binius_fast_compute/
arith_circuit.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithCircuit, ArithCircuitStep, CompositionPoly, Error, RowsBatchRef};
7use binius_utils::{
8	DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, bail,
9	mem::{slice_assume_init_mut, slice_assume_init_ref},
10};
11
12/// Convert the expression to a sequence of arithmetic operations that can be evaluated in sequence.
13fn convert_circuit_steps<F: Field>(
14	expr: &ArithCircuit<F>,
15) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
16	/// This struct is used to the steps in the original circuit and the converted one and back.
17	struct StepsMapping {
18		original_to_converted: Vec<Option<usize>>,
19		converted_to_original: Vec<Option<usize>>,
20	}
21
22	impl StepsMapping {
23		fn new(original_size: usize) -> Self {
24			Self {
25				original_to_converted: vec![None; original_size],
26				// The size of this vector isn't known at the start, but that's a reasonable guess
27				converted_to_original: vec![None; original_size],
28			}
29		}
30
31		fn register(&mut self, original_step: usize, converted_step: usize) {
32			self.original_to_converted[original_step] = Some(converted_step);
33
34			if converted_step >= self.converted_to_original.len() {
35				self.converted_to_original.resize(converted_step + 1, None);
36			}
37			self.converted_to_original[converted_step] = Some(original_step);
38		}
39
40		fn clear_step(&mut self, converted_step: usize) {
41			if self.converted_to_original.len() <= converted_step {
42				return;
43			}
44
45			if let Some(original_step) = self.converted_to_original[converted_step].take() {
46				self.original_to_converted[original_step] = None;
47			}
48		}
49
50		fn get_converted_step(&self, original_step: usize) -> Option<usize> {
51			self.original_to_converted[original_step]
52		}
53	}
54
55	fn convert_step<F: Field>(
56		original_step: usize,
57		original_steps: &[ArithCircuitStep<F>],
58		result: &mut Vec<CircuitStep<F>>,
59		node_to_step: &mut StepsMapping,
60	) -> CircuitStepArgument<F> {
61		if let Some(converted_step) = node_to_step.get_converted_step(original_step) {
62			return CircuitStepArgument::Expr(CircuitNode::Slot(converted_step));
63		}
64
65		match &original_steps[original_step] {
66			ArithCircuitStep::Const(constant) => CircuitStepArgument::Const(*constant),
67			ArithCircuitStep::Var(var) => CircuitStepArgument::Expr(CircuitNode::Var(*var)),
68			ArithCircuitStep::Add(left, right) => {
69				let left = convert_step(*left, original_steps, result, node_to_step);
70
71				if let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left {
72					if let ArithCircuitStep::Mul(mleft, mright) = &original_steps[*right] {
73						// Only handling e1 + (e2 * e3), not (e1 * e2) + e3, as latter was not
74						// observed in practice (the former can be enforced by rewriting
75						// expression)
76						let mleft = convert_step(*mleft, original_steps, result, node_to_step);
77						let mright = convert_step(*mright, original_steps, result, node_to_step);
78
79						// Since we'we changed the value of `left` to a new value, we need to clear
80						// the cache for it
81						node_to_step.clear_step(left);
82						node_to_step.register(original_step, left);
83						result.push(CircuitStep::AddMul(left, mleft, mright));
84
85						return CircuitStepArgument::Expr(CircuitNode::Slot(left));
86					}
87				}
88
89				let right = convert_step(*right, original_steps, result, node_to_step);
90
91				node_to_step.register(original_step, result.len());
92				result.push(CircuitStep::Add(left, right));
93				CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
94			}
95			ArithCircuitStep::Mul(left, right) => {
96				let left = convert_step(*left, original_steps, result, node_to_step);
97				let right = convert_step(*right, original_steps, result, node_to_step);
98
99				node_to_step.register(original_step, result.len());
100				result.push(CircuitStep::Mul(left, right));
101				CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
102			}
103			ArithCircuitStep::Pow(base, exp) => {
104				let mut acc = convert_step(*base, original_steps, result, node_to_step);
105				let base_expr = acc;
106				let highest_bit = exp.ilog2();
107
108				for i in (0..highest_bit).rev() {
109					if i == 0 {
110						node_to_step.register(original_step, result.len());
111					}
112					result.push(CircuitStep::Square(acc));
113					acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
114
115					if (exp >> i) & 1 != 0 {
116						result.push(CircuitStep::Mul(acc, base_expr));
117						acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
118					}
119				}
120
121				acc
122			}
123		}
124	}
125
126	let mut steps = Vec::new();
127	let mut steps_mapping = StepsMapping::new(expr.steps().len());
128	let ret = convert_step(expr.steps().len() - 1, expr.steps(), &mut steps, &mut steps_mapping);
129	(steps, ret)
130}
131
132/// Input of the circuit calculation step
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134enum CircuitNode {
135	/// Input variable
136	Var(usize),
137	/// Evaluation at one of the previous steps
138	Slot(usize),
139}
140
141impl CircuitNode {
142	/// Return either the input column or the slice with evaluations at one of the previous steps.
143	/// This method is used for batch evaluation.
144	fn get_sparse_chunk<'a, P: PackedField>(
145		&'a self,
146		inputs: &'a RowsBatchRef<'a, P>,
147		evals: &'a [P],
148		row_len: usize,
149	) -> &'a [P] {
150		match self {
151			Self::Var(index) => inputs.row(*index),
152			Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
153		}
154	}
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq)]
158enum CircuitStepArgument<F> {
159	Expr(CircuitNode),
160	Const(F),
161}
162
163/// Describes computation symbolically. This is used internally by ArithCircuitPoly.
164///
165/// ExprIds used by an Expr has to be less than the index of the Expr itself within the
166/// ArithCircuitPoly, to ensure it represents a directed acyclic graph that can be computed in
167/// sequence.
168#[derive(Debug)]
169enum CircuitStep<F: Field> {
170	Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
171	Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
172	Square(CircuitStepArgument<F>),
173	AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
174}
175
176/// Describes polynomial evaluations using a directed acyclic graph of expressions.
177///
178/// This is meant as an alternative to a hard-coded CompositionPoly.
179///
180/// The advantage over a hard coded CompositionPoly is that this can be constructed and manipulated
181/// dynamically at runtime and the object representing different polnomials can be stored in a
182/// homogeneous collection.
183#[derive(Debug, Clone)]
184pub struct ArithCircuitPoly<F: Field> {
185	expr: ArithCircuit<F>,
186	steps: Arc<[CircuitStep<F>]>,
187	/// The "top level expression", which depends on circuit expression evaluations
188	retval: CircuitStepArgument<F>,
189	degree: usize,
190	n_vars: usize,
191	tower_level: usize,
192}
193
194impl<F: Field> PartialEq for ArithCircuitPoly<F> {
195	fn eq(&self, other: &Self) -> bool {
196		self.n_vars == other.n_vars && self.expr == other.expr
197	}
198}
199
200impl<F: Field> Eq for ArithCircuitPoly<F> {}
201
202impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
203	fn serialize(
204		&self,
205		mut write_buf: impl bytes::BufMut,
206		mode: SerializationMode,
207	) -> Result<(), SerializationError> {
208		(&self.expr, self.n_vars).serialize(&mut write_buf, mode)
209	}
210}
211
212impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
213	fn deserialize(
214		read_buf: impl bytes::Buf,
215		mode: SerializationMode,
216	) -> Result<Self, SerializationError>
217	where
218		Self: Sized,
219	{
220		let (expr, n_vars) = <(ArithCircuit<F>, usize)>::deserialize(read_buf, mode)?;
221		Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
222			name: "ArithCircuitPoly",
223		})
224	}
225}
226
227impl<F: TowerField> ArithCircuitPoly<F> {
228	pub fn new(mut expr: ArithCircuit<F>) -> Self {
229		expr.optimize_in_place();
230
231		let degree = expr.degree();
232		let n_vars = expr.n_vars();
233		let tower_level = expr.binary_tower_level();
234		let (exprs, retval) = convert_circuit_steps(&expr);
235
236		Self {
237			expr,
238			steps: exprs.into(),
239			retval,
240			degree,
241			n_vars,
242			tower_level,
243		}
244	}
245	/// Constructs an [`ArithCircuitPoly`] with the given number of variables.
246	///
247	/// The number of variables may be greater than the number of variables actually read in the
248	/// arithmetic expression.
249	pub fn with_n_vars(n_vars: usize, mut expr: ArithCircuit<F>) -> Result<Self, Error> {
250		expr.optimize_in_place();
251
252		let degree = expr.degree();
253		let tower_level = expr.binary_tower_level();
254		if n_vars < expr.n_vars() {
255			return Err(Error::IncorrectNumberOfVariables {
256				expected: expr.n_vars(),
257				actual: n_vars,
258			});
259		}
260		let (steps, retval) = convert_circuit_steps(&expr);
261
262		Ok(Self {
263			expr,
264			steps: steps.into(),
265			retval,
266			n_vars,
267			degree,
268			tower_level,
269		})
270	}
271
272	/// Returns an underlying circuit from which this polynomial was constructed.
273	pub fn expr(&self) -> &ArithCircuit<F> {
274		&self.expr
275	}
276}
277
278impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
279	for ArithCircuitPoly<F>
280{
281	fn degree(&self) -> usize {
282		self.degree
283	}
284
285	fn n_vars(&self) -> usize {
286		self.n_vars
287	}
288
289	fn binary_tower_level(&self) -> usize {
290		self.tower_level
291	}
292
293	fn expression(&self) -> ArithCircuit<P::Scalar> {
294		self.expr.convert_field()
295	}
296
297	fn evaluate(&self, query: &[P]) -> Result<P, Error> {
298		if query.len() != self.n_vars {
299			return Err(Error::IncorrectQuerySize {
300				expected: self.n_vars,
301				actual: query.len(),
302			});
303		}
304
305		fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
306			// Safety: The index is guaranteed to be within bounds because
307			// we initialize at least `self.steps.len()` using `stackalloc`.
308			unsafe {
309				target.get_unchecked_mut(0).write(value);
310			}
311		}
312
313		alloc_scratch_space::<P, _, _>(self.steps.len(), |evals| {
314			let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
315				// Safety: The index is guaranteed to be within bounds by the construction of the
316				// circuit
317				CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
318					*query.get_unchecked(index)
319				},
320				// Safety: The index is guaranteed to be within bounds by the circuit evaluation
321				// order
322				CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
323					*evals.get_unchecked(slot)
324				},
325				CircuitStepArgument::Const(value) => P::broadcast(value.into()),
326			};
327
328			for (i, expr) in self.steps.iter().enumerate() {
329				// Safety: previous evaluations are initialized by the previous loop iterations (if
330				// dereferenced)
331				let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
332				let before = unsafe { slice_assume_init_mut(before) };
333				match expr {
334					CircuitStep::Add(x, y) => write_result(
335						after,
336						get_argument_value(*x, before) + get_argument_value(*y, before),
337					),
338					CircuitStep::AddMul(target_slot, x, y) => {
339						let intermediate =
340							get_argument_value(*x, before) * get_argument_value(*y, before);
341						// Safety: we know by evaluation order and construction of steps that
342						// `target.slot` is initialized
343						let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
344						*target_slot += intermediate;
345					}
346					CircuitStep::Mul(x, y) => write_result(
347						after,
348						get_argument_value(*x, before) * get_argument_value(*y, before),
349					),
350					CircuitStep::Square(x) => {
351						write_result(after, get_argument_value(*x, before).square())
352					}
353				};
354			}
355
356			// Some slots in `evals` might be empty, but we're guaranteed that
357			// if `self.retval` points to a slot, that this slot is initialized.
358			unsafe {
359				let evals = slice_assume_init_ref(evals);
360				Ok(get_argument_value(self.retval, evals))
361			}
362		})
363	}
364
365	fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
366		let row_len = evals.len();
367		if batch_query.row_len() != row_len {
368			bail!(Error::BatchEvaluateSizeMismatch {
369				expected: row_len,
370				actual: batch_query.row_len(),
371			});
372		}
373
374		alloc_scratch_space::<P, (), _>(self.steps.len() * row_len, |sparse_evals| {
375			for (i, expr) in self.steps.iter().enumerate() {
376				let (before, current) = sparse_evals.split_at_mut(i * row_len);
377
378				// Safety: `before` is guaranteed to be initialized by the previous loop iterations
379				// (if dereferenced).
380				let before = unsafe { slice_assume_init_mut(before) };
381				let current = &mut current[..row_len];
382
383				match expr {
384					CircuitStep::Add(left, right) => {
385						apply_binary_op(
386							left,
387							right,
388							batch_query,
389							before,
390							current,
391							|left, right, out| {
392								out.write(left + right);
393							},
394						);
395					}
396					CircuitStep::Mul(left, right) => {
397						apply_binary_op(
398							left,
399							right,
400							batch_query,
401							before,
402							current,
403							|left, right, out| {
404								out.write(left * right);
405							},
406						);
407					}
408					CircuitStep::Square(arg) => {
409						match arg {
410							CircuitStepArgument::Expr(node) => {
411								let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
412								for j in 0..row_len {
413									// Safety: `current` and `id_chunk` have length equal to
414									// `row_len`
415									unsafe {
416										current
417											.get_unchecked_mut(j)
418											.write(id_chunk.get_unchecked(j).square());
419									}
420								}
421							}
422							CircuitStepArgument::Const(value) => {
423								let value: P = P::broadcast((*value).into());
424								let result = value.square();
425								for j in 0..row_len {
426									// Safety: `current` has length equal to `row_len`
427									unsafe {
428										current.get_unchecked_mut(j).write(result);
429									}
430								}
431							}
432						}
433					}
434					CircuitStep::AddMul(target, left, right) => {
435						let target = &mut before[row_len * target..(target + 1) * row_len];
436						// Safety: by construction of steps and evaluation order we know
437						// that `target` is not borrowed elsewhere.
438						let target: &mut [MaybeUninit<P>] = unsafe {
439							std::slice::from_raw_parts_mut(
440								target.as_mut_ptr() as *mut MaybeUninit<P>,
441								target.len(),
442							)
443						};
444						apply_binary_op(
445							left,
446							right,
447							batch_query,
448							before,
449							target,
450							// Safety: by construction of steps and evaluation order we know
451							// that `target`/`out` is initialized.
452							|left, right, out| unsafe {
453								let out = out.assume_init_mut();
454								*out += left * right;
455							},
456						);
457					}
458				}
459			}
460
461			match self.retval {
462				CircuitStepArgument::Expr(node) => {
463					// Safety: `sparse_evals` is fully initialized by the previous loop iterations
464					let sparse_evals = unsafe { slice_assume_init_ref(sparse_evals) };
465					evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
466				}
467				CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
468			}
469		});
470
471		Ok(())
472	}
473}
474
475/// Apply a binary operation to two arguments and store the result in `current_evals`.
476/// `op` must be a function that takes two arguments and initialized the result with the third
477/// argument.
478fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
479	left: &CircuitStepArgument<F>,
480	right: &CircuitStepArgument<F>,
481	batch_query: &RowsBatchRef<P>,
482	evals_before: &[P],
483	current_evals: &mut [MaybeUninit<P>],
484	op: impl Fn(P, P, &mut MaybeUninit<P>),
485) {
486	let row_len = current_evals.len();
487
488	match (left, right) {
489		(CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
490			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
491			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
492			for j in 0..row_len {
493				// Safety: `current`, `left` and `right` have length equal to `row_len`
494				unsafe {
495					op(
496						*left.get_unchecked(j),
497						*right.get_unchecked(j),
498						current_evals.get_unchecked_mut(j),
499					)
500				}
501			}
502		}
503		(CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
504			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
505			let right = P::broadcast((*right).into());
506			for j in 0..row_len {
507				// Safety: `current` and `left` have length equal to `row_len`
508				unsafe {
509					op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
510				}
511			}
512		}
513		(CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
514			let left = P::broadcast((*left).into());
515			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
516			for j in 0..row_len {
517				// Safety: `current` and `right` have length equal to `row_len`
518				unsafe {
519					op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
520				}
521			}
522		}
523		(CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
524			let left = P::broadcast((*left).into());
525			let right = P::broadcast((*right).into());
526			let mut result = MaybeUninit::uninit();
527			op(left, right, &mut result);
528			for j in 0..row_len {
529				// Safety:
530				// - `current` has length equal to `row_len`
531				// - `result` is initialized by `op`
532				unsafe {
533					current_evals
534						.get_unchecked_mut(j)
535						.write(result.assume_init());
536				}
537			}
538		}
539	}
540}
541
542fn alloc_scratch_space<T, U, F>(size: usize, callback: F) -> U
543where
544	F: FnOnce(&mut [MaybeUninit<T>]) -> U,
545{
546	use std::mem;
547	// We don't want to deal with running destructors.
548	assert!(!mem::needs_drop::<T>());
549
550	#[cfg(miri)]
551	{
552		let mut scratch_space = Vec::<T>::with_capacity(size);
553		let out = callback(scratch_space.spare_capacity_mut());
554		drop(scratch_space);
555		out
556	}
557	#[cfg(not(miri))]
558	{
559		// `stackalloc_uninit` throws a debug assert if `size` is 0, so set minimum of 1.
560		let size = size.max(1);
561		stackalloc::stackalloc_uninit(size, callback)
562	}
563}
564
565#[cfg(test)]
566mod tests {
567	use binius_field::{
568		BinaryField8b, BinaryField16b, PackedBinaryField8x16b, PackedField, TowerField,
569	};
570	use binius_math::{ArithExpr, CompositionPoly, RowsBatch};
571	use binius_utils::felts;
572
573	use super::*;
574
575	#[test]
576	fn test_constant() {
577		type F = BinaryField8b;
578		type P = PackedBinaryField8x16b;
579
580		let expr = ArithExpr::Const(F::new(123));
581		let circuit = ArithCircuitPoly::<F>::new(expr.into());
582
583		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
584		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
585		assert_eq!(typed_circuit.degree(), 0);
586		assert_eq!(typed_circuit.n_vars(), 0);
587
588		assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
589
590		let mut evals = [P::default()];
591		typed_circuit
592			.batch_evaluate(&RowsBatchRef::new(&[], 1), &mut evals)
593			.unwrap();
594		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
595	}
596
597	#[test]
598	fn test_identity() {
599		type F = BinaryField8b;
600		type P = PackedBinaryField8x16b;
601
602		// x0
603		let expr = ArithExpr::Var(0);
604		let circuit = ArithCircuitPoly::<F>::new(expr.into());
605
606		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
607		assert_eq!(typed_circuit.binary_tower_level(), 0);
608		assert_eq!(typed_circuit.degree(), 1);
609		assert_eq!(typed_circuit.n_vars(), 1);
610
611		assert_eq!(
612			typed_circuit
613				.evaluate(&[P::broadcast(F::new(123).into())])
614				.unwrap(),
615			P::broadcast(F::new(123).into())
616		);
617
618		let mut evals = [P::default()];
619		let batch_query = [[P::broadcast(F::new(123).into())]; 1];
620		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
621		typed_circuit
622			.batch_evaluate(&batch_query.get_ref(), &mut evals)
623			.unwrap();
624		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
625	}
626
627	#[test]
628	fn test_add() {
629		type F = BinaryField8b;
630		type P = PackedBinaryField8x16b;
631
632		// 123 + x0
633		let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
634		let circuit = ArithCircuitPoly::<F>::new(expr.into());
635
636		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
637		assert_eq!(typed_circuit.binary_tower_level(), 3);
638		assert_eq!(typed_circuit.degree(), 1);
639		assert_eq!(typed_circuit.n_vars(), 1);
640
641		assert_eq!(
642			CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
643			P::broadcast(F::new(123).into())
644		);
645	}
646
647	#[test]
648	fn test_mul() {
649		type F = BinaryField8b;
650		type P = PackedBinaryField8x16b;
651
652		// 123 * x0
653		let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
654		let circuit = ArithCircuitPoly::<F>::new(expr.into());
655
656		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
657		assert_eq!(typed_circuit.binary_tower_level(), 3);
658		assert_eq!(typed_circuit.degree(), 1);
659		assert_eq!(typed_circuit.n_vars(), 1);
660
661		assert_eq!(
662			CompositionPoly::evaluate(
663				&circuit,
664				&[P::from_scalars(
665					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
666				)]
667			)
668			.unwrap(),
669			P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
670		);
671	}
672
673	#[test]
674	fn test_pow() {
675		type F = BinaryField8b;
676		type P = PackedBinaryField8x16b;
677
678		// x0^13
679		let expr = ArithExpr::Var(0).pow(13);
680		let circuit = ArithCircuitPoly::<F>::new(expr.into());
681
682		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
683		assert_eq!(typed_circuit.binary_tower_level(), 0);
684		assert_eq!(typed_circuit.degree(), 13);
685		assert_eq!(typed_circuit.n_vars(), 1);
686
687		assert_eq!(
688			CompositionPoly::evaluate(
689				&circuit,
690				&[P::from_scalars(
691					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
692				)]
693			)
694			.unwrap(),
695			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
696		);
697	}
698
699	#[test]
700	fn test_mixed() {
701		type F = BinaryField8b;
702		type P = PackedBinaryField8x16b;
703
704		// x0^2 * (x1 + 123)
705		let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
706		let circuit = ArithCircuitPoly::<F>::new(expr.into());
707
708		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
709		assert_eq!(typed_circuit.binary_tower_level(), 3);
710		assert_eq!(typed_circuit.degree(), 3);
711		assert_eq!(typed_circuit.n_vars(), 2);
712
713		// test evaluate
714		assert_eq!(
715			CompositionPoly::evaluate(
716				&circuit,
717				&[
718					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
719					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
720				]
721			)
722			.unwrap(),
723			P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
724		);
725
726		// test batch evaluate
727		let query1 = &[
728			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
729			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
730		];
731		let query2 = &[
732			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
733			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
734		];
735		let query3 = &[
736			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
737			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
738		];
739		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
740		let expected2 =
741			P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
742		let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
743
744		let mut batch_result = vec![P::zero(); 3];
745		let batch_query = &[
746			&[query1[0], query2[0], query3[0]],
747			&[query1[1], query2[1], query3[1]],
748		];
749		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
750
751		CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
752			.unwrap();
753		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
754	}
755
756	#[test]
757	fn batch_evaluate_add_mul() {
758		// This test is focused on exposing the the currently present stacked borrows violation. It
759		// passes but still triggers `miri`.
760
761		type F = BinaryField8b;
762		type P = PackedBinaryField8x16b;
763
764		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(0))
765			+ (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(0)
766			- ArithExpr::Var(0);
767		let circuit = ArithCircuitPoly::<F>::new(expr.into());
768
769		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
770		assert_eq!(typed_circuit.binary_tower_level(), 0);
771		assert_eq!(typed_circuit.degree(), 2);
772		assert_eq!(typed_circuit.n_vars(), 1);
773
774		let mut evals = [P::default(); 1];
775		let batch_query = [[P::broadcast(F::new(1).into())]; 1];
776		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
777		typed_circuit
778			.batch_evaluate(&batch_query.get_ref(), &mut evals)
779			.unwrap();
780	}
781
782	#[test]
783	fn test_const_fold() {
784		type F = BinaryField8b;
785		type P = PackedBinaryField8x16b;
786
787		// x0 * ((122 * 123) + (124 + 125)) + x1
788		let expr = ArithExpr::Var(0)
789			* ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
790				+ (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
791			+ ArithExpr::Var(1);
792		let circuit = ArithCircuitPoly::<F>::new(expr.into());
793		assert_eq!(circuit.steps.len(), 2);
794
795		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
796		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
797		assert_eq!(typed_circuit.degree(), 1);
798		assert_eq!(typed_circuit.n_vars(), 2);
799
800		// test evaluate
801		assert_eq!(
802			CompositionPoly::evaluate(
803				&circuit,
804				&[
805					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
806					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
807				]
808			)
809			.unwrap(),
810			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
811		);
812
813		// test batch evaluate
814		let query1 = &[
815			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
816			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
817		];
818		let query2 = &[
819			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
820			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
821		];
822		let query3 = &[
823			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
824			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
825		];
826		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
827		let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
828		let expected3 =
829			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
830
831		let mut batch_result = vec![P::zero(); 3];
832		let batch_query = &[
833			&[query1[0], query2[0], query3[0]],
834			&[query1[1], query2[1], query3[1]],
835		];
836		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
837		CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
838			.unwrap();
839		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
840	}
841
842	#[test]
843	fn test_pow_const_fold() {
844		type F = BinaryField8b;
845		type P = PackedBinaryField8x16b;
846
847		// x0 + 2^5
848		let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
849		let circuit = ArithCircuitPoly::<F>::new(expr.into());
850		assert_eq!(circuit.steps.len(), 1);
851
852		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
853		assert_eq!(typed_circuit.binary_tower_level(), 1);
854		assert_eq!(typed_circuit.degree(), 1);
855		assert_eq!(typed_circuit.n_vars(), 1);
856
857		assert_eq!(
858			CompositionPoly::evaluate(
859				&circuit,
860				&[P::from_scalars(
861					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
862				)]
863			)
864			.unwrap(),
865			P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
866		);
867	}
868
869	#[test]
870	fn test_pow_nested() {
871		type F = BinaryField8b;
872		type P = PackedBinaryField8x16b;
873
874		// ((x0^2)^3)^4
875		let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
876		let circuit = ArithCircuitPoly::<F>::new(expr.into());
877		assert_eq!(circuit.steps.len(), 5);
878
879		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
880		assert_eq!(typed_circuit.binary_tower_level(), 0);
881		assert_eq!(typed_circuit.degree(), 24);
882		assert_eq!(typed_circuit.n_vars(), 1);
883
884		assert_eq!(
885			CompositionPoly::evaluate(
886				&circuit,
887				&[P::from_scalars(
888					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
889				)]
890			)
891			.unwrap(),
892			P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
893		);
894	}
895
896	#[test]
897	fn test_circuit_steps_for_expr_constant() {
898		type F = BinaryField8b;
899
900		let expr = ArithExpr::Const(F::new(5));
901		let (steps, retval) = convert_circuit_steps(&expr.into());
902
903		assert!(steps.is_empty(), "No steps should be generated for a constant");
904		assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
905	}
906
907	#[test]
908	fn test_circuit_steps_for_expr_variable() {
909		type F = BinaryField8b;
910
911		let expr = ArithExpr::<F>::Var(18);
912		let (steps, retval) = convert_circuit_steps(&expr.into());
913
914		assert!(steps.is_empty(), "No steps should be generated for a variable");
915		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
916	}
917
918	#[test]
919	fn test_circuit_steps_for_expr_addition() {
920		type F = BinaryField8b;
921
922		let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
923		let (steps, retval) = convert_circuit_steps(&expr.into());
924
925		assert_eq!(steps.len(), 1, "One addition step should be generated");
926		assert!(matches!(
927			steps[0],
928			CircuitStep::Add(
929				CircuitStepArgument::Expr(CircuitNode::Var(14)),
930				CircuitStepArgument::Expr(CircuitNode::Var(56))
931			)
932		));
933		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
934	}
935
936	#[test]
937	fn test_circuit_steps_for_expr_multiplication() {
938		type F = BinaryField8b;
939
940		let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
941		let (steps, retval) = convert_circuit_steps(&expr.into());
942
943		assert_eq!(steps.len(), 1, "One multiplication step should be generated");
944		assert!(matches!(
945			steps[0],
946			CircuitStep::Mul(
947				CircuitStepArgument::Expr(CircuitNode::Var(36)),
948				CircuitStepArgument::Expr(CircuitNode::Var(26))
949			)
950		));
951		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
952	}
953
954	#[test]
955	fn test_circuit_steps_for_expr_pow_1() {
956		type F = BinaryField8b;
957
958		let expr = ArithExpr::<F>::Var(12).pow(1);
959		let (steps, retval) = convert_circuit_steps(&expr.into());
960
961		// No steps should be generated for x^1
962		assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
963
964		// The return value should just be the variable itself
965		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
966	}
967
968	#[test]
969	fn test_circuit_steps_for_expr_pow_2() {
970		type F = BinaryField8b;
971
972		let expr = ArithExpr::<F>::Var(10).pow(2);
973		let (steps, retval) = convert_circuit_steps(&expr.into());
974
975		assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
976		assert!(matches!(
977			steps[0],
978			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
979		));
980		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
981	}
982
983	#[test]
984	fn test_circuit_steps_for_expr_pow_3() {
985		type F = BinaryField8b;
986
987		let expr = ArithExpr::<F>::Var(5).pow(3);
988		let (steps, retval) = convert_circuit_steps(&expr.into());
989
990		assert_eq!(
991			steps.len(),
992			2,
993			"Pow(3) should generate one squaring and one multiplication step"
994		);
995		assert!(matches!(
996			steps[0],
997			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
998		));
999		assert!(matches!(
1000			steps[1],
1001			CircuitStep::Mul(
1002				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1003				CircuitStepArgument::Expr(CircuitNode::Var(5))
1004			)
1005		));
1006		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
1007	}
1008
1009	#[test]
1010	fn test_circuit_steps_for_expr_pow_4() {
1011		type F = BinaryField8b;
1012
1013		let expr = ArithExpr::<F>::Var(7).pow(4);
1014		let (steps, retval) = convert_circuit_steps(&expr.into());
1015
1016		assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
1017		assert!(matches!(
1018			steps[0],
1019			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1020		));
1021
1022		assert!(matches!(
1023			steps[1],
1024			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1025		));
1026
1027		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
1028	}
1029
1030	#[test]
1031	fn test_circuit_steps_for_expr_pow_5() {
1032		type F = BinaryField8b;
1033
1034		let expr = ArithExpr::<F>::Var(3).pow(5);
1035		let (steps, retval) = convert_circuit_steps(&expr.into());
1036
1037		assert_eq!(
1038			steps.len(),
1039			3,
1040			"Pow(5) should generate two squaring steps and one multiplication"
1041		);
1042		assert!(matches!(
1043			steps[0],
1044			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
1045		));
1046		assert!(matches!(
1047			steps[1],
1048			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1049		));
1050		assert!(matches!(
1051			steps[2],
1052			CircuitStep::Mul(
1053				CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1054				CircuitStepArgument::Expr(CircuitNode::Var(3))
1055			)
1056		));
1057
1058		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1059	}
1060
1061	#[test]
1062	fn test_circuit_steps_for_expr_pow_8() {
1063		type F = BinaryField8b;
1064
1065		let expr = ArithExpr::<F>::Var(4).pow(8);
1066		let (steps, retval) = convert_circuit_steps(&expr.into());
1067
1068		assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
1069		assert!(matches!(
1070			steps[0],
1071			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
1072		));
1073		assert!(matches!(
1074			steps[1],
1075			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1076		));
1077		assert!(matches!(
1078			steps[2],
1079			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1080		));
1081
1082		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1083	}
1084
1085	#[test]
1086	fn test_circuit_steps_for_expr_pow_9() {
1087		type F = BinaryField8b;
1088
1089		let expr = ArithExpr::<F>::Var(8).pow(9);
1090		let (steps, retval) = convert_circuit_steps(&expr.into());
1091
1092		assert_eq!(
1093			steps.len(),
1094			4,
1095			"Pow(9) should generate three squaring steps and one multiplication"
1096		);
1097		assert!(matches!(
1098			steps[0],
1099			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
1100		));
1101		assert!(matches!(
1102			steps[1],
1103			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1104		));
1105		assert!(matches!(
1106			steps[2],
1107			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1108		));
1109		assert!(matches!(
1110			steps[3],
1111			CircuitStep::Mul(
1112				CircuitStepArgument::Expr(CircuitNode::Slot(2)),
1113				CircuitStepArgument::Expr(CircuitNode::Var(8))
1114			)
1115		));
1116
1117		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1118	}
1119
1120	#[test]
1121	fn test_circuit_steps_for_expr_pow_12() {
1122		type F = BinaryField8b;
1123		let expr = ArithExpr::<F>::Var(6).pow(12);
1124		let (steps, retval) = convert_circuit_steps(&expr.into());
1125
1126		assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1127
1128		assert!(matches!(
1129			steps[0],
1130			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1131		));
1132		assert!(matches!(
1133			steps[1],
1134			CircuitStep::Mul(
1135				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1136				CircuitStepArgument::Expr(CircuitNode::Var(6))
1137			)
1138		));
1139		assert!(matches!(
1140			steps[2],
1141			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1142		));
1143		assert!(matches!(
1144			steps[3],
1145			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1146		));
1147
1148		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1149	}
1150
1151	#[test]
1152	fn test_circuit_steps_for_expr_pow_13() {
1153		type F = BinaryField8b;
1154		let expr = ArithExpr::<F>::Var(7).pow(13);
1155		let (steps, retval) = convert_circuit_steps(&expr.into());
1156
1157		assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1158		assert!(matches!(
1159			steps[0],
1160			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1161		));
1162		assert!(matches!(
1163			steps[1],
1164			CircuitStep::Mul(
1165				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1166				CircuitStepArgument::Expr(CircuitNode::Var(7))
1167			)
1168		));
1169		assert!(matches!(
1170			steps[2],
1171			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1172		));
1173		assert!(matches!(
1174			steps[3],
1175			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1176		));
1177		assert!(matches!(
1178			steps[4],
1179			CircuitStep::Mul(
1180				CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1181				CircuitStepArgument::Expr(CircuitNode::Var(7))
1182			)
1183		));
1184		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1185	}
1186
1187	#[test]
1188	fn test_circuit_steps_for_expr_complex() {
1189		type F = BinaryField8b;
1190
1191		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1192			+ (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1193			- ArithExpr::Var(3);
1194
1195		let (steps, retval) = convert_circuit_steps(&expr.into());
1196
1197		assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1198
1199		assert!(
1200			matches!(
1201				steps[0],
1202				CircuitStep::Mul(
1203					CircuitStepArgument::Expr(CircuitNode::Var(0)),
1204					CircuitStepArgument::Expr(CircuitNode::Var(1))
1205				)
1206			),
1207			"First step should be multiplication x0 * x1"
1208		);
1209
1210		assert!(
1211			matches!(
1212				steps[1],
1213				CircuitStep::Add(
1214					CircuitStepArgument::Const(F::ONE),
1215					CircuitStepArgument::Expr(CircuitNode::Var(0))
1216				)
1217			),
1218			"Second step should be (1 - x0)"
1219		);
1220
1221		assert!(
1222			matches!(
1223				steps[2],
1224				CircuitStep::AddMul(
1225					0,
1226					CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1227					CircuitStepArgument::Expr(CircuitNode::Var(2))
1228				)
1229			),
1230			"Third step should be (1 - x0) * x2"
1231		);
1232
1233		assert!(
1234			matches!(
1235				steps[3],
1236				CircuitStep::Add(
1237					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1238					CircuitStepArgument::Expr(CircuitNode::Var(3))
1239				)
1240			),
1241			"Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1242		);
1243
1244		assert!(
1245			matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1246			"Final result should be stored in Slot(3)"
1247		);
1248	}
1249
1250	#[test]
1251	fn check_deduplication_in_steps() {
1252		type F = BinaryField8b;
1253
1254		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1255			+ (ArithExpr::<F>::Var(0) * ArithExpr::Var(1)) * ArithExpr::Var(2)
1256			- ArithExpr::Var(3);
1257		let expr = ArithCircuit::from(&expr);
1258		let expr = expr.optimize();
1259
1260		let (steps, retval) = convert_circuit_steps(&expr);
1261
1262		assert_eq!(steps.len(), 3, "Expression should generate 3 computation steps");
1263
1264		assert!(
1265			matches!(
1266				steps[0],
1267				CircuitStep::Mul(
1268					CircuitStepArgument::Expr(CircuitNode::Var(0)),
1269					CircuitStepArgument::Expr(CircuitNode::Var(1))
1270				)
1271			),
1272			"First step should be multiplication x0 * x1"
1273		);
1274
1275		assert!(
1276			matches!(
1277				steps[1],
1278				CircuitStep::AddMul(
1279					0,
1280					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1281					CircuitStepArgument::Expr(CircuitNode::Var(2))
1282				)
1283			),
1284			"Second step should be (x0 * x1) * x2"
1285		);
1286
1287		assert!(
1288			matches!(
1289				steps[2],
1290				CircuitStep::Add(
1291					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1292					CircuitStepArgument::Expr(CircuitNode::Var(3))
1293				)
1294			),
1295			"Third step should be x0 * x1 + (x0 * x1) * x2 + x3"
1296		);
1297
1298		assert!(
1299			matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))),
1300			"Final result should be stored in Slot(2)"
1301		);
1302	}
1303}