binius_core/polynomial/
arith_circuit.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{fmt::Debug, mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithExpr, CompositionPoly, Error};
7use binius_utils::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
8use stackalloc::{
9	helpers::{slice_assume_init, slice_assume_init_mut},
10	stackalloc_uninit,
11};
12
13/// Convert the expression to a sequence of arithmetic operations that can be evaluated in sequence.
14fn circuit_steps_for_expr<F: Field>(
15	expr: &ArithExpr<F>,
16) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
17	let mut steps = Vec::new();
18
19	fn to_circuit_inner<F: Field>(
20		expr: &ArithExpr<F>,
21		result: &mut Vec<CircuitStep<F>>,
22	) -> CircuitStepArgument<F> {
23		match expr {
24			ArithExpr::Const(value) => CircuitStepArgument::Const(*value),
25			ArithExpr::Var(index) => CircuitStepArgument::Expr(CircuitNode::Var(*index)),
26			ArithExpr::Add(left, right) => match &**right {
27				ArithExpr::Mul(mleft, mright) if left.is_composite() => {
28					// Only handling e1 + (e2 * e3), not (e1 * e2) + e3, as latter was not observed in practice
29					// (the former can be enforced by rewriting expression).
30					let left = to_circuit_inner(left, result);
31					let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left else {
32						unreachable!("guaranteed by `is_composite` check above")
33					};
34					let mleft = to_circuit_inner(mleft, result);
35					let mright = to_circuit_inner(mright, result);
36					result.push(CircuitStep::AddMul(left, mleft, mright));
37					CircuitStepArgument::Expr(CircuitNode::Slot(left))
38				}
39				_ => {
40					let left = to_circuit_inner(left, result);
41					let right = to_circuit_inner(right, result);
42					result.push(CircuitStep::Add(left, right));
43					CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
44				}
45			},
46			ArithExpr::Mul(left, right) => {
47				let left = to_circuit_inner(left, result);
48				let right = to_circuit_inner(right, result);
49				result.push(CircuitStep::Mul(left, right));
50				CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
51			}
52			ArithExpr::Pow(base, exp) => {
53				let mut acc = to_circuit_inner(base, result);
54				let base_expr = acc;
55				let highest_bit = exp.ilog2();
56
57				for i in (0..highest_bit).rev() {
58					result.push(CircuitStep::Square(acc));
59					acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
60
61					if (exp >> i) & 1 != 0 {
62						result.push(CircuitStep::Mul(acc, base_expr));
63						acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
64					}
65				}
66
67				acc
68			}
69		}
70	}
71
72	let ret = to_circuit_inner(&expr.optimize(), &mut steps);
73	(steps, ret)
74}
75
76/// Input of the circuit calculation step
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78enum CircuitNode {
79	/// Input variable
80	Var(usize),
81	/// Evaluation at one of the previous steps
82	Slot(usize),
83}
84
85impl CircuitNode {
86	/// Return either the input column or the slice with evaluations at one of the previous steps.
87	/// This method is used for batch evaluation.
88	fn get_sparse_chunk<'a, P: PackedField>(
89		&self,
90		inputs: &[&'a [P]],
91		evals: &'a [P],
92		row_len: usize,
93	) -> &'a [P] {
94		match self {
95			Self::Var(index) => inputs[*index],
96			Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
97		}
98	}
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102enum CircuitStepArgument<F> {
103	Expr(CircuitNode),
104	Const(F),
105}
106
107/// Describes computation symbolically. This is used internally by ArithCircuitPoly.
108///
109/// ExprIds used by an Expr has to be less than the index of the Expr itself within the ArithCircuitPoly,
110/// to ensure it represents a directed acyclic graph that can be computed in sequence.
111#[derive(Debug)]
112enum CircuitStep<F: Field> {
113	Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
114	Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
115	Square(CircuitStepArgument<F>),
116	AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
117}
118
119/// Describes polynomial evaluations using a directed acyclic graph of expressions.
120///
121/// This is meant as an alternative to a hard-coded CompositionPoly.
122///
123/// The advantage over a hard coded CompositionPoly is that this can be constructed and manipulated dynamically at runtime
124/// and the object representing different polnomials can be stored in a homogeneous collection.
125#[derive(Debug, Clone)]
126pub struct ArithCircuitPoly<F: Field> {
127	expr: ArithExpr<F>,
128	steps: Arc<[CircuitStep<F>]>,
129	/// The "top level expression", which depends on circuit expression evaluations
130	retval: CircuitStepArgument<F>,
131	degree: usize,
132	n_vars: usize,
133	tower_level: usize,
134}
135
136impl<F: Field> PartialEq for ArithCircuitPoly<F> {
137	fn eq(&self, other: &Self) -> bool {
138		self.n_vars == other.n_vars && self.expr == other.expr
139	}
140}
141
142impl<F: Field> Eq for ArithCircuitPoly<F> {}
143
144impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
145	fn serialize(
146		&self,
147		mut write_buf: impl bytes::BufMut,
148		mode: SerializationMode,
149	) -> Result<(), SerializationError> {
150		(&self.expr, self.n_vars).serialize(&mut write_buf, mode)
151	}
152}
153
154impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
155	fn deserialize(
156		read_buf: impl bytes::Buf,
157		mode: SerializationMode,
158	) -> Result<Self, SerializationError>
159	where
160		Self: Sized,
161	{
162		let (expr, n_vars) = <(ArithExpr<F>, usize)>::deserialize(read_buf, mode)?;
163		Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
164			name: "ArithCircuitPoly",
165		})
166	}
167}
168
169impl<F: TowerField> ArithCircuitPoly<F> {
170	pub fn new(expr: ArithExpr<F>) -> Self {
171		let degree = expr.degree();
172		let n_vars = expr.n_vars();
173		let tower_level = expr.binary_tower_level();
174		let (exprs, retval) = circuit_steps_for_expr(&expr);
175
176		Self {
177			expr,
178			steps: exprs.into(),
179			retval,
180			degree,
181			n_vars,
182			tower_level,
183		}
184	}
185
186	/// Constructs an [`ArithCircuitPoly`] with the given number of variables.
187	///
188	/// The number of variables may be greater than the number of variables actually read in the
189	/// arithmetic expression.
190	pub fn with_n_vars(n_vars: usize, expr: ArithExpr<F>) -> Result<Self, Error> {
191		let degree = expr.degree();
192		let tower_level = expr.binary_tower_level();
193		if n_vars < expr.n_vars() {
194			return Err(Error::IncorrectNumberOfVariables {
195				expected: expr.n_vars(),
196				actual: n_vars,
197			});
198		}
199		let (exprs, retval) = circuit_steps_for_expr(&expr);
200
201		Ok(Self {
202			expr,
203			steps: exprs.into(),
204			retval,
205			n_vars,
206			degree,
207			tower_level,
208		})
209	}
210}
211
212impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
213	for ArithCircuitPoly<F>
214{
215	fn degree(&self) -> usize {
216		self.degree
217	}
218
219	fn n_vars(&self) -> usize {
220		self.n_vars
221	}
222
223	fn binary_tower_level(&self) -> usize {
224		self.tower_level
225	}
226
227	fn expression(&self) -> ArithExpr<P::Scalar> {
228		self.expr.convert_field()
229	}
230
231	fn evaluate(&self, query: &[P]) -> Result<P, Error> {
232		if query.len() != self.n_vars {
233			return Err(Error::IncorrectQuerySize {
234				expected: self.n_vars,
235			});
236		}
237
238		fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
239			// Safety: The index is guaranteed to be within bounds because
240			// we initialize at least `self.steps.len()` using `stackalloc`.
241			unsafe {
242				target.get_unchecked_mut(0).write(value);
243			}
244		}
245
246		// `stackalloc_uninit` throws a debug assert if `size` is 0, so set minimum of 1.
247		stackalloc_uninit::<P, _, _>(self.steps.len().max(1), |evals| {
248			let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
249				// Safety: The index is guaranteed to be within bounds by the construction of the circuit
250				CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
251					*query.get_unchecked(index)
252				},
253				// Safety: The index is guaranteed to be within bounds by the circuit evaluation order
254				CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
255					*evals.get_unchecked(slot)
256				},
257				CircuitStepArgument::Const(value) => P::broadcast(value.into()),
258			};
259
260			for (i, expr) in self.steps.iter().enumerate() {
261				// Safety: previous evaluations are initialized by the previous loop iterations (if dereferenced)
262				let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
263				let before = unsafe { slice_assume_init_mut(before) };
264				match expr {
265					CircuitStep::Add(x, y) => write_result(
266						after,
267						get_argument_value(*x, before) + get_argument_value(*y, before),
268					),
269					CircuitStep::AddMul(target_slot, x, y) => {
270						let intermediate =
271							get_argument_value(*x, before) * get_argument_value(*y, before);
272						// Safety: we know by evaluation order and construction of steps that `target.slot` is initialized
273						let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
274						*target_slot += intermediate;
275					}
276					CircuitStep::Mul(x, y) => write_result(
277						after,
278						get_argument_value(*x, before) * get_argument_value(*y, before),
279					),
280					CircuitStep::Square(x) => {
281						write_result(after, get_argument_value(*x, before).square())
282					}
283				};
284			}
285
286			// Some slots in `evals` might be empty, but we're guaranted that
287			// if `self.retval` points to a slot, that this slot is initialized.
288			unsafe {
289				let evals = slice_assume_init(evals);
290				Ok(get_argument_value(self.retval, evals))
291			}
292		})
293	}
294
295	fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> {
296		let row_len = evals.len();
297		if batch_query.iter().any(|row| row.len() != row_len) {
298			return Err(Error::BatchEvaluateSizeMismatch);
299		}
300
301		// `stackalloc_uninit` throws a debug assert if `size` is 0, so set minimum of 1.
302		stackalloc_uninit::<P, (), _>((self.steps.len() * row_len).max(1), |sparse_evals| {
303			for (i, expr) in self.steps.iter().enumerate() {
304				let (before, current) = sparse_evals.split_at_mut(i * row_len);
305
306				// Safety: `before` is guaranteed to be initialized by the previous loop iterations (if dereferenced).
307				let before = unsafe { slice_assume_init_mut(before) };
308				let current = &mut current[..row_len];
309
310				match expr {
311					CircuitStep::Add(left, right) => {
312						apply_binary_op(
313							left,
314							right,
315							batch_query,
316							before,
317							current,
318							|left, right, out| {
319								out.write(left + right);
320							},
321						);
322					}
323					CircuitStep::Mul(left, right) => {
324						apply_binary_op(
325							left,
326							right,
327							batch_query,
328							before,
329							current,
330							|left, right, out| {
331								out.write(left * right);
332							},
333						);
334					}
335					CircuitStep::Square(arg) => {
336						match arg {
337							CircuitStepArgument::Expr(node) => {
338								let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
339								for j in 0..row_len {
340									// Safety: `current` and `id_chunk` have length equal to `row_len`
341									unsafe {
342										current
343											.get_unchecked_mut(j)
344											.write(id_chunk.get_unchecked(j).square());
345									}
346								}
347							}
348							CircuitStepArgument::Const(value) => {
349								let value: P = P::broadcast((*value).into());
350								let result = value.square();
351								for j in 0..row_len {
352									// Safety: `current` has length equal to `row_len`
353									unsafe {
354										current.get_unchecked_mut(j).write(result);
355									}
356								}
357							}
358						}
359					}
360					CircuitStep::AddMul(target, left, right) => {
361						let target = &before[row_len * target..(target + 1) * row_len];
362						// Safety: by construction of steps and evaluation order we know
363						// that `target` is not borrowed elsewhere.
364						let target: &mut [MaybeUninit<P>] = unsafe {
365							std::slice::from_raw_parts_mut(
366								target.as_ptr() as *mut MaybeUninit<P>,
367								target.len(),
368							)
369						};
370						apply_binary_op(
371							left,
372							right,
373							batch_query,
374							before,
375							target,
376							// Safety: by construction of steps and evaluation order we know
377							// that `target`/`out` is initialized.
378							|left, right, out| unsafe {
379								let out = out.assume_init_mut();
380								*out += left * right;
381							},
382						);
383					}
384				}
385			}
386
387			match self.retval {
388				CircuitStepArgument::Expr(node) => {
389					// Safety: `sparse_evals` is fully initialized by the previous loop iterations
390					let sparse_evals = unsafe { slice_assume_init(sparse_evals) };
391					evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
392				}
393				CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
394			}
395		});
396
397		Ok(())
398	}
399}
400
401/// Apply a binary operation to two arguments and store the result in `current_evals`.
402/// `op` must be a function that takes two arguments and initialized the result with the third argument.
403fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
404	left: &CircuitStepArgument<F>,
405	right: &CircuitStepArgument<F>,
406	batch_query: &[&[P]],
407	evals_before: &[P],
408	current_evals: &mut [MaybeUninit<P>],
409	op: impl Fn(P, P, &mut MaybeUninit<P>),
410) {
411	let row_len = current_evals.len();
412
413	match (left, right) {
414		(CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
415			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
416			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
417			for j in 0..row_len {
418				// Safety: `current`, `left` and `right` have length equal to `row_len`
419				unsafe {
420					op(
421						*left.get_unchecked(j),
422						*right.get_unchecked(j),
423						current_evals.get_unchecked_mut(j),
424					)
425				}
426			}
427		}
428		(CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
429			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
430			let right = P::broadcast((*right).into());
431			for j in 0..row_len {
432				// Safety: `current` and `left` have length equal to `row_len`
433				unsafe {
434					op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
435				}
436			}
437		}
438		(CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
439			let left = P::broadcast((*left).into());
440			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
441			for j in 0..row_len {
442				// Safety: `current` and `right` have length equal to `row_len`
443				unsafe {
444					op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
445				}
446			}
447		}
448		(CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
449			let left = P::broadcast((*left).into());
450			let right = P::broadcast((*right).into());
451			let mut result = MaybeUninit::uninit();
452			op(left, right, &mut result);
453			for j in 0..row_len {
454				// Safety:
455				// - `current` has length equal to `row_len`
456				// - `result` is initialized by `op`
457				unsafe {
458					current_evals
459						.get_unchecked_mut(j)
460						.write(result.assume_init());
461				}
462			}
463		}
464	}
465}
466
467#[cfg(test)]
468mod tests {
469	use binius_field::{
470		BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedField, TowerField,
471	};
472	use binius_math::CompositionPoly;
473	use binius_utils::felts;
474
475	use super::*;
476
477	#[test]
478	fn test_constant() {
479		type F = BinaryField8b;
480		type P = PackedBinaryField8x16b;
481
482		let expr = ArithExpr::Const(F::new(123));
483		let circuit = ArithCircuitPoly::<F>::new(expr);
484
485		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
486		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
487		assert_eq!(typed_circuit.degree(), 0);
488		assert_eq!(typed_circuit.n_vars(), 0);
489
490		assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
491
492		let mut evals = [P::default()];
493		typed_circuit.batch_evaluate(&[], &mut evals).unwrap();
494		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
495	}
496
497	#[test]
498	fn test_identity() {
499		type F = BinaryField8b;
500		type P = PackedBinaryField8x16b;
501
502		// x0
503		let expr = ArithExpr::Var(0);
504		let circuit = ArithCircuitPoly::<F>::new(expr);
505
506		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
507		assert_eq!(typed_circuit.binary_tower_level(), 0);
508		assert_eq!(typed_circuit.degree(), 1);
509		assert_eq!(typed_circuit.n_vars(), 1);
510
511		assert_eq!(
512			typed_circuit
513				.evaluate(&[P::broadcast(F::new(123).into())])
514				.unwrap(),
515			P::broadcast(F::new(123).into())
516		);
517
518		let mut evals = [P::default()];
519		typed_circuit
520			.batch_evaluate(&[&[P::broadcast(F::new(123).into())]], &mut evals)
521			.unwrap();
522		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
523	}
524
525	#[test]
526	fn test_add() {
527		type F = BinaryField8b;
528		type P = PackedBinaryField8x16b;
529
530		// 123 + x0
531		let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
532		let circuit = ArithCircuitPoly::<F>::new(expr);
533
534		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
535		assert_eq!(typed_circuit.binary_tower_level(), 3);
536		assert_eq!(typed_circuit.degree(), 1);
537		assert_eq!(typed_circuit.n_vars(), 1);
538
539		assert_eq!(
540			CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
541			P::broadcast(F::new(123).into())
542		);
543	}
544
545	#[test]
546	fn test_mul() {
547		type F = BinaryField8b;
548		type P = PackedBinaryField8x16b;
549
550		// 123 * x0
551		let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
552		let circuit = ArithCircuitPoly::<F>::new(expr);
553
554		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
555		assert_eq!(typed_circuit.binary_tower_level(), 3);
556		assert_eq!(typed_circuit.degree(), 1);
557		assert_eq!(typed_circuit.n_vars(), 1);
558
559		assert_eq!(
560			CompositionPoly::evaluate(
561				&circuit,
562				&[P::from_scalars(
563					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
564				)]
565			)
566			.unwrap(),
567			P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
568		);
569	}
570
571	#[test]
572	fn test_pow() {
573		type F = BinaryField8b;
574		type P = PackedBinaryField8x16b;
575
576		// x0^13
577		let expr = ArithExpr::Var(0).pow(13);
578		let circuit = ArithCircuitPoly::<F>::new(expr);
579
580		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
581		assert_eq!(typed_circuit.binary_tower_level(), 0);
582		assert_eq!(typed_circuit.degree(), 13);
583		assert_eq!(typed_circuit.n_vars(), 1);
584
585		assert_eq!(
586			CompositionPoly::evaluate(
587				&circuit,
588				&[P::from_scalars(
589					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
590				)]
591			)
592			.unwrap(),
593			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
594		);
595	}
596
597	#[test]
598	fn test_mixed() {
599		type F = BinaryField8b;
600		type P = PackedBinaryField8x16b;
601
602		// x0^2 * (x1 + 123)
603		let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
604		let circuit = ArithCircuitPoly::<F>::new(expr);
605
606		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
607		assert_eq!(typed_circuit.binary_tower_level(), 3);
608		assert_eq!(typed_circuit.degree(), 3);
609		assert_eq!(typed_circuit.n_vars(), 2);
610
611		// test evaluate
612		assert_eq!(
613			CompositionPoly::evaluate(
614				&circuit,
615				&[
616					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
617					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
618				]
619			)
620			.unwrap(),
621			P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
622		);
623
624		// test batch evaluate
625		let query1 = &[
626			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
627			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
628		];
629		let query2 = &[
630			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
631			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
632		];
633		let query3 = &[
634			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
635			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
636		];
637		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
638		let expected2 =
639			P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
640		let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
641
642		let mut batch_result = vec![P::zero(); 3];
643		CompositionPoly::batch_evaluate(
644			&circuit,
645			&[
646				&[query1[0], query2[0], query3[0]],
647				&[query1[1], query2[1], query3[1]],
648			],
649			&mut batch_result,
650		)
651		.unwrap();
652		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
653	}
654
655	#[test]
656	fn test_const_fold() {
657		type F = BinaryField8b;
658		type P = PackedBinaryField8x16b;
659
660		// x0 * ((122 * 123) + (124 + 125)) + x1
661		let expr = ArithExpr::Var(0)
662			* ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
663				+ (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
664			+ ArithExpr::Var(1);
665		let circuit = ArithCircuitPoly::<F>::new(expr);
666		assert_eq!(circuit.steps.len(), 2);
667
668		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
669		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
670		assert_eq!(typed_circuit.degree(), 1);
671		assert_eq!(typed_circuit.n_vars(), 2);
672
673		// test evaluate
674		assert_eq!(
675			CompositionPoly::evaluate(
676				&circuit,
677				&[
678					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
679					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
680				]
681			)
682			.unwrap(),
683			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
684		);
685
686		// test batch evaluate
687		let query1 = &[
688			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
689			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
690		];
691		let query2 = &[
692			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
693			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
694		];
695		let query3 = &[
696			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
697			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
698		];
699		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
700		let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
701		let expected3 =
702			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
703
704		let mut batch_result = vec![P::zero(); 3];
705		CompositionPoly::batch_evaluate(
706			&circuit,
707			&[
708				&[query1[0], query2[0], query3[0]],
709				&[query1[1], query2[1], query3[1]],
710			],
711			&mut batch_result,
712		)
713		.unwrap();
714		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
715	}
716
717	#[test]
718	fn test_pow_const_fold() {
719		type F = BinaryField8b;
720		type P = PackedBinaryField8x16b;
721
722		// x0 + 2^5
723		let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
724		let circuit = ArithCircuitPoly::<F>::new(expr);
725		assert_eq!(circuit.steps.len(), 1);
726
727		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
728		assert_eq!(typed_circuit.binary_tower_level(), 1);
729		assert_eq!(typed_circuit.degree(), 1);
730		assert_eq!(typed_circuit.n_vars(), 1);
731
732		assert_eq!(
733			CompositionPoly::evaluate(
734				&circuit,
735				&[P::from_scalars(
736					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
737				)]
738			)
739			.unwrap(),
740			P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
741		);
742	}
743
744	#[test]
745	fn test_pow_nested() {
746		type F = BinaryField8b;
747		type P = PackedBinaryField8x16b;
748
749		// ((x0^2)^3)^4
750		let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
751		let circuit = ArithCircuitPoly::<F>::new(expr);
752		assert_eq!(circuit.steps.len(), 5);
753
754		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
755		assert_eq!(typed_circuit.binary_tower_level(), 0);
756		assert_eq!(typed_circuit.degree(), 24);
757		assert_eq!(typed_circuit.n_vars(), 1);
758
759		assert_eq!(
760			CompositionPoly::evaluate(
761				&circuit,
762				&[P::from_scalars(
763					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
764				)]
765			)
766			.unwrap(),
767			P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
768		);
769	}
770
771	#[test]
772	fn test_circuit_steps_for_expr_constant() {
773		type F = BinaryField8b;
774
775		let expr = ArithExpr::Const(F::new(5));
776		let (steps, retval) = circuit_steps_for_expr(&expr);
777
778		assert!(steps.is_empty(), "No steps should be generated for a constant");
779		assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
780	}
781
782	#[test]
783	fn test_circuit_steps_for_expr_variable() {
784		type F = BinaryField8b;
785
786		let expr = ArithExpr::<F>::Var(18);
787		let (steps, retval) = circuit_steps_for_expr(&expr);
788
789		assert!(steps.is_empty(), "No steps should be generated for a variable");
790		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
791	}
792
793	#[test]
794	fn test_circuit_steps_for_expr_addition() {
795		type F = BinaryField8b;
796
797		let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
798		let (steps, retval) = circuit_steps_for_expr(&expr);
799
800		assert_eq!(steps.len(), 1, "One addition step should be generated");
801		assert!(matches!(
802			steps[0],
803			CircuitStep::Add(
804				CircuitStepArgument::Expr(CircuitNode::Var(14)),
805				CircuitStepArgument::Expr(CircuitNode::Var(56))
806			)
807		));
808		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
809	}
810
811	#[test]
812	fn test_circuit_steps_for_expr_multiplication() {
813		type F = BinaryField8b;
814
815		let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
816		let (steps, retval) = circuit_steps_for_expr(&expr);
817
818		assert_eq!(steps.len(), 1, "One multiplication step should be generated");
819		assert!(matches!(
820			steps[0],
821			CircuitStep::Mul(
822				CircuitStepArgument::Expr(CircuitNode::Var(36)),
823				CircuitStepArgument::Expr(CircuitNode::Var(26))
824			)
825		));
826		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
827	}
828
829	#[test]
830	fn test_circuit_steps_for_expr_pow_1() {
831		type F = BinaryField8b;
832
833		let expr = ArithExpr::<F>::Var(12).pow(1);
834		let (steps, retval) = circuit_steps_for_expr(&expr);
835
836		// No steps should be generated for x^1
837		assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
838
839		// The return value should just be the variable itself
840		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
841	}
842
843	#[test]
844	fn test_circuit_steps_for_expr_pow_2() {
845		type F = BinaryField8b;
846
847		let expr = ArithExpr::<F>::Var(10).pow(2);
848		let (steps, retval) = circuit_steps_for_expr(&expr);
849
850		assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
851		assert!(matches!(
852			steps[0],
853			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
854		));
855		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
856	}
857
858	#[test]
859	fn test_circuit_steps_for_expr_pow_3() {
860		type F = BinaryField8b;
861
862		let expr = ArithExpr::<F>::Var(5).pow(3);
863		let (steps, retval) = circuit_steps_for_expr(&expr);
864
865		assert_eq!(
866			steps.len(),
867			2,
868			"Pow(3) should generate one squaring and one multiplication step"
869		);
870		assert!(matches!(
871			steps[0],
872			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
873		));
874		assert!(matches!(
875			steps[1],
876			CircuitStep::Mul(
877				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
878				CircuitStepArgument::Expr(CircuitNode::Var(5))
879			)
880		));
881		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
882	}
883
884	#[test]
885	fn test_circuit_steps_for_expr_pow_4() {
886		type F = BinaryField8b;
887
888		let expr = ArithExpr::<F>::Var(7).pow(4);
889		let (steps, retval) = circuit_steps_for_expr(&expr);
890
891		assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
892		assert!(matches!(
893			steps[0],
894			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
895		));
896
897		assert!(matches!(
898			steps[1],
899			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
900		));
901
902		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
903	}
904
905	#[test]
906	fn test_circuit_steps_for_expr_pow_5() {
907		type F = BinaryField8b;
908
909		let expr = ArithExpr::<F>::Var(3).pow(5);
910		let (steps, retval) = circuit_steps_for_expr(&expr);
911
912		assert_eq!(
913			steps.len(),
914			3,
915			"Pow(5) should generate two squaring steps and one multiplication"
916		);
917		assert!(matches!(
918			steps[0],
919			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
920		));
921		assert!(matches!(
922			steps[1],
923			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
924		));
925		assert!(matches!(
926			steps[2],
927			CircuitStep::Mul(
928				CircuitStepArgument::Expr(CircuitNode::Slot(1)),
929				CircuitStepArgument::Expr(CircuitNode::Var(3))
930			)
931		));
932
933		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
934	}
935
936	#[test]
937	fn test_circuit_steps_for_expr_pow_8() {
938		type F = BinaryField8b;
939
940		let expr = ArithExpr::<F>::Var(4).pow(8);
941		let (steps, retval) = circuit_steps_for_expr(&expr);
942
943		assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
944		assert!(matches!(
945			steps[0],
946			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
947		));
948		assert!(matches!(
949			steps[1],
950			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
951		));
952		assert!(matches!(
953			steps[2],
954			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
955		));
956
957		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
958	}
959
960	#[test]
961	fn test_circuit_steps_for_expr_pow_9() {
962		type F = BinaryField8b;
963
964		let expr = ArithExpr::<F>::Var(8).pow(9);
965		let (steps, retval) = circuit_steps_for_expr(&expr);
966
967		assert_eq!(
968			steps.len(),
969			4,
970			"Pow(9) should generate three squaring steps and one multiplication"
971		);
972		assert!(matches!(
973			steps[0],
974			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
975		));
976		assert!(matches!(
977			steps[1],
978			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
979		));
980		assert!(matches!(
981			steps[2],
982			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
983		));
984		assert!(matches!(
985			steps[3],
986			CircuitStep::Mul(
987				CircuitStepArgument::Expr(CircuitNode::Slot(2)),
988				CircuitStepArgument::Expr(CircuitNode::Var(8))
989			)
990		));
991
992		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
993	}
994
995	#[test]
996	fn test_circuit_steps_for_expr_pow_12() {
997		type F = BinaryField8b;
998		let expr = ArithExpr::<F>::Var(6).pow(12);
999		let (steps, retval) = circuit_steps_for_expr(&expr);
1000
1001		assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1002
1003		assert!(matches!(
1004			steps[0],
1005			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1006		));
1007		assert!(matches!(
1008			steps[1],
1009			CircuitStep::Mul(
1010				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1011				CircuitStepArgument::Expr(CircuitNode::Var(6))
1012			)
1013		));
1014		assert!(matches!(
1015			steps[2],
1016			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1017		));
1018		assert!(matches!(
1019			steps[3],
1020			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1021		));
1022
1023		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1024	}
1025
1026	#[test]
1027	fn test_circuit_steps_for_expr_pow_13() {
1028		type F = BinaryField8b;
1029		let expr = ArithExpr::<F>::Var(7).pow(13);
1030		let (steps, retval) = circuit_steps_for_expr(&expr);
1031
1032		assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1033		assert!(matches!(
1034			steps[0],
1035			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1036		));
1037		assert!(matches!(
1038			steps[1],
1039			CircuitStep::Mul(
1040				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1041				CircuitStepArgument::Expr(CircuitNode::Var(7))
1042			)
1043		));
1044		assert!(matches!(
1045			steps[2],
1046			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1047		));
1048		assert!(matches!(
1049			steps[3],
1050			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1051		));
1052		assert!(matches!(
1053			steps[4],
1054			CircuitStep::Mul(
1055				CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1056				CircuitStepArgument::Expr(CircuitNode::Var(7))
1057			)
1058		));
1059		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1060	}
1061
1062	#[test]
1063	fn test_circuit_steps_for_expr_complex() {
1064		type F = BinaryField8b;
1065
1066		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1067			+ (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1068			- ArithExpr::Var(3);
1069
1070		let (steps, retval) = circuit_steps_for_expr(&expr);
1071
1072		assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1073
1074		assert!(
1075			matches!(
1076				steps[0],
1077				CircuitStep::Mul(
1078					CircuitStepArgument::Expr(CircuitNode::Var(0)),
1079					CircuitStepArgument::Expr(CircuitNode::Var(1))
1080				)
1081			),
1082			"First step should be multiplication x0 * x1"
1083		);
1084
1085		assert!(
1086			matches!(
1087				steps[1],
1088				CircuitStep::Add(
1089					CircuitStepArgument::Const(F::ONE),
1090					CircuitStepArgument::Expr(CircuitNode::Var(0))
1091				)
1092			),
1093			"Second step should be (1 - x0)"
1094		);
1095
1096		assert!(
1097			matches!(
1098				steps[2],
1099				CircuitStep::AddMul(
1100					0,
1101					CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1102					CircuitStepArgument::Expr(CircuitNode::Var(2))
1103				)
1104			),
1105			"Third step should be (1 - x0) * x2"
1106		);
1107
1108		assert!(
1109			matches!(
1110				steps[3],
1111				CircuitStep::Add(
1112					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1113					CircuitStepArgument::Expr(CircuitNode::Var(3))
1114				)
1115			),
1116			"Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1117		);
1118
1119		assert!(
1120			matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1121			"Final result should be stored in Slot(3)"
1122		);
1123	}
1124}