binius_core/polynomial/
arith_circuit.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{fmt::Debug, mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithExpr, CompositionPoly, Error, RowsBatchRef};
7use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
8use stackalloc::{
9	helpers::{slice_assume_init, slice_assume_init_mut},
10	stackalloc_uninit,
11};
12
13/// Convert the expression to a sequence of arithmetic operations that can be evaluated in sequence.
14fn circuit_steps_for_expr<F: Field>(
15	expr: &ArithExpr<F>,
16) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
17	let mut steps = Vec::new();
18
19	fn to_circuit_inner<F: Field>(
20		expr: &ArithExpr<F>,
21		result: &mut Vec<CircuitStep<F>>,
22	) -> CircuitStepArgument<F> {
23		match expr {
24			ArithExpr::Const(value) => CircuitStepArgument::Const(*value),
25			ArithExpr::Var(index) => CircuitStepArgument::Expr(CircuitNode::Var(*index)),
26			ArithExpr::Add(left, right) => match &**right {
27				ArithExpr::Mul(mleft, mright) if left.is_composite() => {
28					// Only handling e1 + (e2 * e3), not (e1 * e2) + e3, as latter was not observed in practice
29					// (the former can be enforced by rewriting expression).
30					let left = to_circuit_inner(left, result);
31					let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left else {
32						unreachable!("guaranteed by `is_composite` check above")
33					};
34					let mleft = to_circuit_inner(mleft, result);
35					let mright = to_circuit_inner(mright, result);
36					result.push(CircuitStep::AddMul(left, mleft, mright));
37					CircuitStepArgument::Expr(CircuitNode::Slot(left))
38				}
39				_ => {
40					let left = to_circuit_inner(left, result);
41					let right = to_circuit_inner(right, result);
42					result.push(CircuitStep::Add(left, right));
43					CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
44				}
45			},
46			ArithExpr::Mul(left, right) => {
47				let left = to_circuit_inner(left, result);
48				let right = to_circuit_inner(right, result);
49				result.push(CircuitStep::Mul(left, right));
50				CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
51			}
52			ArithExpr::Pow(base, exp) => {
53				let mut acc = to_circuit_inner(base, result);
54				let base_expr = acc;
55				let highest_bit = exp.ilog2();
56
57				for i in (0..highest_bit).rev() {
58					result.push(CircuitStep::Square(acc));
59					acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
60
61					if (exp >> i) & 1 != 0 {
62						result.push(CircuitStep::Mul(acc, base_expr));
63						acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
64					}
65				}
66
67				acc
68			}
69		}
70	}
71
72	let ret = to_circuit_inner(&expr.optimize(), &mut steps);
73	(steps, ret)
74}
75
76/// Input of the circuit calculation step
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78enum CircuitNode {
79	/// Input variable
80	Var(usize),
81	/// Evaluation at one of the previous steps
82	Slot(usize),
83}
84
85impl CircuitNode {
86	/// Return either the input column or the slice with evaluations at one of the previous steps.
87	/// This method is used for batch evaluation.
88	fn get_sparse_chunk<'a, P: PackedField>(
89		&'a self,
90		inputs: &'a RowsBatchRef<'a, P>,
91		evals: &'a [P],
92		row_len: usize,
93	) -> &'a [P] {
94		match self {
95			Self::Var(index) => inputs.rows()[*index],
96			Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
97		}
98	}
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102enum CircuitStepArgument<F> {
103	Expr(CircuitNode),
104	Const(F),
105}
106
107/// Describes computation symbolically. This is used internally by ArithCircuitPoly.
108///
109/// ExprIds used by an Expr has to be less than the index of the Expr itself within the ArithCircuitPoly,
110/// to ensure it represents a directed acyclic graph that can be computed in sequence.
111#[derive(Debug)]
112enum CircuitStep<F: Field> {
113	Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
114	Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
115	Square(CircuitStepArgument<F>),
116	AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
117}
118
119/// Describes polynomial evaluations using a directed acyclic graph of expressions.
120///
121/// This is meant as an alternative to a hard-coded CompositionPoly.
122///
123/// The advantage over a hard coded CompositionPoly is that this can be constructed and manipulated dynamically at runtime
124/// and the object representing different polnomials can be stored in a homogeneous collection.
125#[derive(Debug, Clone)]
126pub struct ArithCircuitPoly<F: Field> {
127	expr: ArithExpr<F>,
128	steps: Arc<[CircuitStep<F>]>,
129	/// The "top level expression", which depends on circuit expression evaluations
130	retval: CircuitStepArgument<F>,
131	degree: usize,
132	n_vars: usize,
133	tower_level: usize,
134}
135
136impl<F: Field> PartialEq for ArithCircuitPoly<F> {
137	fn eq(&self, other: &Self) -> bool {
138		self.n_vars == other.n_vars && self.expr == other.expr
139	}
140}
141
142impl<F: Field> Eq for ArithCircuitPoly<F> {}
143
144impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
145	fn serialize(
146		&self,
147		mut write_buf: impl bytes::BufMut,
148		mode: SerializationMode,
149	) -> Result<(), SerializationError> {
150		(&self.expr, self.n_vars).serialize(&mut write_buf, mode)
151	}
152}
153
154impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
155	fn deserialize(
156		read_buf: impl bytes::Buf,
157		mode: SerializationMode,
158	) -> Result<Self, SerializationError>
159	where
160		Self: Sized,
161	{
162		let (expr, n_vars) = <(ArithExpr<F>, usize)>::deserialize(read_buf, mode)?;
163		Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
164			name: "ArithCircuitPoly",
165		})
166	}
167}
168
169impl<F: TowerField> ArithCircuitPoly<F> {
170	pub fn new(expr: ArithExpr<F>) -> Self {
171		let degree = expr.degree();
172		let n_vars = expr.n_vars();
173		let tower_level = expr.binary_tower_level();
174		let (exprs, retval) = circuit_steps_for_expr(&expr);
175
176		Self {
177			expr,
178			steps: exprs.into(),
179			retval,
180			degree,
181			n_vars,
182			tower_level,
183		}
184	}
185
186	/// Constructs an [`ArithCircuitPoly`] with the given number of variables.
187	///
188	/// The number of variables may be greater than the number of variables actually read in the
189	/// arithmetic expression.
190	pub fn with_n_vars(n_vars: usize, expr: ArithExpr<F>) -> Result<Self, Error> {
191		let degree = expr.degree();
192		let tower_level = expr.binary_tower_level();
193		if n_vars < expr.n_vars() {
194			return Err(Error::IncorrectNumberOfVariables {
195				expected: expr.n_vars(),
196				actual: n_vars,
197			});
198		}
199		let (exprs, retval) = circuit_steps_for_expr(&expr);
200
201		Ok(Self {
202			expr,
203			steps: exprs.into(),
204			retval,
205			n_vars,
206			degree,
207			tower_level,
208		})
209	}
210}
211
212impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
213	for ArithCircuitPoly<F>
214{
215	fn degree(&self) -> usize {
216		self.degree
217	}
218
219	fn n_vars(&self) -> usize {
220		self.n_vars
221	}
222
223	fn binary_tower_level(&self) -> usize {
224		self.tower_level
225	}
226
227	fn expression(&self) -> ArithExpr<P::Scalar> {
228		self.expr.convert_field()
229	}
230
231	fn evaluate(&self, query: &[P]) -> Result<P, Error> {
232		if query.len() != self.n_vars {
233			return Err(Error::IncorrectQuerySize {
234				expected: self.n_vars,
235			});
236		}
237
238		fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
239			// Safety: The index is guaranteed to be within bounds because
240			// we initialize at least `self.steps.len()` using `stackalloc`.
241			unsafe {
242				target.get_unchecked_mut(0).write(value);
243			}
244		}
245
246		// `stackalloc_uninit` throws a debug assert if `size` is 0, so set minimum of 1.
247		stackalloc_uninit::<P, _, _>(self.steps.len().max(1), |evals| {
248			let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
249				// Safety: The index is guaranteed to be within bounds by the construction of the circuit
250				CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
251					*query.get_unchecked(index)
252				},
253				// Safety: The index is guaranteed to be within bounds by the circuit evaluation order
254				CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
255					*evals.get_unchecked(slot)
256				},
257				CircuitStepArgument::Const(value) => P::broadcast(value.into()),
258			};
259
260			for (i, expr) in self.steps.iter().enumerate() {
261				// Safety: previous evaluations are initialized by the previous loop iterations (if dereferenced)
262				let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
263				let before = unsafe { slice_assume_init_mut(before) };
264				match expr {
265					CircuitStep::Add(x, y) => write_result(
266						after,
267						get_argument_value(*x, before) + get_argument_value(*y, before),
268					),
269					CircuitStep::AddMul(target_slot, x, y) => {
270						let intermediate =
271							get_argument_value(*x, before) * get_argument_value(*y, before);
272						// Safety: we know by evaluation order and construction of steps that `target.slot` is initialized
273						let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
274						*target_slot += intermediate;
275					}
276					CircuitStep::Mul(x, y) => write_result(
277						after,
278						get_argument_value(*x, before) * get_argument_value(*y, before),
279					),
280					CircuitStep::Square(x) => {
281						write_result(after, get_argument_value(*x, before).square())
282					}
283				};
284			}
285
286			// Some slots in `evals` might be empty, but we're guaranted that
287			// if `self.retval` points to a slot, that this slot is initialized.
288			unsafe {
289				let evals = slice_assume_init(evals);
290				Ok(get_argument_value(self.retval, evals))
291			}
292		})
293	}
294
295	fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
296		let row_len = evals.len();
297		if batch_query.row_len() != row_len {
298			bail!(Error::BatchEvaluateSizeMismatch {
299				expected: row_len,
300				actual: batch_query.row_len(),
301			});
302		}
303
304		// `stackalloc_uninit` throws a debug assert if `size` is 0, so set minimum of 1.
305		stackalloc_uninit::<P, (), _>((self.steps.len() * row_len).max(1), |sparse_evals| {
306			for (i, expr) in self.steps.iter().enumerate() {
307				let (before, current) = sparse_evals.split_at_mut(i * row_len);
308
309				// Safety: `before` is guaranteed to be initialized by the previous loop iterations (if dereferenced).
310				let before = unsafe { slice_assume_init_mut(before) };
311				let current = &mut current[..row_len];
312
313				match expr {
314					CircuitStep::Add(left, right) => {
315						apply_binary_op(
316							left,
317							right,
318							batch_query,
319							before,
320							current,
321							|left, right, out| {
322								out.write(left + right);
323							},
324						);
325					}
326					CircuitStep::Mul(left, right) => {
327						apply_binary_op(
328							left,
329							right,
330							batch_query,
331							before,
332							current,
333							|left, right, out| {
334								out.write(left * right);
335							},
336						);
337					}
338					CircuitStep::Square(arg) => {
339						match arg {
340							CircuitStepArgument::Expr(node) => {
341								let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
342								for j in 0..row_len {
343									// Safety: `current` and `id_chunk` have length equal to `row_len`
344									unsafe {
345										current
346											.get_unchecked_mut(j)
347											.write(id_chunk.get_unchecked(j).square());
348									}
349								}
350							}
351							CircuitStepArgument::Const(value) => {
352								let value: P = P::broadcast((*value).into());
353								let result = value.square();
354								for j in 0..row_len {
355									// Safety: `current` has length equal to `row_len`
356									unsafe {
357										current.get_unchecked_mut(j).write(result);
358									}
359								}
360							}
361						}
362					}
363					CircuitStep::AddMul(target, left, right) => {
364						let target = &before[row_len * target..(target + 1) * row_len];
365						// Safety: by construction of steps and evaluation order we know
366						// that `target` is not borrowed elsewhere.
367						let target: &mut [MaybeUninit<P>] = unsafe {
368							std::slice::from_raw_parts_mut(
369								target.as_ptr() as *mut MaybeUninit<P>,
370								target.len(),
371							)
372						};
373						apply_binary_op(
374							left,
375							right,
376							batch_query,
377							before,
378							target,
379							// Safety: by construction of steps and evaluation order we know
380							// that `target`/`out` is initialized.
381							|left, right, out| unsafe {
382								let out = out.assume_init_mut();
383								*out += left * right;
384							},
385						);
386					}
387				}
388			}
389
390			match self.retval {
391				CircuitStepArgument::Expr(node) => {
392					// Safety: `sparse_evals` is fully initialized by the previous loop iterations
393					let sparse_evals = unsafe { slice_assume_init(sparse_evals) };
394					evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
395				}
396				CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
397			}
398		});
399
400		Ok(())
401	}
402}
403
404/// Apply a binary operation to two arguments and store the result in `current_evals`.
405/// `op` must be a function that takes two arguments and initialized the result with the third argument.
406fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
407	left: &CircuitStepArgument<F>,
408	right: &CircuitStepArgument<F>,
409	batch_query: &RowsBatchRef<P>,
410	evals_before: &[P],
411	current_evals: &mut [MaybeUninit<P>],
412	op: impl Fn(P, P, &mut MaybeUninit<P>),
413) {
414	let row_len = current_evals.len();
415
416	match (left, right) {
417		(CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
418			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
419			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
420			for j in 0..row_len {
421				// Safety: `current`, `left` and `right` have length equal to `row_len`
422				unsafe {
423					op(
424						*left.get_unchecked(j),
425						*right.get_unchecked(j),
426						current_evals.get_unchecked_mut(j),
427					)
428				}
429			}
430		}
431		(CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
432			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
433			let right = P::broadcast((*right).into());
434			for j in 0..row_len {
435				// Safety: `current` and `left` have length equal to `row_len`
436				unsafe {
437					op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
438				}
439			}
440		}
441		(CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
442			let left = P::broadcast((*left).into());
443			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
444			for j in 0..row_len {
445				// Safety: `current` and `right` have length equal to `row_len`
446				unsafe {
447					op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
448				}
449			}
450		}
451		(CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
452			let left = P::broadcast((*left).into());
453			let right = P::broadcast((*right).into());
454			let mut result = MaybeUninit::uninit();
455			op(left, right, &mut result);
456			for j in 0..row_len {
457				// Safety:
458				// - `current` has length equal to `row_len`
459				// - `result` is initialized by `op`
460				unsafe {
461					current_evals
462						.get_unchecked_mut(j)
463						.write(result.assume_init());
464				}
465			}
466		}
467	}
468}
469
470#[cfg(test)]
471mod tests {
472	use binius_field::{
473		BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedField, TowerField,
474	};
475	use binius_math::{CompositionPoly, RowsBatch};
476	use binius_utils::felts;
477
478	use super::*;
479
480	#[test]
481	fn test_constant() {
482		type F = BinaryField8b;
483		type P = PackedBinaryField8x16b;
484
485		let expr = ArithExpr::Const(F::new(123));
486		let circuit = ArithCircuitPoly::<F>::new(expr);
487
488		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
489		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
490		assert_eq!(typed_circuit.degree(), 0);
491		assert_eq!(typed_circuit.n_vars(), 0);
492
493		assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
494
495		let mut evals = [P::default()];
496		typed_circuit
497			.batch_evaluate(&RowsBatchRef::new(&[], 1), &mut evals)
498			.unwrap();
499		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
500	}
501
502	#[test]
503	fn test_identity() {
504		type F = BinaryField8b;
505		type P = PackedBinaryField8x16b;
506
507		// x0
508		let expr = ArithExpr::Var(0);
509		let circuit = ArithCircuitPoly::<F>::new(expr);
510
511		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
512		assert_eq!(typed_circuit.binary_tower_level(), 0);
513		assert_eq!(typed_circuit.degree(), 1);
514		assert_eq!(typed_circuit.n_vars(), 1);
515
516		assert_eq!(
517			typed_circuit
518				.evaluate(&[P::broadcast(F::new(123).into())])
519				.unwrap(),
520			P::broadcast(F::new(123).into())
521		);
522
523		let mut evals = [P::default()];
524		let batch_query = [[P::broadcast(F::new(123).into())]; 1];
525		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
526		typed_circuit
527			.batch_evaluate(&batch_query.get_ref(), &mut evals)
528			.unwrap();
529		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
530	}
531
532	#[test]
533	fn test_add() {
534		type F = BinaryField8b;
535		type P = PackedBinaryField8x16b;
536
537		// 123 + x0
538		let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
539		let circuit = ArithCircuitPoly::<F>::new(expr);
540
541		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
542		assert_eq!(typed_circuit.binary_tower_level(), 3);
543		assert_eq!(typed_circuit.degree(), 1);
544		assert_eq!(typed_circuit.n_vars(), 1);
545
546		assert_eq!(
547			CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
548			P::broadcast(F::new(123).into())
549		);
550	}
551
552	#[test]
553	fn test_mul() {
554		type F = BinaryField8b;
555		type P = PackedBinaryField8x16b;
556
557		// 123 * x0
558		let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
559		let circuit = ArithCircuitPoly::<F>::new(expr);
560
561		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
562		assert_eq!(typed_circuit.binary_tower_level(), 3);
563		assert_eq!(typed_circuit.degree(), 1);
564		assert_eq!(typed_circuit.n_vars(), 1);
565
566		assert_eq!(
567			CompositionPoly::evaluate(
568				&circuit,
569				&[P::from_scalars(
570					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
571				)]
572			)
573			.unwrap(),
574			P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
575		);
576	}
577
578	#[test]
579	fn test_pow() {
580		type F = BinaryField8b;
581		type P = PackedBinaryField8x16b;
582
583		// x0^13
584		let expr = ArithExpr::Var(0).pow(13);
585		let circuit = ArithCircuitPoly::<F>::new(expr);
586
587		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
588		assert_eq!(typed_circuit.binary_tower_level(), 0);
589		assert_eq!(typed_circuit.degree(), 13);
590		assert_eq!(typed_circuit.n_vars(), 1);
591
592		assert_eq!(
593			CompositionPoly::evaluate(
594				&circuit,
595				&[P::from_scalars(
596					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
597				)]
598			)
599			.unwrap(),
600			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
601		);
602	}
603
604	#[test]
605	fn test_mixed() {
606		type F = BinaryField8b;
607		type P = PackedBinaryField8x16b;
608
609		// x0^2 * (x1 + 123)
610		let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
611		let circuit = ArithCircuitPoly::<F>::new(expr);
612
613		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
614		assert_eq!(typed_circuit.binary_tower_level(), 3);
615		assert_eq!(typed_circuit.degree(), 3);
616		assert_eq!(typed_circuit.n_vars(), 2);
617
618		// test evaluate
619		assert_eq!(
620			CompositionPoly::evaluate(
621				&circuit,
622				&[
623					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
624					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
625				]
626			)
627			.unwrap(),
628			P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
629		);
630
631		// test batch evaluate
632		let query1 = &[
633			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
634			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
635		];
636		let query2 = &[
637			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
638			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
639		];
640		let query3 = &[
641			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
642			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
643		];
644		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
645		let expected2 =
646			P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
647		let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
648
649		let mut batch_result = vec![P::zero(); 3];
650		let batch_query = &[
651			&[query1[0], query2[0], query3[0]],
652			&[query1[1], query2[1], query3[1]],
653		];
654		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
655
656		CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
657			.unwrap();
658		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
659	}
660
661	#[test]
662	fn test_const_fold() {
663		type F = BinaryField8b;
664		type P = PackedBinaryField8x16b;
665
666		// x0 * ((122 * 123) + (124 + 125)) + x1
667		let expr = ArithExpr::Var(0)
668			* ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
669				+ (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
670			+ ArithExpr::Var(1);
671		let circuit = ArithCircuitPoly::<F>::new(expr);
672		assert_eq!(circuit.steps.len(), 2);
673
674		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
675		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
676		assert_eq!(typed_circuit.degree(), 1);
677		assert_eq!(typed_circuit.n_vars(), 2);
678
679		// test evaluate
680		assert_eq!(
681			CompositionPoly::evaluate(
682				&circuit,
683				&[
684					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
685					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
686				]
687			)
688			.unwrap(),
689			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
690		);
691
692		// test batch evaluate
693		let query1 = &[
694			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
695			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
696		];
697		let query2 = &[
698			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
699			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
700		];
701		let query3 = &[
702			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
703			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
704		];
705		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
706		let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
707		let expected3 =
708			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
709
710		let mut batch_result = vec![P::zero(); 3];
711		let batch_query = &[
712			&[query1[0], query2[0], query3[0]],
713			&[query1[1], query2[1], query3[1]],
714		];
715		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
716		CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
717			.unwrap();
718		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
719	}
720
721	#[test]
722	fn test_pow_const_fold() {
723		type F = BinaryField8b;
724		type P = PackedBinaryField8x16b;
725
726		// x0 + 2^5
727		let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
728		let circuit = ArithCircuitPoly::<F>::new(expr);
729		assert_eq!(circuit.steps.len(), 1);
730
731		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
732		assert_eq!(typed_circuit.binary_tower_level(), 1);
733		assert_eq!(typed_circuit.degree(), 1);
734		assert_eq!(typed_circuit.n_vars(), 1);
735
736		assert_eq!(
737			CompositionPoly::evaluate(
738				&circuit,
739				&[P::from_scalars(
740					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
741				)]
742			)
743			.unwrap(),
744			P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
745		);
746	}
747
748	#[test]
749	fn test_pow_nested() {
750		type F = BinaryField8b;
751		type P = PackedBinaryField8x16b;
752
753		// ((x0^2)^3)^4
754		let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
755		let circuit = ArithCircuitPoly::<F>::new(expr);
756		assert_eq!(circuit.steps.len(), 5);
757
758		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
759		assert_eq!(typed_circuit.binary_tower_level(), 0);
760		assert_eq!(typed_circuit.degree(), 24);
761		assert_eq!(typed_circuit.n_vars(), 1);
762
763		assert_eq!(
764			CompositionPoly::evaluate(
765				&circuit,
766				&[P::from_scalars(
767					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
768				)]
769			)
770			.unwrap(),
771			P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
772		);
773	}
774
775	#[test]
776	fn test_circuit_steps_for_expr_constant() {
777		type F = BinaryField8b;
778
779		let expr = ArithExpr::Const(F::new(5));
780		let (steps, retval) = circuit_steps_for_expr(&expr);
781
782		assert!(steps.is_empty(), "No steps should be generated for a constant");
783		assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
784	}
785
786	#[test]
787	fn test_circuit_steps_for_expr_variable() {
788		type F = BinaryField8b;
789
790		let expr = ArithExpr::<F>::Var(18);
791		let (steps, retval) = circuit_steps_for_expr(&expr);
792
793		assert!(steps.is_empty(), "No steps should be generated for a variable");
794		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
795	}
796
797	#[test]
798	fn test_circuit_steps_for_expr_addition() {
799		type F = BinaryField8b;
800
801		let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
802		let (steps, retval) = circuit_steps_for_expr(&expr);
803
804		assert_eq!(steps.len(), 1, "One addition step should be generated");
805		assert!(matches!(
806			steps[0],
807			CircuitStep::Add(
808				CircuitStepArgument::Expr(CircuitNode::Var(14)),
809				CircuitStepArgument::Expr(CircuitNode::Var(56))
810			)
811		));
812		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
813	}
814
815	#[test]
816	fn test_circuit_steps_for_expr_multiplication() {
817		type F = BinaryField8b;
818
819		let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
820		let (steps, retval) = circuit_steps_for_expr(&expr);
821
822		assert_eq!(steps.len(), 1, "One multiplication step should be generated");
823		assert!(matches!(
824			steps[0],
825			CircuitStep::Mul(
826				CircuitStepArgument::Expr(CircuitNode::Var(36)),
827				CircuitStepArgument::Expr(CircuitNode::Var(26))
828			)
829		));
830		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
831	}
832
833	#[test]
834	fn test_circuit_steps_for_expr_pow_1() {
835		type F = BinaryField8b;
836
837		let expr = ArithExpr::<F>::Var(12).pow(1);
838		let (steps, retval) = circuit_steps_for_expr(&expr);
839
840		// No steps should be generated for x^1
841		assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
842
843		// The return value should just be the variable itself
844		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
845	}
846
847	#[test]
848	fn test_circuit_steps_for_expr_pow_2() {
849		type F = BinaryField8b;
850
851		let expr = ArithExpr::<F>::Var(10).pow(2);
852		let (steps, retval) = circuit_steps_for_expr(&expr);
853
854		assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
855		assert!(matches!(
856			steps[0],
857			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
858		));
859		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
860	}
861
862	#[test]
863	fn test_circuit_steps_for_expr_pow_3() {
864		type F = BinaryField8b;
865
866		let expr = ArithExpr::<F>::Var(5).pow(3);
867		let (steps, retval) = circuit_steps_for_expr(&expr);
868
869		assert_eq!(
870			steps.len(),
871			2,
872			"Pow(3) should generate one squaring and one multiplication step"
873		);
874		assert!(matches!(
875			steps[0],
876			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
877		));
878		assert!(matches!(
879			steps[1],
880			CircuitStep::Mul(
881				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
882				CircuitStepArgument::Expr(CircuitNode::Var(5))
883			)
884		));
885		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
886	}
887
888	#[test]
889	fn test_circuit_steps_for_expr_pow_4() {
890		type F = BinaryField8b;
891
892		let expr = ArithExpr::<F>::Var(7).pow(4);
893		let (steps, retval) = circuit_steps_for_expr(&expr);
894
895		assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
896		assert!(matches!(
897			steps[0],
898			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
899		));
900
901		assert!(matches!(
902			steps[1],
903			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
904		));
905
906		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
907	}
908
909	#[test]
910	fn test_circuit_steps_for_expr_pow_5() {
911		type F = BinaryField8b;
912
913		let expr = ArithExpr::<F>::Var(3).pow(5);
914		let (steps, retval) = circuit_steps_for_expr(&expr);
915
916		assert_eq!(
917			steps.len(),
918			3,
919			"Pow(5) should generate two squaring steps and one multiplication"
920		);
921		assert!(matches!(
922			steps[0],
923			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
924		));
925		assert!(matches!(
926			steps[1],
927			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
928		));
929		assert!(matches!(
930			steps[2],
931			CircuitStep::Mul(
932				CircuitStepArgument::Expr(CircuitNode::Slot(1)),
933				CircuitStepArgument::Expr(CircuitNode::Var(3))
934			)
935		));
936
937		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
938	}
939
940	#[test]
941	fn test_circuit_steps_for_expr_pow_8() {
942		type F = BinaryField8b;
943
944		let expr = ArithExpr::<F>::Var(4).pow(8);
945		let (steps, retval) = circuit_steps_for_expr(&expr);
946
947		assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
948		assert!(matches!(
949			steps[0],
950			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
951		));
952		assert!(matches!(
953			steps[1],
954			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
955		));
956		assert!(matches!(
957			steps[2],
958			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
959		));
960
961		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
962	}
963
964	#[test]
965	fn test_circuit_steps_for_expr_pow_9() {
966		type F = BinaryField8b;
967
968		let expr = ArithExpr::<F>::Var(8).pow(9);
969		let (steps, retval) = circuit_steps_for_expr(&expr);
970
971		assert_eq!(
972			steps.len(),
973			4,
974			"Pow(9) should generate three squaring steps and one multiplication"
975		);
976		assert!(matches!(
977			steps[0],
978			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
979		));
980		assert!(matches!(
981			steps[1],
982			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
983		));
984		assert!(matches!(
985			steps[2],
986			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
987		));
988		assert!(matches!(
989			steps[3],
990			CircuitStep::Mul(
991				CircuitStepArgument::Expr(CircuitNode::Slot(2)),
992				CircuitStepArgument::Expr(CircuitNode::Var(8))
993			)
994		));
995
996		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
997	}
998
999	#[test]
1000	fn test_circuit_steps_for_expr_pow_12() {
1001		type F = BinaryField8b;
1002		let expr = ArithExpr::<F>::Var(6).pow(12);
1003		let (steps, retval) = circuit_steps_for_expr(&expr);
1004
1005		assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1006
1007		assert!(matches!(
1008			steps[0],
1009			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1010		));
1011		assert!(matches!(
1012			steps[1],
1013			CircuitStep::Mul(
1014				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1015				CircuitStepArgument::Expr(CircuitNode::Var(6))
1016			)
1017		));
1018		assert!(matches!(
1019			steps[2],
1020			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1021		));
1022		assert!(matches!(
1023			steps[3],
1024			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1025		));
1026
1027		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1028	}
1029
1030	#[test]
1031	fn test_circuit_steps_for_expr_pow_13() {
1032		type F = BinaryField8b;
1033		let expr = ArithExpr::<F>::Var(7).pow(13);
1034		let (steps, retval) = circuit_steps_for_expr(&expr);
1035
1036		assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1037		assert!(matches!(
1038			steps[0],
1039			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1040		));
1041		assert!(matches!(
1042			steps[1],
1043			CircuitStep::Mul(
1044				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1045				CircuitStepArgument::Expr(CircuitNode::Var(7))
1046			)
1047		));
1048		assert!(matches!(
1049			steps[2],
1050			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1051		));
1052		assert!(matches!(
1053			steps[3],
1054			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1055		));
1056		assert!(matches!(
1057			steps[4],
1058			CircuitStep::Mul(
1059				CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1060				CircuitStepArgument::Expr(CircuitNode::Var(7))
1061			)
1062		));
1063		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1064	}
1065
1066	#[test]
1067	fn test_circuit_steps_for_expr_complex() {
1068		type F = BinaryField8b;
1069
1070		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1071			+ (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1072			- ArithExpr::Var(3);
1073
1074		let (steps, retval) = circuit_steps_for_expr(&expr);
1075
1076		assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1077
1078		assert!(
1079			matches!(
1080				steps[0],
1081				CircuitStep::Mul(
1082					CircuitStepArgument::Expr(CircuitNode::Var(0)),
1083					CircuitStepArgument::Expr(CircuitNode::Var(1))
1084				)
1085			),
1086			"First step should be multiplication x0 * x1"
1087		);
1088
1089		assert!(
1090			matches!(
1091				steps[1],
1092				CircuitStep::Add(
1093					CircuitStepArgument::Const(F::ONE),
1094					CircuitStepArgument::Expr(CircuitNode::Var(0))
1095				)
1096			),
1097			"Second step should be (1 - x0)"
1098		);
1099
1100		assert!(
1101			matches!(
1102				steps[2],
1103				CircuitStep::AddMul(
1104					0,
1105					CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1106					CircuitStepArgument::Expr(CircuitNode::Var(2))
1107				)
1108			),
1109			"Third step should be (1 - x0) * x2"
1110		);
1111
1112		assert!(
1113			matches!(
1114				steps[3],
1115				CircuitStep::Add(
1116					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1117					CircuitStepArgument::Expr(CircuitNode::Var(3))
1118				)
1119			),
1120			"Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1121		);
1122
1123		assert!(
1124			matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1125			"Final result should be stored in Slot(3)"
1126		);
1127	}
1128}