binius_fast_compute/
arith_circuit.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithCircuit, ArithCircuitStep, CompositionPoly, Error, RowsBatchRef};
7use binius_utils::{
8	DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, bail,
9	mem::{slice_assume_init_mut, slice_assume_init_ref},
10};
11
12/// Convert the expression to a sequence of arithmetic operations that can be evaluated in sequence.
13fn convert_circuit_steps<F: Field>(
14	expr: &ArithCircuit<F>,
15) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
16	/// This struct is used to the steps in the original circuit and the converted one and back.
17	struct StepsMapping {
18		original_to_converted: Vec<Option<usize>>,
19		converted_to_original: Vec<Option<usize>>,
20	}
21
22	impl StepsMapping {
23		fn new(original_size: usize) -> Self {
24			Self {
25				original_to_converted: vec![None; original_size],
26				// The size of this vector isn't known at the start, but that's a reasonable guess
27				converted_to_original: vec![None; original_size],
28			}
29		}
30
31		fn register(&mut self, original_step: usize, converted_step: usize) {
32			self.original_to_converted[original_step] = Some(converted_step);
33
34			if converted_step >= self.converted_to_original.len() {
35				self.converted_to_original.resize(converted_step + 1, None);
36			}
37			self.converted_to_original[converted_step] = Some(original_step);
38		}
39
40		fn clear_step(&mut self, converted_step: usize) {
41			if self.converted_to_original.len() <= converted_step {
42				return;
43			}
44
45			if let Some(original_step) = self.converted_to_original[converted_step].take() {
46				self.original_to_converted[original_step] = None;
47			}
48		}
49
50		fn get_converted_step(&self, original_step: usize) -> Option<usize> {
51			self.original_to_converted[original_step]
52		}
53	}
54
55	fn convert_step<F: Field>(
56		original_step: usize,
57		original_steps: &[ArithCircuitStep<F>],
58		result: &mut Vec<CircuitStep<F>>,
59		node_to_step: &mut StepsMapping,
60	) -> CircuitStepArgument<F> {
61		if let Some(converted_step) = node_to_step.get_converted_step(original_step) {
62			return CircuitStepArgument::Expr(CircuitNode::Slot(converted_step));
63		}
64
65		match &original_steps[original_step] {
66			ArithCircuitStep::Const(constant) => CircuitStepArgument::Const(*constant),
67			ArithCircuitStep::Var(var) => CircuitStepArgument::Expr(CircuitNode::Var(*var)),
68			ArithCircuitStep::Add(left, right) => {
69				let left = convert_step(*left, original_steps, result, node_to_step);
70
71				if let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left {
72					if let ArithCircuitStep::Mul(mleft, mright) = &original_steps[*right] {
73						// Only handling e1 + (e2 * e3), not (e1 * e2) + e3, as latter was not
74						// observed in practice (the former can be enforced by rewriting
75						// expression)
76						let mleft = convert_step(*mleft, original_steps, result, node_to_step);
77						let mright = convert_step(*mright, original_steps, result, node_to_step);
78
79						// Since we'we changed the value of `left` to a new value, we need to clear
80						// the cache for it
81						node_to_step.clear_step(left);
82						node_to_step.register(original_step, left);
83						result.push(CircuitStep::AddMul(left, mleft, mright));
84
85						return CircuitStepArgument::Expr(CircuitNode::Slot(left));
86					}
87				}
88
89				let right = convert_step(*right, original_steps, result, node_to_step);
90
91				node_to_step.register(original_step, result.len());
92				result.push(CircuitStep::Add(left, right));
93				CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
94			}
95			ArithCircuitStep::Mul(left, right) => {
96				let left = convert_step(*left, original_steps, result, node_to_step);
97				let right = convert_step(*right, original_steps, result, node_to_step);
98
99				node_to_step.register(original_step, result.len());
100				result.push(CircuitStep::Mul(left, right));
101				CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
102			}
103			ArithCircuitStep::Pow(base, exp) => {
104				let mut acc = convert_step(*base, original_steps, result, node_to_step);
105				let base_expr = acc;
106				let highest_bit = exp.ilog2();
107
108				for i in (0..highest_bit).rev() {
109					if i == 0 {
110						node_to_step.register(original_step, result.len());
111					}
112					result.push(CircuitStep::Square(acc));
113					acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
114
115					if (exp >> i) & 1 != 0 {
116						result.push(CircuitStep::Mul(acc, base_expr));
117						acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
118					}
119				}
120
121				acc
122			}
123		}
124	}
125
126	let mut steps = Vec::new();
127	let mut steps_mapping = StepsMapping::new(expr.steps().len());
128	let ret = convert_step(expr.steps().len() - 1, expr.steps(), &mut steps, &mut steps_mapping);
129	(steps, ret)
130}
131
132/// Input of the circuit calculation step
133#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134enum CircuitNode {
135	/// Input variable
136	Var(usize),
137	/// Evaluation at one of the previous steps
138	Slot(usize),
139}
140
141impl CircuitNode {
142	/// Return either the input column or the slice with evaluations at one of the previous steps.
143	/// This method is used for batch evaluation.
144	fn get_sparse_chunk<'a, P: PackedField>(
145		&'a self,
146		inputs: &'a RowsBatchRef<'a, P>,
147		evals: &'a [P],
148		row_len: usize,
149	) -> &'a [P] {
150		match self {
151			Self::Var(index) => inputs.row(*index),
152			Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
153		}
154	}
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq)]
158enum CircuitStepArgument<F> {
159	Expr(CircuitNode),
160	Const(F),
161}
162
163/// Describes computation symbolically. This is used internally by ArithCircuitPoly.
164///
165/// ExprIds used by an Expr has to be less than the index of the Expr itself within the
166/// ArithCircuitPoly, to ensure it represents a directed acyclic graph that can be computed in
167/// sequence.
168#[derive(Debug)]
169enum CircuitStep<F: Field> {
170	Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
171	Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
172	Square(CircuitStepArgument<F>),
173	AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
174}
175
176/// Describes polynomial evaluations using a directed acyclic graph of expressions.
177///
178/// This is meant as an alternative to a hard-coded CompositionPoly.
179///
180/// The advantage over a hard coded CompositionPoly is that this can be constructed and manipulated
181/// dynamically at runtime and the object representing different polnomials can be stored in a
182/// homogeneous collection.
183#[derive(Debug, Clone)]
184pub struct ArithCircuitPoly<F: Field> {
185	expr: ArithCircuit<F>,
186	steps: Arc<[CircuitStep<F>]>,
187	/// The "top level expression", which depends on circuit expression evaluations
188	retval: CircuitStepArgument<F>,
189	degree: usize,
190	n_vars: usize,
191	tower_level: usize,
192}
193
194impl<F: Field> PartialEq for ArithCircuitPoly<F> {
195	fn eq(&self, other: &Self) -> bool {
196		self.n_vars == other.n_vars && self.expr == other.expr
197	}
198}
199
200impl<F: Field> Eq for ArithCircuitPoly<F> {}
201
202impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
203	fn serialize(
204		&self,
205		mut write_buf: impl bytes::BufMut,
206		mode: SerializationMode,
207	) -> Result<(), SerializationError> {
208		(&self.expr, self.n_vars).serialize(&mut write_buf, mode)
209	}
210}
211
212impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
213	fn deserialize(
214		read_buf: impl bytes::Buf,
215		mode: SerializationMode,
216	) -> Result<Self, SerializationError>
217	where
218		Self: Sized,
219	{
220		let (expr, n_vars) = <(ArithCircuit<F>, usize)>::deserialize(read_buf, mode)?;
221		Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
222			name: "ArithCircuitPoly",
223		})
224	}
225}
226
227impl<F: TowerField> ArithCircuitPoly<F> {
228	pub fn new(mut expr: ArithCircuit<F>) -> Self {
229		expr.optimize_in_place();
230
231		let degree = expr.degree();
232		let n_vars = expr.n_vars();
233		let tower_level = expr.binary_tower_level();
234		let (exprs, retval) = convert_circuit_steps(&expr);
235
236		Self {
237			expr,
238			steps: exprs.into(),
239			retval,
240			degree,
241			n_vars,
242			tower_level,
243		}
244	}
245	/// Constructs an [`ArithCircuitPoly`] with the given number of variables.
246	///
247	/// The number of variables may be greater than the number of variables actually read in the
248	/// arithmetic expression.
249	pub fn with_n_vars(n_vars: usize, mut expr: ArithCircuit<F>) -> Result<Self, Error> {
250		expr.optimize_in_place();
251
252		let degree = expr.degree();
253		let tower_level = expr.binary_tower_level();
254		if n_vars < expr.n_vars() {
255			return Err(Error::IncorrectNumberOfVariables {
256				expected: expr.n_vars(),
257				actual: n_vars,
258			});
259		}
260		let (steps, retval) = convert_circuit_steps(&expr);
261
262		Ok(Self {
263			expr,
264			steps: steps.into(),
265			retval,
266			n_vars,
267			degree,
268			tower_level,
269		})
270	}
271}
272
273impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
274	for ArithCircuitPoly<F>
275{
276	fn degree(&self) -> usize {
277		self.degree
278	}
279
280	fn n_vars(&self) -> usize {
281		self.n_vars
282	}
283
284	fn binary_tower_level(&self) -> usize {
285		self.tower_level
286	}
287
288	fn expression(&self) -> ArithCircuit<P::Scalar> {
289		self.expr.convert_field()
290	}
291
292	fn evaluate(&self, query: &[P]) -> Result<P, Error> {
293		if query.len() != self.n_vars {
294			return Err(Error::IncorrectQuerySize {
295				expected: self.n_vars,
296				actual: query.len(),
297			});
298		}
299
300		fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
301			// Safety: The index is guaranteed to be within bounds because
302			// we initialize at least `self.steps.len()` using `stackalloc`.
303			unsafe {
304				target.get_unchecked_mut(0).write(value);
305			}
306		}
307
308		alloc_scratch_space::<P, _, _>(self.steps.len(), |evals| {
309			let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
310				// Safety: The index is guaranteed to be within bounds by the construction of the
311				// circuit
312				CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
313					*query.get_unchecked(index)
314				},
315				// Safety: The index is guaranteed to be within bounds by the circuit evaluation
316				// order
317				CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
318					*evals.get_unchecked(slot)
319				},
320				CircuitStepArgument::Const(value) => P::broadcast(value.into()),
321			};
322
323			for (i, expr) in self.steps.iter().enumerate() {
324				// Safety: previous evaluations are initialized by the previous loop iterations (if
325				// dereferenced)
326				let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
327				let before = unsafe { slice_assume_init_mut(before) };
328				match expr {
329					CircuitStep::Add(x, y) => write_result(
330						after,
331						get_argument_value(*x, before) + get_argument_value(*y, before),
332					),
333					CircuitStep::AddMul(target_slot, x, y) => {
334						let intermediate =
335							get_argument_value(*x, before) * get_argument_value(*y, before);
336						// Safety: we know by evaluation order and construction of steps that
337						// `target.slot` is initialized
338						let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
339						*target_slot += intermediate;
340					}
341					CircuitStep::Mul(x, y) => write_result(
342						after,
343						get_argument_value(*x, before) * get_argument_value(*y, before),
344					),
345					CircuitStep::Square(x) => {
346						write_result(after, get_argument_value(*x, before).square())
347					}
348				};
349			}
350
351			// Some slots in `evals` might be empty, but we're guaranteed that
352			// if `self.retval` points to a slot, that this slot is initialized.
353			unsafe {
354				let evals = slice_assume_init_ref(evals);
355				Ok(get_argument_value(self.retval, evals))
356			}
357		})
358	}
359
360	fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
361		let row_len = evals.len();
362		if batch_query.row_len() != row_len {
363			bail!(Error::BatchEvaluateSizeMismatch {
364				expected: row_len,
365				actual: batch_query.row_len(),
366			});
367		}
368
369		alloc_scratch_space::<P, (), _>(self.steps.len() * row_len, |sparse_evals| {
370			for (i, expr) in self.steps.iter().enumerate() {
371				let (before, current) = sparse_evals.split_at_mut(i * row_len);
372
373				// Safety: `before` is guaranteed to be initialized by the previous loop iterations
374				// (if dereferenced).
375				let before = unsafe { slice_assume_init_mut(before) };
376				let current = &mut current[..row_len];
377
378				match expr {
379					CircuitStep::Add(left, right) => {
380						apply_binary_op(
381							left,
382							right,
383							batch_query,
384							before,
385							current,
386							|left, right, out| {
387								out.write(left + right);
388							},
389						);
390					}
391					CircuitStep::Mul(left, right) => {
392						apply_binary_op(
393							left,
394							right,
395							batch_query,
396							before,
397							current,
398							|left, right, out| {
399								out.write(left * right);
400							},
401						);
402					}
403					CircuitStep::Square(arg) => {
404						match arg {
405							CircuitStepArgument::Expr(node) => {
406								let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
407								for j in 0..row_len {
408									// Safety: `current` and `id_chunk` have length equal to
409									// `row_len`
410									unsafe {
411										current
412											.get_unchecked_mut(j)
413											.write(id_chunk.get_unchecked(j).square());
414									}
415								}
416							}
417							CircuitStepArgument::Const(value) => {
418								let value: P = P::broadcast((*value).into());
419								let result = value.square();
420								for j in 0..row_len {
421									// Safety: `current` has length equal to `row_len`
422									unsafe {
423										current.get_unchecked_mut(j).write(result);
424									}
425								}
426							}
427						}
428					}
429					CircuitStep::AddMul(target, left, right) => {
430						let target = &mut before[row_len * target..(target + 1) * row_len];
431						// Safety: by construction of steps and evaluation order we know
432						// that `target` is not borrowed elsewhere.
433						let target: &mut [MaybeUninit<P>] = unsafe {
434							std::slice::from_raw_parts_mut(
435								target.as_mut_ptr() as *mut MaybeUninit<P>,
436								target.len(),
437							)
438						};
439						apply_binary_op(
440							left,
441							right,
442							batch_query,
443							before,
444							target,
445							// Safety: by construction of steps and evaluation order we know
446							// that `target`/`out` is initialized.
447							|left, right, out| unsafe {
448								let out = out.assume_init_mut();
449								*out += left * right;
450							},
451						);
452					}
453				}
454			}
455
456			match self.retval {
457				CircuitStepArgument::Expr(node) => {
458					// Safety: `sparse_evals` is fully initialized by the previous loop iterations
459					let sparse_evals = unsafe { slice_assume_init_ref(sparse_evals) };
460					evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
461				}
462				CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
463			}
464		});
465
466		Ok(())
467	}
468}
469
470/// Apply a binary operation to two arguments and store the result in `current_evals`.
471/// `op` must be a function that takes two arguments and initialized the result with the third
472/// argument.
473fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
474	left: &CircuitStepArgument<F>,
475	right: &CircuitStepArgument<F>,
476	batch_query: &RowsBatchRef<P>,
477	evals_before: &[P],
478	current_evals: &mut [MaybeUninit<P>],
479	op: impl Fn(P, P, &mut MaybeUninit<P>),
480) {
481	let row_len = current_evals.len();
482
483	match (left, right) {
484		(CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
485			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
486			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
487			for j in 0..row_len {
488				// Safety: `current`, `left` and `right` have length equal to `row_len`
489				unsafe {
490					op(
491						*left.get_unchecked(j),
492						*right.get_unchecked(j),
493						current_evals.get_unchecked_mut(j),
494					)
495				}
496			}
497		}
498		(CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
499			let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
500			let right = P::broadcast((*right).into());
501			for j in 0..row_len {
502				// Safety: `current` and `left` have length equal to `row_len`
503				unsafe {
504					op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
505				}
506			}
507		}
508		(CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
509			let left = P::broadcast((*left).into());
510			let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
511			for j in 0..row_len {
512				// Safety: `current` and `right` have length equal to `row_len`
513				unsafe {
514					op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
515				}
516			}
517		}
518		(CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
519			let left = P::broadcast((*left).into());
520			let right = P::broadcast((*right).into());
521			let mut result = MaybeUninit::uninit();
522			op(left, right, &mut result);
523			for j in 0..row_len {
524				// Safety:
525				// - `current` has length equal to `row_len`
526				// - `result` is initialized by `op`
527				unsafe {
528					current_evals
529						.get_unchecked_mut(j)
530						.write(result.assume_init());
531				}
532			}
533		}
534	}
535}
536
537fn alloc_scratch_space<T, U, F>(size: usize, callback: F) -> U
538where
539	F: FnOnce(&mut [MaybeUninit<T>]) -> U,
540{
541	use std::mem;
542	// We don't want to deal with running destructors.
543	assert!(!mem::needs_drop::<T>());
544
545	#[cfg(miri)]
546	{
547		let mut scratch_space = Vec::<T>::with_capacity(size);
548		let out = callback(scratch_space.spare_capacity_mut());
549		drop(scratch_space);
550		out
551	}
552	#[cfg(not(miri))]
553	{
554		// `stackalloc_uninit` throws a debug assert if `size` is 0, so set minimum of 1.
555		let size = size.max(1);
556		stackalloc::stackalloc_uninit(size, callback)
557	}
558}
559
560#[cfg(test)]
561mod tests {
562	use binius_field::{
563		BinaryField8b, BinaryField16b, PackedBinaryField8x16b, PackedField, TowerField,
564	};
565	use binius_math::{ArithExpr, CompositionPoly, RowsBatch};
566	use binius_utils::felts;
567
568	use super::*;
569
570	#[test]
571	fn test_constant() {
572		type F = BinaryField8b;
573		type P = PackedBinaryField8x16b;
574
575		let expr = ArithExpr::Const(F::new(123));
576		let circuit = ArithCircuitPoly::<F>::new(expr.into());
577
578		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
579		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
580		assert_eq!(typed_circuit.degree(), 0);
581		assert_eq!(typed_circuit.n_vars(), 0);
582
583		assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
584
585		let mut evals = [P::default()];
586		typed_circuit
587			.batch_evaluate(&RowsBatchRef::new(&[], 1), &mut evals)
588			.unwrap();
589		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
590	}
591
592	#[test]
593	fn test_identity() {
594		type F = BinaryField8b;
595		type P = PackedBinaryField8x16b;
596
597		// x0
598		let expr = ArithExpr::Var(0);
599		let circuit = ArithCircuitPoly::<F>::new(expr.into());
600
601		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
602		assert_eq!(typed_circuit.binary_tower_level(), 0);
603		assert_eq!(typed_circuit.degree(), 1);
604		assert_eq!(typed_circuit.n_vars(), 1);
605
606		assert_eq!(
607			typed_circuit
608				.evaluate(&[P::broadcast(F::new(123).into())])
609				.unwrap(),
610			P::broadcast(F::new(123).into())
611		);
612
613		let mut evals = [P::default()];
614		let batch_query = [[P::broadcast(F::new(123).into())]; 1];
615		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
616		typed_circuit
617			.batch_evaluate(&batch_query.get_ref(), &mut evals)
618			.unwrap();
619		assert_eq!(evals, [P::broadcast(F::new(123).into())]);
620	}
621
622	#[test]
623	fn test_add() {
624		type F = BinaryField8b;
625		type P = PackedBinaryField8x16b;
626
627		// 123 + x0
628		let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
629		let circuit = ArithCircuitPoly::<F>::new(expr.into());
630
631		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
632		assert_eq!(typed_circuit.binary_tower_level(), 3);
633		assert_eq!(typed_circuit.degree(), 1);
634		assert_eq!(typed_circuit.n_vars(), 1);
635
636		assert_eq!(
637			CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
638			P::broadcast(F::new(123).into())
639		);
640	}
641
642	#[test]
643	fn test_mul() {
644		type F = BinaryField8b;
645		type P = PackedBinaryField8x16b;
646
647		// 123 * x0
648		let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
649		let circuit = ArithCircuitPoly::<F>::new(expr.into());
650
651		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
652		assert_eq!(typed_circuit.binary_tower_level(), 3);
653		assert_eq!(typed_circuit.degree(), 1);
654		assert_eq!(typed_circuit.n_vars(), 1);
655
656		assert_eq!(
657			CompositionPoly::evaluate(
658				&circuit,
659				&[P::from_scalars(
660					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
661				)]
662			)
663			.unwrap(),
664			P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
665		);
666	}
667
668	#[test]
669	fn test_pow() {
670		type F = BinaryField8b;
671		type P = PackedBinaryField8x16b;
672
673		// x0^13
674		let expr = ArithExpr::Var(0).pow(13);
675		let circuit = ArithCircuitPoly::<F>::new(expr.into());
676
677		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
678		assert_eq!(typed_circuit.binary_tower_level(), 0);
679		assert_eq!(typed_circuit.degree(), 13);
680		assert_eq!(typed_circuit.n_vars(), 1);
681
682		assert_eq!(
683			CompositionPoly::evaluate(
684				&circuit,
685				&[P::from_scalars(
686					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
687				)]
688			)
689			.unwrap(),
690			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
691		);
692	}
693
694	#[test]
695	fn test_mixed() {
696		type F = BinaryField8b;
697		type P = PackedBinaryField8x16b;
698
699		// x0^2 * (x1 + 123)
700		let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
701		let circuit = ArithCircuitPoly::<F>::new(expr.into());
702
703		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
704		assert_eq!(typed_circuit.binary_tower_level(), 3);
705		assert_eq!(typed_circuit.degree(), 3);
706		assert_eq!(typed_circuit.n_vars(), 2);
707
708		// test evaluate
709		assert_eq!(
710			CompositionPoly::evaluate(
711				&circuit,
712				&[
713					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
714					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
715				]
716			)
717			.unwrap(),
718			P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
719		);
720
721		// test batch evaluate
722		let query1 = &[
723			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
724			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
725		];
726		let query2 = &[
727			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
728			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
729		];
730		let query3 = &[
731			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
732			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
733		];
734		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
735		let expected2 =
736			P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
737		let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
738
739		let mut batch_result = vec![P::zero(); 3];
740		let batch_query = &[
741			&[query1[0], query2[0], query3[0]],
742			&[query1[1], query2[1], query3[1]],
743		];
744		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
745
746		CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
747			.unwrap();
748		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
749	}
750
751	#[test]
752	fn batch_evaluate_add_mul() {
753		// This test is focused on exposing the the currently present stacked borrows violation. It
754		// passes but still triggers `miri`.
755
756		type F = BinaryField8b;
757		type P = PackedBinaryField8x16b;
758
759		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(0))
760			+ (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(0)
761			- ArithExpr::Var(0);
762		let circuit = ArithCircuitPoly::<F>::new(expr.into());
763
764		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
765		assert_eq!(typed_circuit.binary_tower_level(), 0);
766		assert_eq!(typed_circuit.degree(), 2);
767		assert_eq!(typed_circuit.n_vars(), 1);
768
769		let mut evals = [P::default(); 1];
770		let batch_query = [[P::broadcast(F::new(1).into())]; 1];
771		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
772		typed_circuit
773			.batch_evaluate(&batch_query.get_ref(), &mut evals)
774			.unwrap();
775	}
776
777	#[test]
778	fn test_const_fold() {
779		type F = BinaryField8b;
780		type P = PackedBinaryField8x16b;
781
782		// x0 * ((122 * 123) + (124 + 125)) + x1
783		let expr = ArithExpr::Var(0)
784			* ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
785				+ (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
786			+ ArithExpr::Var(1);
787		let circuit = ArithCircuitPoly::<F>::new(expr.into());
788		assert_eq!(circuit.steps.len(), 2);
789
790		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
791		assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
792		assert_eq!(typed_circuit.degree(), 1);
793		assert_eq!(typed_circuit.n_vars(), 2);
794
795		// test evaluate
796		assert_eq!(
797			CompositionPoly::evaluate(
798				&circuit,
799				&[
800					P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
801					P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
802				]
803			)
804			.unwrap(),
805			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
806		);
807
808		// test batch evaluate
809		let query1 = &[
810			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
811			P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
812		];
813		let query2 = &[
814			P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
815			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
816		];
817		let query3 = &[
818			P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
819			P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
820		];
821		let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
822		let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
823		let expected3 =
824			P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
825
826		let mut batch_result = vec![P::zero(); 3];
827		let batch_query = &[
828			&[query1[0], query2[0], query3[0]],
829			&[query1[1], query2[1], query3[1]],
830		];
831		let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
832		CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
833			.unwrap();
834		assert_eq!(&batch_result, &[expected1, expected2, expected3]);
835	}
836
837	#[test]
838	fn test_pow_const_fold() {
839		type F = BinaryField8b;
840		type P = PackedBinaryField8x16b;
841
842		// x0 + 2^5
843		let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
844		let circuit = ArithCircuitPoly::<F>::new(expr.into());
845		assert_eq!(circuit.steps.len(), 1);
846
847		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
848		assert_eq!(typed_circuit.binary_tower_level(), 1);
849		assert_eq!(typed_circuit.degree(), 1);
850		assert_eq!(typed_circuit.n_vars(), 1);
851
852		assert_eq!(
853			CompositionPoly::evaluate(
854				&circuit,
855				&[P::from_scalars(
856					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
857				)]
858			)
859			.unwrap(),
860			P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
861		);
862	}
863
864	#[test]
865	fn test_pow_nested() {
866		type F = BinaryField8b;
867		type P = PackedBinaryField8x16b;
868
869		// ((x0^2)^3)^4
870		let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
871		let circuit = ArithCircuitPoly::<F>::new(expr.into());
872		assert_eq!(circuit.steps.len(), 5);
873
874		let typed_circuit: &dyn CompositionPoly<P> = &circuit;
875		assert_eq!(typed_circuit.binary_tower_level(), 0);
876		assert_eq!(typed_circuit.degree(), 24);
877		assert_eq!(typed_circuit.n_vars(), 1);
878
879		assert_eq!(
880			CompositionPoly::evaluate(
881				&circuit,
882				&[P::from_scalars(
883					felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
884				)]
885			)
886			.unwrap(),
887			P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
888		);
889	}
890
891	#[test]
892	fn test_circuit_steps_for_expr_constant() {
893		type F = BinaryField8b;
894
895		let expr = ArithExpr::Const(F::new(5));
896		let (steps, retval) = convert_circuit_steps(&expr.into());
897
898		assert!(steps.is_empty(), "No steps should be generated for a constant");
899		assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
900	}
901
902	#[test]
903	fn test_circuit_steps_for_expr_variable() {
904		type F = BinaryField8b;
905
906		let expr = ArithExpr::<F>::Var(18);
907		let (steps, retval) = convert_circuit_steps(&expr.into());
908
909		assert!(steps.is_empty(), "No steps should be generated for a variable");
910		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
911	}
912
913	#[test]
914	fn test_circuit_steps_for_expr_addition() {
915		type F = BinaryField8b;
916
917		let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
918		let (steps, retval) = convert_circuit_steps(&expr.into());
919
920		assert_eq!(steps.len(), 1, "One addition step should be generated");
921		assert!(matches!(
922			steps[0],
923			CircuitStep::Add(
924				CircuitStepArgument::Expr(CircuitNode::Var(14)),
925				CircuitStepArgument::Expr(CircuitNode::Var(56))
926			)
927		));
928		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
929	}
930
931	#[test]
932	fn test_circuit_steps_for_expr_multiplication() {
933		type F = BinaryField8b;
934
935		let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
936		let (steps, retval) = convert_circuit_steps(&expr.into());
937
938		assert_eq!(steps.len(), 1, "One multiplication step should be generated");
939		assert!(matches!(
940			steps[0],
941			CircuitStep::Mul(
942				CircuitStepArgument::Expr(CircuitNode::Var(36)),
943				CircuitStepArgument::Expr(CircuitNode::Var(26))
944			)
945		));
946		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
947	}
948
949	#[test]
950	fn test_circuit_steps_for_expr_pow_1() {
951		type F = BinaryField8b;
952
953		let expr = ArithExpr::<F>::Var(12).pow(1);
954		let (steps, retval) = convert_circuit_steps(&expr.into());
955
956		// No steps should be generated for x^1
957		assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
958
959		// The return value should just be the variable itself
960		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
961	}
962
963	#[test]
964	fn test_circuit_steps_for_expr_pow_2() {
965		type F = BinaryField8b;
966
967		let expr = ArithExpr::<F>::Var(10).pow(2);
968		let (steps, retval) = convert_circuit_steps(&expr.into());
969
970		assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
971		assert!(matches!(
972			steps[0],
973			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
974		));
975		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
976	}
977
978	#[test]
979	fn test_circuit_steps_for_expr_pow_3() {
980		type F = BinaryField8b;
981
982		let expr = ArithExpr::<F>::Var(5).pow(3);
983		let (steps, retval) = convert_circuit_steps(&expr.into());
984
985		assert_eq!(
986			steps.len(),
987			2,
988			"Pow(3) should generate one squaring and one multiplication step"
989		);
990		assert!(matches!(
991			steps[0],
992			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
993		));
994		assert!(matches!(
995			steps[1],
996			CircuitStep::Mul(
997				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
998				CircuitStepArgument::Expr(CircuitNode::Var(5))
999			)
1000		));
1001		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
1002	}
1003
1004	#[test]
1005	fn test_circuit_steps_for_expr_pow_4() {
1006		type F = BinaryField8b;
1007
1008		let expr = ArithExpr::<F>::Var(7).pow(4);
1009		let (steps, retval) = convert_circuit_steps(&expr.into());
1010
1011		assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
1012		assert!(matches!(
1013			steps[0],
1014			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1015		));
1016
1017		assert!(matches!(
1018			steps[1],
1019			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1020		));
1021
1022		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
1023	}
1024
1025	#[test]
1026	fn test_circuit_steps_for_expr_pow_5() {
1027		type F = BinaryField8b;
1028
1029		let expr = ArithExpr::<F>::Var(3).pow(5);
1030		let (steps, retval) = convert_circuit_steps(&expr.into());
1031
1032		assert_eq!(
1033			steps.len(),
1034			3,
1035			"Pow(5) should generate two squaring steps and one multiplication"
1036		);
1037		assert!(matches!(
1038			steps[0],
1039			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
1040		));
1041		assert!(matches!(
1042			steps[1],
1043			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1044		));
1045		assert!(matches!(
1046			steps[2],
1047			CircuitStep::Mul(
1048				CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1049				CircuitStepArgument::Expr(CircuitNode::Var(3))
1050			)
1051		));
1052
1053		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1054	}
1055
1056	#[test]
1057	fn test_circuit_steps_for_expr_pow_8() {
1058		type F = BinaryField8b;
1059
1060		let expr = ArithExpr::<F>::Var(4).pow(8);
1061		let (steps, retval) = convert_circuit_steps(&expr.into());
1062
1063		assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
1064		assert!(matches!(
1065			steps[0],
1066			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
1067		));
1068		assert!(matches!(
1069			steps[1],
1070			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1071		));
1072		assert!(matches!(
1073			steps[2],
1074			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1075		));
1076
1077		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1078	}
1079
1080	#[test]
1081	fn test_circuit_steps_for_expr_pow_9() {
1082		type F = BinaryField8b;
1083
1084		let expr = ArithExpr::<F>::Var(8).pow(9);
1085		let (steps, retval) = convert_circuit_steps(&expr.into());
1086
1087		assert_eq!(
1088			steps.len(),
1089			4,
1090			"Pow(9) should generate three squaring steps and one multiplication"
1091		);
1092		assert!(matches!(
1093			steps[0],
1094			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
1095		));
1096		assert!(matches!(
1097			steps[1],
1098			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1099		));
1100		assert!(matches!(
1101			steps[2],
1102			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1103		));
1104		assert!(matches!(
1105			steps[3],
1106			CircuitStep::Mul(
1107				CircuitStepArgument::Expr(CircuitNode::Slot(2)),
1108				CircuitStepArgument::Expr(CircuitNode::Var(8))
1109			)
1110		));
1111
1112		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1113	}
1114
1115	#[test]
1116	fn test_circuit_steps_for_expr_pow_12() {
1117		type F = BinaryField8b;
1118		let expr = ArithExpr::<F>::Var(6).pow(12);
1119		let (steps, retval) = convert_circuit_steps(&expr.into());
1120
1121		assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1122
1123		assert!(matches!(
1124			steps[0],
1125			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1126		));
1127		assert!(matches!(
1128			steps[1],
1129			CircuitStep::Mul(
1130				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1131				CircuitStepArgument::Expr(CircuitNode::Var(6))
1132			)
1133		));
1134		assert!(matches!(
1135			steps[2],
1136			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1137		));
1138		assert!(matches!(
1139			steps[3],
1140			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1141		));
1142
1143		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1144	}
1145
1146	#[test]
1147	fn test_circuit_steps_for_expr_pow_13() {
1148		type F = BinaryField8b;
1149		let expr = ArithExpr::<F>::Var(7).pow(13);
1150		let (steps, retval) = convert_circuit_steps(&expr.into());
1151
1152		assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1153		assert!(matches!(
1154			steps[0],
1155			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1156		));
1157		assert!(matches!(
1158			steps[1],
1159			CircuitStep::Mul(
1160				CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1161				CircuitStepArgument::Expr(CircuitNode::Var(7))
1162			)
1163		));
1164		assert!(matches!(
1165			steps[2],
1166			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1167		));
1168		assert!(matches!(
1169			steps[3],
1170			CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1171		));
1172		assert!(matches!(
1173			steps[4],
1174			CircuitStep::Mul(
1175				CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1176				CircuitStepArgument::Expr(CircuitNode::Var(7))
1177			)
1178		));
1179		assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1180	}
1181
1182	#[test]
1183	fn test_circuit_steps_for_expr_complex() {
1184		type F = BinaryField8b;
1185
1186		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1187			+ (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1188			- ArithExpr::Var(3);
1189
1190		let (steps, retval) = convert_circuit_steps(&expr.into());
1191
1192		assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1193
1194		assert!(
1195			matches!(
1196				steps[0],
1197				CircuitStep::Mul(
1198					CircuitStepArgument::Expr(CircuitNode::Var(0)),
1199					CircuitStepArgument::Expr(CircuitNode::Var(1))
1200				)
1201			),
1202			"First step should be multiplication x0 * x1"
1203		);
1204
1205		assert!(
1206			matches!(
1207				steps[1],
1208				CircuitStep::Add(
1209					CircuitStepArgument::Const(F::ONE),
1210					CircuitStepArgument::Expr(CircuitNode::Var(0))
1211				)
1212			),
1213			"Second step should be (1 - x0)"
1214		);
1215
1216		assert!(
1217			matches!(
1218				steps[2],
1219				CircuitStep::AddMul(
1220					0,
1221					CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1222					CircuitStepArgument::Expr(CircuitNode::Var(2))
1223				)
1224			),
1225			"Third step should be (1 - x0) * x2"
1226		);
1227
1228		assert!(
1229			matches!(
1230				steps[3],
1231				CircuitStep::Add(
1232					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1233					CircuitStepArgument::Expr(CircuitNode::Var(3))
1234				)
1235			),
1236			"Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1237		);
1238
1239		assert!(
1240			matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1241			"Final result should be stored in Slot(3)"
1242		);
1243	}
1244
1245	#[test]
1246	fn check_deduplication_in_steps() {
1247		type F = BinaryField8b;
1248
1249		let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1250			+ (ArithExpr::<F>::Var(0) * ArithExpr::Var(1)) * ArithExpr::Var(2)
1251			- ArithExpr::Var(3);
1252		let expr = ArithCircuit::from(&expr);
1253		let expr = expr.optimize();
1254
1255		let (steps, retval) = convert_circuit_steps(&expr);
1256
1257		assert_eq!(steps.len(), 3, "Expression should generate 3 computation steps");
1258
1259		assert!(
1260			matches!(
1261				steps[0],
1262				CircuitStep::Mul(
1263					CircuitStepArgument::Expr(CircuitNode::Var(0)),
1264					CircuitStepArgument::Expr(CircuitNode::Var(1))
1265				)
1266			),
1267			"First step should be multiplication x0 * x1"
1268		);
1269
1270		assert!(
1271			matches!(
1272				steps[1],
1273				CircuitStep::AddMul(
1274					0,
1275					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1276					CircuitStepArgument::Expr(CircuitNode::Var(2))
1277				)
1278			),
1279			"Second step should be (x0 * x1) * x2"
1280		);
1281
1282		assert!(
1283			matches!(
1284				steps[2],
1285				CircuitStep::Add(
1286					CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1287					CircuitStepArgument::Expr(CircuitNode::Var(3))
1288				)
1289			),
1290			"Third step should be x0 * x1 + (x0 * x1) * x2 + x3"
1291		);
1292
1293		assert!(
1294			matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))),
1295			"Final result should be stored in Slot(2)"
1296		);
1297	}
1298}