binius_core/polynomial/
arith_circuit.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{fmt::Debug, mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithCircuit, ArithCircuitStep, CompositionPoly, Error, RowsBatchRef};
7use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
8use stackalloc::{
9	helpers::{slice_assume_init, slice_assume_init_mut},
10	stackalloc_uninit,
11};
12
13/// Convert the expression to a sequence of arithmetic operations that can be evaluated in sequence.
14fn convert_circuit_steps<F: Field>(
15	expr: &ArithCircuit<F>,
16) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
17	/// This struct is used to the steps in the original circuit and the converted one and back.
18	struct StepsMapping {
19		original_to_converted: Vec<Option<usize>>,
20		converted_to_original: Vec<Option<usize>>,
21	}
22
23	impl StepsMapping {
24		fn new(original_size: usize) -> Self {
25			Self {
26				original_to_converted: vec![None; original_size],
27				// The size of this vector isn't known at the start, but that's a reasonable guess
28				converted_to_original: vec![None; original_size],
29			}
30		}
31
32		fn register(&mut self, original_step: usize, converted_step: usize) {
33			self.original_to_converted[original_step] = Some(converted_step);
34
35			if converted_step >= self.converted_to_original.len() {
36				self.converted_to_original.resize(converted_step + 1, None);
37			}
38			self.converted_to_original[converted_step] = Some(original_step);
39		}
40
41		fn clear_step(&mut self, converted_step: usize) {
42			if self.converted_to_original.len() <= converted_step {
43				return;
44			}
45
46			if let Some(original_step) = self.converted_to_original[converted_step].take() {
47				self.original_to_converted[original_step] = None;
48			}
49		}
50
51		fn get_converted_step(&self, original_step: usize) -> Option<usize> {
52			self.original_to_converted[original_step]
53		}
54	}
55
56	fn convert_step<F: Field>(
57		original_step: usize,
58		original_steps: &[ArithCircuitStep<F>],
59		result: &mut Vec<CircuitStep<F>>,
60		node_to_step: &mut StepsMapping,
61	) -> CircuitStepArgument<F> {
62		if let Some(converted_step) = node_to_step.get_converted_step(original_step) {
63			return CircuitStepArgument::Expr(CircuitNode::Slot(converted_step));
64		}
65
66		match &original_steps[original_step] {
67			ArithCircuitStep::Const(constant) => CircuitStepArgument::Const(*constant),
68			ArithCircuitStep::Var(var) => CircuitStepArgument::Expr(CircuitNode::Var(*var)),
69			ArithCircuitStep::Add(left, right) => {
70				let left = convert_step(*left, original_steps, result, node_to_step);
71
72				if let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left {
73					if let ArithCircuitStep::Mul(mleft, mright) = &original_steps[*right] {
74						// Only handling e1 + (e2 * e3), not (e1 * e2) + e3, as latter was not observed in practice
75						// (the former can be enforced by rewriting expression)
76						let mleft = convert_step(*mleft, original_steps, result, node_to_step);
77						let mright = convert_step(*mright, original_steps, result, node_to_step);
78
79						// Since we'we changed the value of `left` to a new value, we need to clear the cache for it
80						node_to_step.clear_step(left);
81						node_to_step.register(original_step, left);
82						result.push(CircuitStep::AddMul(left, mleft, mright));
83
84						return CircuitStepArgument::Expr(CircuitNode::Slot(left));
85					}
86				}
87
88				let right = convert_step(*right, original_steps, result, node_to_step);
89
90				node_to_step.register(original_step, result.len());
91				result.push(CircuitStep::Add(left, right));
92				CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
93			}
94			ArithCircuitStep::Mul(left, right) => {
95				let left = convert_step(*left, original_steps, result, node_to_step);
96				let right = convert_step(*right, original_steps, result, node_to_step);
97
98				node_to_step.register(original_step, result.len());
99				result.push(CircuitStep::Mul(left, right));
100				CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
101			}
102			ArithCircuitStep::Pow(base, exp) => {
103				let mut acc = convert_step(*base, original_steps, result, node_to_step);
104				let base_expr = acc;
105				let highest_bit = exp.ilog2();
106
107				for i in (0..highest_bit).rev() {
108					if i == 0 {
109						node_to_step.register(original_step, result.len());
110					}
111					result.push(CircuitStep::Square(acc));
112					acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
113
114					if (exp >> i) & 1 != 0 {
115						result.push(CircuitStep::Mul(acc, base_expr));
116						acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
117					}
118				}
119
120				acc
121			}
122		}
123	}
124
125	let mut steps = Vec::new();
126	let mut steps_mapping = StepsMapping::new(expr.steps().len());
127	let ret = convert_step(expr.steps().len() - 1, expr.steps(), &mut steps, &mut steps_mapping);
128	(steps, ret)
129}
130
131/// Input of the circuit calculation step
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133enum CircuitNode {
134	/// Input variable
135	Var(usize),
136	/// Evaluation at one of the previous steps
137	Slot(usize),
138}
139
140impl CircuitNode {
141	/// Return either the input column or the slice with evaluations at one of the previous steps.
142	/// This method is used for batch evaluation.
143	fn get_sparse_chunk<'a, P: PackedField>(
144		&'a self,
145		inputs: &'a RowsBatchRef<'a, P>,
146		evals: &'a [P],
147		row_len: usize,
148	) -> &'a [P] {
149		match self {
150			Self::Var(index) => inputs.rows()[*index],
151			Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
152		}
153	}
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157enum CircuitStepArgument<F> {
158	Expr(CircuitNode),
159	Const(F),
160}
161
162/// Describes computation symbolically. This is used internally by ArithCircuitPoly.
163///
164/// ExprIds used by an Expr has to be less than the index of the Expr itself within the ArithCircuitPoly,
165/// to ensure it represents a directed acyclic graph that can be computed in sequence.
166#[derive(Debug)]
167enum CircuitStep<F: Field> {
168	Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
169	Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
170	Square(CircuitStepArgument<F>),
171	AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
172}
173
174/// Describes polynomial evaluations using a directed acyclic graph of expressions.
175///
176/// This is meant as an alternative to a hard-coded CompositionPoly.
177///
178/// The advantage over a hard coded CompositionPoly is that this can be constructed and manipulated dynamically at runtime
179/// and the object representing different polnomials can be stored in a homogeneous collection.
180#[derive(Debug, Clone)]
181pub struct ArithCircuitPoly<F: Field> {
182	expr: ArithCircuit<F>,
183	steps: Arc<[CircuitStep<F>]>,
184	/// The "top level expression", which depends on circuit expression evaluations
185	retval: CircuitStepArgument<F>,
186	degree: usize,
187	n_vars: usize,
188	tower_level: usize,
189}
190
191impl<F: Field> PartialEq for ArithCircuitPoly<F> {
192	fn eq(&self, other: &Self) -> bool {
193		self.n_vars == other.n_vars && self.expr == other.expr
194	}
195}
196
197impl<F: Field> Eq for ArithCircuitPoly<F> {}
198
199impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
200	fn serialize(
201		&self,
202		mut write_buf: impl bytes::BufMut,
203		mode: SerializationMode,
204	) -> Result<(), SerializationError> {
205		(&self.expr, self.n_vars).serialize(&mut write_buf, mode)
206	}
207}
208
209impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
210	fn deserialize(
211		read_buf: impl bytes::Buf,
212		mode: SerializationMode,
213	) -> Result<Self, SerializationError>
214	where
215		Self: Sized,
216	{
217		let (expr, n_vars) = <(ArithCircuit<F>, usize)>::deserialize(read_buf, mode)?;
218		Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
219			name: "ArithCircuitPoly",
220		})
221	}
222}
223
224impl<F: TowerField> ArithCircuitPoly<F> {
225	pub fn new(mut expr: ArithCircuit<F>) -> Self {
226		expr.optimize_in_place();
227
228		let degree = expr.degree();
229		let n_vars = expr.n_vars();
230		let tower_level = expr.binary_tower_level();
231		let (exprs, retval) = convert_circuit_steps(&expr);
232
233		Self {
234			expr,
235			steps: exprs.into(),
236			retval,
237			degree,
238			n_vars,
239			tower_level,
240		}
241	}
242	/// Constructs an [`ArithCircuitPoly`] with the given number of variables.
243	///
244	/// The number of variables may be greater than the number of variables actually read in the
245	/// arithmetic expression.
246	pub fn with_n_vars(n_vars: usize, mut expr: ArithCircuit<F>) -> Result<Self, Error> {
247		expr.optimize_in_place();
248
249		let degree = expr.degree();
250		let tower_level = expr.binary_tower_level();
251		if n_vars < expr.n_vars() {
252			return Err(Error::IncorrectNumberOfVariables {
253				expected: expr.n_vars(),
254				actual: n_vars,
255			});
256		}
257		let (steps, retval) = convert_circuit_steps(&expr);
258
259		Ok(Self {
260			expr,
261			steps: steps.into(),
262			retval,
263			n_vars,
264			degree,
265			tower_level,
266		})
267	}
268}
269
270impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
271	for ArithCircuitPoly<F>
272{
273	fn degree(&self) -> usize {
274		self.degree
275	}
276
277	fn n_vars(&self) -> usize {
278		self.n_vars
279	}
280
281	fn binary_tower_level(&self) -> usize {
282		self.tower_level
283	}
284
285	fn expression(&self) -> ArithCircuit<P::Scalar> {
286		self.expr.convert_field()
287	}
288
289	fn evaluate(&self, query: &[P]) -> Result<P, Error> {
290		if query.len() != self.n_vars {
291			return Err(Error::IncorrectQuerySize {
292				expected: self.n_vars,
293			});
294		}
295
296		fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
297			// Safety: The index is guaranteed to be within bounds because
298			// we initialize at least `self.steps.len()` using `stackalloc`.
299			unsafe {
300				target.get_unchecked_mut(0).write(value);
301			}
302		}
303
304		// `stackalloc_uninit` throws a debug assert if `size` is 0, so set minimum of 1.
305		stackalloc_uninit::<P, _, _>(self.steps.len().max(1), |evals| {
306			let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
307				// Safety: The index is guaranteed to be within bounds by the construction of the circuit
308				CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
309					*query.get_unchecked(index)
310				},
311				// Safety: The index is guaranteed to be within bounds by the circuit evaluation order
312				CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
313					*evals.get_unchecked(slot)
314				},
315				CircuitStepArgument::Const(value) => P::broadcast(value.into()),
316			};
317
318			for (i, expr) in self.steps.iter().enumerate() {
319				// Safety: previous evaluations are initialized by the previous loop iterations (if dereferenced)
320				let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
321				let before = unsafe { slice_assume_init_mut(before) };
322				match expr {
323					CircuitStep::Add(x, y) => write_result(
324						after,
325						get_argument_value(*x, before) + get_argument_value(*y, before),
326					),
327					CircuitStep::AddMul(target_slot, x, y) => {
328						let intermediate =
329							get_argument_value(*x, before) * get_argument_value(*y, before);
330						// Safety: we know by evaluation order and construction of steps that `target.slot` is initialized
331						let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
332						*target_slot += intermediate;
333					}
334					CircuitStep::Mul(x, y) => write_result(
335						after,
336						get_argument_value(*x, before) * get_argument_value(*y, before),
337					),
338					CircuitStep::Square(x) => {
339						write_result(after, get_argument_value(*x, before).square())
340					}
341				};
342			}
343
344			// Some slots in `evals` might be empty, but we're guaranted that
345			// if `self.retval` points to a slot, that this slot is initialized.
346			unsafe {
347				let evals = slice_assume_init(evals);
348				Ok(get_argument_value(self.retval, evals))
349			}
350		})
351	}
352
353	fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
354		let row_len = evals.len();
355		if batch_query.row_len() != row_len {
356			bail!(Error::BatchEvaluateSizeMismatch {
357				expected: row_len,
358				actual: batch_query.row_len(),
359			});
360		}
361
362		// `stackalloc_uninit` throws a debug assert if `size` is 0, so set minimum of 1.
363		stackalloc_uninit::<P, (), _>((self.steps.len() * row_len).max(1), |sparse_evals| {
364			for (i, expr) in self.steps.iter().enumerate() {
365				let (before, current) = sparse_evals.split_at_mut(i * row_len);
366
367				// Safety: `before` is guaranteed to be initialized by the previous loop iterations (if dereferenced).
368				let before = unsafe { slice_assume_init_mut(before) };
369				let current = &mut current[..row_len];
370
371				match expr {
372					CircuitStep::Add(left, right) => {
373						apply_binary_op(
374							left,
375							right,
376							batch_query,
377							before,
378							current,
379							|left, right, out| {
380								out.write(left + right);
381							},
382						);
383					}
384					CircuitStep::Mul(left, right) => {
385						apply_binary_op(
386							left,
387							right,
388							batch_query,
389							before,
390							current,
391							|left, right, out| {
392								out.write(left * right);
393							},
394						);
395					}
396					CircuitStep::Square(arg) => {
397						match arg {
398							CircuitStepArgument::Expr(node) => {
399								let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
400								for j in 0..row_len {
401									// Safety: `current` and `id_chunk` have length equal to `row_len`
402									unsafe {
403										current
404											.get_unchecked_mut(j)
405											.write(id_chunk.get_unchecked(j).square());
406									}
407								}
408							}
409							CircuitStepArgument::Const(value) => {
410								let value: P = P::broadcast((*value).into());
411								let result = value.square();
412								for j in 0..row_len {
413									// Safety: `current` has length equal to `row_len`
414									unsafe {
415										current.get_unchecked_mut(j).write(result);
416									}
417								}
418							}
419						}
420					}
421					CircuitStep::AddMul(target, left, right) => {
422						let target = &before[row_len * target..(target + 1) * row_len];
423						// Safety: by construction of steps and evaluation order we know
424						// that `target` is not borrowed elsewhere.
425						let target: &mut [MaybeUninit<P>] = unsafe {
426							std::slice::from_raw_parts_mut(
427								target.as_ptr() as *mut MaybeUninit<P>,
428								target.len(),
429							)
430						};
431						apply_binary_op(
432							left,
433							right,
434							batch_query,
435							before,
436							target,
437							// Safety: by construction of steps and evaluation order we know
438							// that `target`/`out` is initialized.
439							|left, right, out| unsafe {
440								let out = out.assume_init_mut();
441								*out += left * right;
442							},
443						);
444					}
445				}
446			}
447
448			match self.retval {
449				CircuitStepArgument::Expr(node) => {
450					// Safety: `sparse_evals` is fully initialized by the previous loop iterations
451					let sparse_evals = unsafe { slice_assume_init(sparse_evals) };
452					evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
453				}
454				CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
455			}
456		});
457
458		Ok(())
459	}
460}
461
462/// Apply a binary operation to two arguments and store the result in `current_evals`.
463/// `op` must be a function that takes two arguments and initialized the result with the third argument.
464fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
465	left: &CircuitStepArgument<F>,
466	right: &CircuitStepArgument<F>,
467	batch_query: &RowsBatchRef<P>,
468	evals_before: &[P],
469	current_evals: &mut [MaybeUninit<P>],
470	op: impl Fn(P, P, &mut MaybeUninit<P>),
471) {
472	let row_len = current_evals.len();
473
474	match (left, right) {
475		(CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
476			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
477			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
478			for j in 0..row_len {
479				// Safety: `current`, `left` and `right` have length equal to `row_len`
480				unsafe {
481					op(
482						*left.get_unchecked(j),
483						*right.get_unchecked(j),
484						current_evals.get_unchecked_mut(j),
485					)
486				}
487			}
488		}
489		(CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
490			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
491			let right = P::broadcast((*right).into());
492			for j in 0..row_len {
493				// Safety: `current` and `left` have length equal to `row_len`
494				unsafe {
495					op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
496				}
497			}
498		}
499		(CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
500			let left = P::broadcast((*left).into());
501			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
502			for j in 0..row_len {
503				// Safety: `current` and `right` have length equal to `row_len`
504				unsafe {
505					op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
506				}
507			}
508		}
509		(CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
510			let left = P::broadcast((*left).into());
511			let right = P::broadcast((*right).into());
512			let mut result = MaybeUninit::uninit();
513			op(left, right, &mut result);
514			for j in 0..row_len {
515				// Safety:
516				// - `current` has length equal to `row_len`
517				// - `result` is initialized by `op`
518				unsafe {
519					current_evals
520						.get_unchecked_mut(j)
521						.write(result.assume_init());
522				}
523			}
524		}
525	}
526}
527
528#[cfg(test)]
529mod tests {
530	use binius_field::{
531		BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedField, TowerField,
532	};
533	use binius_math::{ArithExpr, CompositionPoly, RowsBatch};
534	use binius_utils::felts;
535
536	use super::*;
537
538	#[test]
539	fn test_constant() {
540		type F = BinaryField8b;
541		type P = PackedBinaryField8x16b;
542
543		let expr = ArithExpr::Const(F::new(123));
544		let circuit = ArithCircuitPoly::<F>::new(expr.into());
545
546		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
547		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
548		assert_eq!(typed_circuit.degree(), 0);
549		assert_eq!(typed_circuit.n_vars(), 0);
550
551		assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
552
553		let mut evals = [P::default()];
554		typed_circuit
555			.batch_evaluate(&RowsBatchRef::new(&[], 1), &mut evals)
556			.unwrap();
557		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
558	}
559
560	#[test]
561	fn test_identity() {
562		type F = BinaryField8b;
563		type P = PackedBinaryField8x16b;
564
565		// x0
566		let expr = ArithExpr::Var(0);
567		let circuit = ArithCircuitPoly::<F>::new(expr.into());
568
569		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
570		assert_eq!(typed_circuit.binary_tower_level(), 0);
571		assert_eq!(typed_circuit.degree(), 1);
572		assert_eq!(typed_circuit.n_vars(), 1);
573
574		assert_eq!(
575			typed_circuit
576				.evaluate(&[P::broadcast(F::new(123).into())])
577				.unwrap(),
578			P::broadcast(F::new(123).into())
579		);
580
581		let mut evals = [P::default()];
582		let batch_query = [[P::broadcast(F::new(123).into())]; 1];
583		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
584		typed_circuit
585			.batch_evaluate(&batch_query.get_ref(), &mut evals)
586			.unwrap();
587		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
588	}
589
590	#[test]
591	fn test_add() {
592		type F = BinaryField8b;
593		type P = PackedBinaryField8x16b;
594
595		// 123 + x0
596		let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
597		let circuit = ArithCircuitPoly::<F>::new(expr.into());
598
599		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
600		assert_eq!(typed_circuit.binary_tower_level(), 3);
601		assert_eq!(typed_circuit.degree(), 1);
602		assert_eq!(typed_circuit.n_vars(), 1);
603
604		assert_eq!(
605			CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
606			P::broadcast(F::new(123).into())
607		);
608	}
609
610	#[test]
611	fn test_mul() {
612		type F = BinaryField8b;
613		type P = PackedBinaryField8x16b;
614
615		// 123 * x0
616		let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
617		let circuit = ArithCircuitPoly::<F>::new(expr.into());
618
619		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
620		assert_eq!(typed_circuit.binary_tower_level(), 3);
621		assert_eq!(typed_circuit.degree(), 1);
622		assert_eq!(typed_circuit.n_vars(), 1);
623
624		assert_eq!(
625			CompositionPoly::evaluate(
626				&circuit,
627				&[P::from_scalars(
628					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
629				)]
630			)
631			.unwrap(),
632			P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
633		);
634	}
635
636	#[test]
637	fn test_pow() {
638		type F = BinaryField8b;
639		type P = PackedBinaryField8x16b;
640
641		// x0^13
642		let expr = ArithExpr::Var(0).pow(13);
643		let circuit = ArithCircuitPoly::<F>::new(expr.into());
644
645		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
646		assert_eq!(typed_circuit.binary_tower_level(), 0);
647		assert_eq!(typed_circuit.degree(), 13);
648		assert_eq!(typed_circuit.n_vars(), 1);
649
650		assert_eq!(
651			CompositionPoly::evaluate(
652				&circuit,
653				&[P::from_scalars(
654					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
655				)]
656			)
657			.unwrap(),
658			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
659		);
660	}
661
662	#[test]
663	fn test_mixed() {
664		type F = BinaryField8b;
665		type P = PackedBinaryField8x16b;
666
667		// x0^2 * (x1 + 123)
668		let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
669		let circuit = ArithCircuitPoly::<F>::new(expr.into());
670
671		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
672		assert_eq!(typed_circuit.binary_tower_level(), 3);
673		assert_eq!(typed_circuit.degree(), 3);
674		assert_eq!(typed_circuit.n_vars(), 2);
675
676		// test evaluate
677		assert_eq!(
678			CompositionPoly::evaluate(
679				&circuit,
680				&[
681					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
682					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
683				]
684			)
685			.unwrap(),
686			P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
687		);
688
689		// test batch evaluate
690		let query1 = &[
691			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
692			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
693		];
694		let query2 = &[
695			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
696			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
697		];
698		let query3 = &[
699			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
700			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
701		];
702		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
703		let expected2 =
704			P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
705		let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
706
707		let mut batch_result = vec![P::zero(); 3];
708		let batch_query = &[
709			&[query1[0], query2[0], query3[0]],
710			&[query1[1], query2[1], query3[1]],
711		];
712		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
713
714		CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
715			.unwrap();
716		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
717	}
718
719	#[test]
720	fn test_const_fold() {
721		type F = BinaryField8b;
722		type P = PackedBinaryField8x16b;
723
724		// x0 * ((122 * 123) + (124 + 125)) + x1
725		let expr = ArithExpr::Var(0)
726			* ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
727				+ (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
728			+ ArithExpr::Var(1);
729		let circuit = ArithCircuitPoly::<F>::new(expr.into());
730		assert_eq!(circuit.steps.len(), 2);
731
732		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
733		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
734		assert_eq!(typed_circuit.degree(), 1);
735		assert_eq!(typed_circuit.n_vars(), 2);
736
737		// test evaluate
738		assert_eq!(
739			CompositionPoly::evaluate(
740				&circuit,
741				&[
742					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
743					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
744				]
745			)
746			.unwrap(),
747			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
748		);
749
750		// test batch evaluate
751		let query1 = &[
752			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
753			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
754		];
755		let query2 = &[
756			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
757			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
758		];
759		let query3 = &[
760			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
761			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
762		];
763		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
764		let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
765		let expected3 =
766			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
767
768		let mut batch_result = vec![P::zero(); 3];
769		let batch_query = &[
770			&[query1[0], query2[0], query3[0]],
771			&[query1[1], query2[1], query3[1]],
772		];
773		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
774		CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
775			.unwrap();
776		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
777	}
778
779	#[test]
780	fn test_pow_const_fold() {
781		type F = BinaryField8b;
782		type P = PackedBinaryField8x16b;
783
784		// x0 + 2^5
785		let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
786		let circuit = ArithCircuitPoly::<F>::new(expr.into());
787		assert_eq!(circuit.steps.len(), 1);
788
789		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
790		assert_eq!(typed_circuit.binary_tower_level(), 1);
791		assert_eq!(typed_circuit.degree(), 1);
792		assert_eq!(typed_circuit.n_vars(), 1);
793
794		assert_eq!(
795			CompositionPoly::evaluate(
796				&circuit,
797				&[P::from_scalars(
798					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
799				)]
800			)
801			.unwrap(),
802			P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
803		);
804	}
805
806	#[test]
807	fn test_pow_nested() {
808		type F = BinaryField8b;
809		type P = PackedBinaryField8x16b;
810
811		// ((x0^2)^3)^4
812		let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
813		let circuit = ArithCircuitPoly::<F>::new(expr.into());
814		assert_eq!(circuit.steps.len(), 5);
815
816		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
817		assert_eq!(typed_circuit.binary_tower_level(), 0);
818		assert_eq!(typed_circuit.degree(), 24);
819		assert_eq!(typed_circuit.n_vars(), 1);
820
821		assert_eq!(
822			CompositionPoly::evaluate(
823				&circuit,
824				&[P::from_scalars(
825					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
826				)]
827			)
828			.unwrap(),
829			P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
830		);
831	}
832
833	#[test]
834	fn test_circuit_steps_for_expr_constant() {
835		type F = BinaryField8b;
836
837		let expr = ArithExpr::Const(F::new(5));
838		let (steps, retval) = convert_circuit_steps(&expr.into());
839
840		assert!(steps.is_empty(), "No steps should be generated for a constant");
841		assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
842	}
843
844	#[test]
845	fn test_circuit_steps_for_expr_variable() {
846		type F = BinaryField8b;
847
848		let expr = ArithExpr::<F>::Var(18);
849		let (steps, retval) = convert_circuit_steps(&expr.into());
850
851		assert!(steps.is_empty(), "No steps should be generated for a variable");
852		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
853	}
854
855	#[test]
856	fn test_circuit_steps_for_expr_addition() {
857		type F = BinaryField8b;
858
859		let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
860		let (steps, retval) = convert_circuit_steps(&expr.into());
861
862		assert_eq!(steps.len(), 1, "One addition step should be generated");
863		assert!(matches!(
864			steps[0],
865			CircuitStep::Add(
866				CircuitStepArgument::Expr(CircuitNode::Var(14)),
867				CircuitStepArgument::Expr(CircuitNode::Var(56))
868			)
869		));
870		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
871	}
872
873	#[test]
874	fn test_circuit_steps_for_expr_multiplication() {
875		type F = BinaryField8b;
876
877		let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
878		let (steps, retval) = convert_circuit_steps(&expr.into());
879
880		assert_eq!(steps.len(), 1, "One multiplication step should be generated");
881		assert!(matches!(
882			steps[0],
883			CircuitStep::Mul(
884				CircuitStepArgument::Expr(CircuitNode::Var(36)),
885				CircuitStepArgument::Expr(CircuitNode::Var(26))
886			)
887		));
888		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
889	}
890
891	#[test]
892	fn test_circuit_steps_for_expr_pow_1() {
893		type F = BinaryField8b;
894
895		let expr = ArithExpr::<F>::Var(12).pow(1);
896		let (steps, retval) = convert_circuit_steps(&expr.into());
897
898		// No steps should be generated for x^1
899		assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
900
901		// The return value should just be the variable itself
902		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
903	}
904
905	#[test]
906	fn test_circuit_steps_for_expr_pow_2() {
907		type F = BinaryField8b;
908
909		let expr = ArithExpr::<F>::Var(10).pow(2);
910		let (steps, retval) = convert_circuit_steps(&expr.into());
911
912		assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
913		assert!(matches!(
914			steps[0],
915			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
916		));
917		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
918	}
919
920	#[test]
921	fn test_circuit_steps_for_expr_pow_3() {
922		type F = BinaryField8b;
923
924		let expr = ArithExpr::<F>::Var(5).pow(3);
925		let (steps, retval) = convert_circuit_steps(&expr.into());
926
927		assert_eq!(
928			steps.len(),
929			2,
930			"Pow(3) should generate one squaring and one multiplication step"
931		);
932		assert!(matches!(
933			steps[0],
934			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
935		));
936		assert!(matches!(
937			steps[1],
938			CircuitStep::Mul(
939				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
940				CircuitStepArgument::Expr(CircuitNode::Var(5))
941			)
942		));
943		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
944	}
945
946	#[test]
947	fn test_circuit_steps_for_expr_pow_4() {
948		type F = BinaryField8b;
949
950		let expr = ArithExpr::<F>::Var(7).pow(4);
951		let (steps, retval) = convert_circuit_steps(&expr.into());
952
953		assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
954		assert!(matches!(
955			steps[0],
956			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
957		));
958
959		assert!(matches!(
960			steps[1],
961			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
962		));
963
964		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
965	}
966
967	#[test]
968	fn test_circuit_steps_for_expr_pow_5() {
969		type F = BinaryField8b;
970
971		let expr = ArithExpr::<F>::Var(3).pow(5);
972		let (steps, retval) = convert_circuit_steps(&expr.into());
973
974		assert_eq!(
975			steps.len(),
976			3,
977			"Pow(5) should generate two squaring steps and one multiplication"
978		);
979		assert!(matches!(
980			steps[0],
981			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
982		));
983		assert!(matches!(
984			steps[1],
985			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
986		));
987		assert!(matches!(
988			steps[2],
989			CircuitStep::Mul(
990				CircuitStepArgument::Expr(CircuitNode::Slot(1)),
991				CircuitStepArgument::Expr(CircuitNode::Var(3))
992			)
993		));
994
995		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
996	}
997
998	#[test]
999	fn test_circuit_steps_for_expr_pow_8() {
1000		type F = BinaryField8b;
1001
1002		let expr = ArithExpr::<F>::Var(4).pow(8);
1003		let (steps, retval) = convert_circuit_steps(&expr.into());
1004
1005		assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
1006		assert!(matches!(
1007			steps[0],
1008			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
1009		));
1010		assert!(matches!(
1011			steps[1],
1012			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1013		));
1014		assert!(matches!(
1015			steps[2],
1016			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1017		));
1018
1019		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1020	}
1021
1022	#[test]
1023	fn test_circuit_steps_for_expr_pow_9() {
1024		type F = BinaryField8b;
1025
1026		let expr = ArithExpr::<F>::Var(8).pow(9);
1027		let (steps, retval) = convert_circuit_steps(&expr.into());
1028
1029		assert_eq!(
1030			steps.len(),
1031			4,
1032			"Pow(9) should generate three squaring steps and one multiplication"
1033		);
1034		assert!(matches!(
1035			steps[0],
1036			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
1037		));
1038		assert!(matches!(
1039			steps[1],
1040			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1041		));
1042		assert!(matches!(
1043			steps[2],
1044			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1045		));
1046		assert!(matches!(
1047			steps[3],
1048			CircuitStep::Mul(
1049				CircuitStepArgument::Expr(CircuitNode::Slot(2)),
1050				CircuitStepArgument::Expr(CircuitNode::Var(8))
1051			)
1052		));
1053
1054		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1055	}
1056
1057	#[test]
1058	fn test_circuit_steps_for_expr_pow_12() {
1059		type F = BinaryField8b;
1060		let expr = ArithExpr::<F>::Var(6).pow(12);
1061		let (steps, retval) = convert_circuit_steps(&expr.into());
1062
1063		assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1064
1065		assert!(matches!(
1066			steps[0],
1067			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1068		));
1069		assert!(matches!(
1070			steps[1],
1071			CircuitStep::Mul(
1072				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1073				CircuitStepArgument::Expr(CircuitNode::Var(6))
1074			)
1075		));
1076		assert!(matches!(
1077			steps[2],
1078			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1079		));
1080		assert!(matches!(
1081			steps[3],
1082			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1083		));
1084
1085		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1086	}
1087
1088	#[test]
1089	fn test_circuit_steps_for_expr_pow_13() {
1090		type F = BinaryField8b;
1091		let expr = ArithExpr::<F>::Var(7).pow(13);
1092		let (steps, retval) = convert_circuit_steps(&expr.into());
1093
1094		assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1095		assert!(matches!(
1096			steps[0],
1097			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1098		));
1099		assert!(matches!(
1100			steps[1],
1101			CircuitStep::Mul(
1102				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1103				CircuitStepArgument::Expr(CircuitNode::Var(7))
1104			)
1105		));
1106		assert!(matches!(
1107			steps[2],
1108			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1109		));
1110		assert!(matches!(
1111			steps[3],
1112			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1113		));
1114		assert!(matches!(
1115			steps[4],
1116			CircuitStep::Mul(
1117				CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1118				CircuitStepArgument::Expr(CircuitNode::Var(7))
1119			)
1120		));
1121		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1122	}
1123
1124	#[test]
1125	fn test_circuit_steps_for_expr_complex() {
1126		type F = BinaryField8b;
1127
1128		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1129			+ (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1130			- ArithExpr::Var(3);
1131
1132		let (steps, retval) = convert_circuit_steps(&expr.into());
1133
1134		assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1135
1136		assert!(
1137			matches!(
1138				steps[0],
1139				CircuitStep::Mul(
1140					CircuitStepArgument::Expr(CircuitNode::Var(0)),
1141					CircuitStepArgument::Expr(CircuitNode::Var(1))
1142				)
1143			),
1144			"First step should be multiplication x0 * x1"
1145		);
1146
1147		assert!(
1148			matches!(
1149				steps[1],
1150				CircuitStep::Add(
1151					CircuitStepArgument::Const(F::ONE),
1152					CircuitStepArgument::Expr(CircuitNode::Var(0))
1153				)
1154			),
1155			"Second step should be (1 - x0)"
1156		);
1157
1158		assert!(
1159			matches!(
1160				steps[2],
1161				CircuitStep::AddMul(
1162					0,
1163					CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1164					CircuitStepArgument::Expr(CircuitNode::Var(2))
1165				)
1166			),
1167			"Third step should be (1 - x0) * x2"
1168		);
1169
1170		assert!(
1171			matches!(
1172				steps[3],
1173				CircuitStep::Add(
1174					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1175					CircuitStepArgument::Expr(CircuitNode::Var(3))
1176				)
1177			),
1178			"Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1179		);
1180
1181		assert!(
1182			matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1183			"Final result should be stored in Slot(3)"
1184		);
1185	}
1186
1187	#[test]
1188	fn check_deduplication_in_steps() {
1189		type F = BinaryField8b;
1190
1191		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1192			+ (ArithExpr::<F>::Var(0) * ArithExpr::Var(1)) * ArithExpr::Var(2)
1193			- ArithExpr::Var(3);
1194		let expr = ArithCircuit::from(&expr);
1195		let expr = expr.optimize();
1196
1197		let (steps, retval) = convert_circuit_steps(&expr);
1198
1199		assert_eq!(steps.len(), 3, "Expression should generate 3 computation steps");
1200
1201		assert!(
1202			matches!(
1203				steps[0],
1204				CircuitStep::Mul(
1205					CircuitStepArgument::Expr(CircuitNode::Var(0)),
1206					CircuitStepArgument::Expr(CircuitNode::Var(1))
1207				)
1208			),
1209			"First step should be multiplication x0 * x1"
1210		);
1211
1212		assert!(
1213			matches!(
1214				steps[1],
1215				CircuitStep::AddMul(
1216					0,
1217					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1218					CircuitStepArgument::Expr(CircuitNode::Var(2))
1219				)
1220			),
1221			"Second step should be (x0 * x1) * x2"
1222		);
1223
1224		assert!(
1225			matches!(
1226				steps[2],
1227				CircuitStep::Add(
1228					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1229					CircuitStepArgument::Expr(CircuitNode::Var(3))
1230				)
1231			),
1232			"Third step should be x0 * x1 + (x0 * x1) * x2 + x3"
1233		);
1234
1235		assert!(
1236			matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))),
1237			"Final result should be stored in Slot(2)"
1238		);
1239	}
1240}