1use std::{fmt::Debug, mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithExpr, CompositionPoly, Error};
7use binius_utils::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
8use stackalloc::{
9 helpers::{slice_assume_init, slice_assume_init_mut},
10 stackalloc_uninit,
11};
12
13fn circuit_steps_for_expr<F: Field>(
15 expr: &ArithExpr<F>,
16) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
17 let mut steps = Vec::new();
18
19 fn to_circuit_inner<F: Field>(
20 expr: &ArithExpr<F>,
21 result: &mut Vec<CircuitStep<F>>,
22 ) -> CircuitStepArgument<F> {
23 match expr {
24 ArithExpr::Const(value) => CircuitStepArgument::Const(*value),
25 ArithExpr::Var(index) => CircuitStepArgument::Expr(CircuitNode::Var(*index)),
26 ArithExpr::Add(left, right) => match &**right {
27 ArithExpr::Mul(mleft, mright) if left.is_composite() => {
28 let left = to_circuit_inner(left, result);
31 let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left else {
32 unreachable!("guaranteed by `is_composite` check above")
33 };
34 let mleft = to_circuit_inner(mleft, result);
35 let mright = to_circuit_inner(mright, result);
36 result.push(CircuitStep::AddMul(left, mleft, mright));
37 CircuitStepArgument::Expr(CircuitNode::Slot(left))
38 }
39 _ => {
40 let left = to_circuit_inner(left, result);
41 let right = to_circuit_inner(right, result);
42 result.push(CircuitStep::Add(left, right));
43 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
44 }
45 },
46 ArithExpr::Mul(left, right) => {
47 let left = to_circuit_inner(left, result);
48 let right = to_circuit_inner(right, result);
49 result.push(CircuitStep::Mul(left, right));
50 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
51 }
52 ArithExpr::Pow(base, exp) => {
53 let mut acc = to_circuit_inner(base, result);
54 let base_expr = acc;
55 let highest_bit = exp.ilog2();
56
57 for i in (0..highest_bit).rev() {
58 result.push(CircuitStep::Square(acc));
59 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
60
61 if (exp >> i) & 1 != 0 {
62 result.push(CircuitStep::Mul(acc, base_expr));
63 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
64 }
65 }
66
67 acc
68 }
69 }
70 }
71
72 let ret = to_circuit_inner(&expr.optimize(), &mut steps);
73 (steps, ret)
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78enum CircuitNode {
79 Var(usize),
81 Slot(usize),
83}
84
85impl CircuitNode {
86 fn get_sparse_chunk<'a, P: PackedField>(
89 &self,
90 inputs: &[&'a [P]],
91 evals: &'a [P],
92 row_len: usize,
93 ) -> &'a [P] {
94 match self {
95 Self::Var(index) => inputs[*index],
96 Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102enum CircuitStepArgument<F> {
103 Expr(CircuitNode),
104 Const(F),
105}
106
107#[derive(Debug)]
112enum CircuitStep<F: Field> {
113 Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
114 Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
115 Square(CircuitStepArgument<F>),
116 AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
117}
118
119#[derive(Debug, Clone)]
126pub struct ArithCircuitPoly<F: Field> {
127 expr: ArithExpr<F>,
128 steps: Arc<[CircuitStep<F>]>,
129 retval: CircuitStepArgument<F>,
131 degree: usize,
132 n_vars: usize,
133 tower_level: usize,
134}
135
136impl<F: Field> PartialEq for ArithCircuitPoly<F> {
137 fn eq(&self, other: &Self) -> bool {
138 self.n_vars == other.n_vars && self.expr == other.expr
139 }
140}
141
142impl<F: Field> Eq for ArithCircuitPoly<F> {}
143
144impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
145 fn serialize(
146 &self,
147 mut write_buf: impl bytes::BufMut,
148 mode: SerializationMode,
149 ) -> Result<(), SerializationError> {
150 (&self.expr, self.n_vars).serialize(&mut write_buf, mode)
151 }
152}
153
154impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
155 fn deserialize(
156 read_buf: impl bytes::Buf,
157 mode: SerializationMode,
158 ) -> Result<Self, SerializationError>
159 where
160 Self: Sized,
161 {
162 let (expr, n_vars) = <(ArithExpr<F>, usize)>::deserialize(read_buf, mode)?;
163 Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
164 name: "ArithCircuitPoly",
165 })
166 }
167}
168
169impl<F: TowerField> ArithCircuitPoly<F> {
170 pub fn new(expr: ArithExpr<F>) -> Self {
171 let degree = expr.degree();
172 let n_vars = expr.n_vars();
173 let tower_level = expr.binary_tower_level();
174 let (exprs, retval) = circuit_steps_for_expr(&expr);
175
176 Self {
177 expr,
178 steps: exprs.into(),
179 retval,
180 degree,
181 n_vars,
182 tower_level,
183 }
184 }
185
186 pub fn with_n_vars(n_vars: usize, expr: ArithExpr<F>) -> Result<Self, Error> {
191 let degree = expr.degree();
192 let tower_level = expr.binary_tower_level();
193 if n_vars < expr.n_vars() {
194 return Err(Error::IncorrectNumberOfVariables {
195 expected: expr.n_vars(),
196 actual: n_vars,
197 });
198 }
199 let (exprs, retval) = circuit_steps_for_expr(&expr);
200
201 Ok(Self {
202 expr,
203 steps: exprs.into(),
204 retval,
205 n_vars,
206 degree,
207 tower_level,
208 })
209 }
210}
211
212impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
213 for ArithCircuitPoly<F>
214{
215 fn degree(&self) -> usize {
216 self.degree
217 }
218
219 fn n_vars(&self) -> usize {
220 self.n_vars
221 }
222
223 fn binary_tower_level(&self) -> usize {
224 self.tower_level
225 }
226
227 fn expression(&self) -> ArithExpr<P::Scalar> {
228 self.expr.convert_field()
229 }
230
231 fn evaluate(&self, query: &[P]) -> Result<P, Error> {
232 if query.len() != self.n_vars {
233 return Err(Error::IncorrectQuerySize {
234 expected: self.n_vars,
235 });
236 }
237
238 fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
239 unsafe {
242 target.get_unchecked_mut(0).write(value);
243 }
244 }
245
246 stackalloc_uninit::<P, _, _>(self.steps.len().max(1), |evals| {
248 let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
249 CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
251 *query.get_unchecked(index)
252 },
253 CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
255 *evals.get_unchecked(slot)
256 },
257 CircuitStepArgument::Const(value) => P::broadcast(value.into()),
258 };
259
260 for (i, expr) in self.steps.iter().enumerate() {
261 let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
263 let before = unsafe { slice_assume_init_mut(before) };
264 match expr {
265 CircuitStep::Add(x, y) => write_result(
266 after,
267 get_argument_value(*x, before) + get_argument_value(*y, before),
268 ),
269 CircuitStep::AddMul(target_slot, x, y) => {
270 let intermediate =
271 get_argument_value(*x, before) * get_argument_value(*y, before);
272 let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
274 *target_slot += intermediate;
275 }
276 CircuitStep::Mul(x, y) => write_result(
277 after,
278 get_argument_value(*x, before) * get_argument_value(*y, before),
279 ),
280 CircuitStep::Square(x) => {
281 write_result(after, get_argument_value(*x, before).square())
282 }
283 };
284 }
285
286 unsafe {
289 let evals = slice_assume_init(evals);
290 Ok(get_argument_value(self.retval, evals))
291 }
292 })
293 }
294
295 fn batch_evaluate(&self, batch_query: &[&[P]], evals: &mut [P]) -> Result<(), Error> {
296 let row_len = evals.len();
297 if batch_query.iter().any(|row| row.len() != row_len) {
298 return Err(Error::BatchEvaluateSizeMismatch);
299 }
300
301 stackalloc_uninit::<P, (), _>((self.steps.len() * row_len).max(1), |sparse_evals| {
303 for (i, expr) in self.steps.iter().enumerate() {
304 let (before, current) = sparse_evals.split_at_mut(i * row_len);
305
306 let before = unsafe { slice_assume_init_mut(before) };
308 let current = &mut current[..row_len];
309
310 match expr {
311 CircuitStep::Add(left, right) => {
312 apply_binary_op(
313 left,
314 right,
315 batch_query,
316 before,
317 current,
318 |left, right, out| {
319 out.write(left + right);
320 },
321 );
322 }
323 CircuitStep::Mul(left, right) => {
324 apply_binary_op(
325 left,
326 right,
327 batch_query,
328 before,
329 current,
330 |left, right, out| {
331 out.write(left * right);
332 },
333 );
334 }
335 CircuitStep::Square(arg) => {
336 match arg {
337 CircuitStepArgument::Expr(node) => {
338 let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
339 for j in 0..row_len {
340 unsafe {
342 current
343 .get_unchecked_mut(j)
344 .write(id_chunk.get_unchecked(j).square());
345 }
346 }
347 }
348 CircuitStepArgument::Const(value) => {
349 let value: P = P::broadcast((*value).into());
350 let result = value.square();
351 for j in 0..row_len {
352 unsafe {
354 current.get_unchecked_mut(j).write(result);
355 }
356 }
357 }
358 }
359 }
360 CircuitStep::AddMul(target, left, right) => {
361 let target = &before[row_len * target..(target + 1) * row_len];
362 let target: &mut [MaybeUninit<P>] = unsafe {
365 std::slice::from_raw_parts_mut(
366 target.as_ptr() as *mut MaybeUninit<P>,
367 target.len(),
368 )
369 };
370 apply_binary_op(
371 left,
372 right,
373 batch_query,
374 before,
375 target,
376 |left, right, out| unsafe {
379 let out = out.assume_init_mut();
380 *out += left * right;
381 },
382 );
383 }
384 }
385 }
386
387 match self.retval {
388 CircuitStepArgument::Expr(node) => {
389 let sparse_evals = unsafe { slice_assume_init(sparse_evals) };
391 evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
392 }
393 CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
394 }
395 });
396
397 Ok(())
398 }
399}
400
401fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
404 left: &CircuitStepArgument<F>,
405 right: &CircuitStepArgument<F>,
406 batch_query: &[&[P]],
407 evals_before: &[P],
408 current_evals: &mut [MaybeUninit<P>],
409 op: impl Fn(P, P, &mut MaybeUninit<P>),
410) {
411 let row_len = current_evals.len();
412
413 match (left, right) {
414 (CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
415 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
416 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
417 for j in 0..row_len {
418 unsafe {
420 op(
421 *left.get_unchecked(j),
422 *right.get_unchecked(j),
423 current_evals.get_unchecked_mut(j),
424 )
425 }
426 }
427 }
428 (CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
429 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
430 let right = P::broadcast((*right).into());
431 for j in 0..row_len {
432 unsafe {
434 op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
435 }
436 }
437 }
438 (CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
439 let left = P::broadcast((*left).into());
440 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
441 for j in 0..row_len {
442 unsafe {
444 op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
445 }
446 }
447 }
448 (CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
449 let left = P::broadcast((*left).into());
450 let right = P::broadcast((*right).into());
451 let mut result = MaybeUninit::uninit();
452 op(left, right, &mut result);
453 for j in 0..row_len {
454 unsafe {
458 current_evals
459 .get_unchecked_mut(j)
460 .write(result.assume_init());
461 }
462 }
463 }
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use binius_field::{
470 BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedField, TowerField,
471 };
472 use binius_math::CompositionPoly;
473 use binius_utils::felts;
474
475 use super::*;
476
477 #[test]
478 fn test_constant() {
479 type F = BinaryField8b;
480 type P = PackedBinaryField8x16b;
481
482 let expr = ArithExpr::Const(F::new(123));
483 let circuit = ArithCircuitPoly::<F>::new(expr);
484
485 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
486 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
487 assert_eq!(typed_circuit.degree(), 0);
488 assert_eq!(typed_circuit.n_vars(), 0);
489
490 assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
491
492 let mut evals = [P::default()];
493 typed_circuit.batch_evaluate(&[], &mut evals).unwrap();
494 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
495 }
496
497 #[test]
498 fn test_identity() {
499 type F = BinaryField8b;
500 type P = PackedBinaryField8x16b;
501
502 let expr = ArithExpr::Var(0);
504 let circuit = ArithCircuitPoly::<F>::new(expr);
505
506 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
507 assert_eq!(typed_circuit.binary_tower_level(), 0);
508 assert_eq!(typed_circuit.degree(), 1);
509 assert_eq!(typed_circuit.n_vars(), 1);
510
511 assert_eq!(
512 typed_circuit
513 .evaluate(&[P::broadcast(F::new(123).into())])
514 .unwrap(),
515 P::broadcast(F::new(123).into())
516 );
517
518 let mut evals = [P::default()];
519 typed_circuit
520 .batch_evaluate(&[&[P::broadcast(F::new(123).into())]], &mut evals)
521 .unwrap();
522 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
523 }
524
525 #[test]
526 fn test_add() {
527 type F = BinaryField8b;
528 type P = PackedBinaryField8x16b;
529
530 let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
532 let circuit = ArithCircuitPoly::<F>::new(expr);
533
534 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
535 assert_eq!(typed_circuit.binary_tower_level(), 3);
536 assert_eq!(typed_circuit.degree(), 1);
537 assert_eq!(typed_circuit.n_vars(), 1);
538
539 assert_eq!(
540 CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
541 P::broadcast(F::new(123).into())
542 );
543 }
544
545 #[test]
546 fn test_mul() {
547 type F = BinaryField8b;
548 type P = PackedBinaryField8x16b;
549
550 let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
552 let circuit = ArithCircuitPoly::<F>::new(expr);
553
554 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
555 assert_eq!(typed_circuit.binary_tower_level(), 3);
556 assert_eq!(typed_circuit.degree(), 1);
557 assert_eq!(typed_circuit.n_vars(), 1);
558
559 assert_eq!(
560 CompositionPoly::evaluate(
561 &circuit,
562 &[P::from_scalars(
563 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
564 )]
565 )
566 .unwrap(),
567 P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
568 );
569 }
570
571 #[test]
572 fn test_pow() {
573 type F = BinaryField8b;
574 type P = PackedBinaryField8x16b;
575
576 let expr = ArithExpr::Var(0).pow(13);
578 let circuit = ArithCircuitPoly::<F>::new(expr);
579
580 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
581 assert_eq!(typed_circuit.binary_tower_level(), 0);
582 assert_eq!(typed_circuit.degree(), 13);
583 assert_eq!(typed_circuit.n_vars(), 1);
584
585 assert_eq!(
586 CompositionPoly::evaluate(
587 &circuit,
588 &[P::from_scalars(
589 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
590 )]
591 )
592 .unwrap(),
593 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
594 );
595 }
596
597 #[test]
598 fn test_mixed() {
599 type F = BinaryField8b;
600 type P = PackedBinaryField8x16b;
601
602 let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
604 let circuit = ArithCircuitPoly::<F>::new(expr);
605
606 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
607 assert_eq!(typed_circuit.binary_tower_level(), 3);
608 assert_eq!(typed_circuit.degree(), 3);
609 assert_eq!(typed_circuit.n_vars(), 2);
610
611 assert_eq!(
613 CompositionPoly::evaluate(
614 &circuit,
615 &[
616 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
617 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
618 ]
619 )
620 .unwrap(),
621 P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
622 );
623
624 let query1 = &[
626 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
627 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
628 ];
629 let query2 = &[
630 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
631 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
632 ];
633 let query3 = &[
634 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
635 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
636 ];
637 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
638 let expected2 =
639 P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
640 let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
641
642 let mut batch_result = vec![P::zero(); 3];
643 CompositionPoly::batch_evaluate(
644 &circuit,
645 &[
646 &[query1[0], query2[0], query3[0]],
647 &[query1[1], query2[1], query3[1]],
648 ],
649 &mut batch_result,
650 )
651 .unwrap();
652 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
653 }
654
655 #[test]
656 fn test_const_fold() {
657 type F = BinaryField8b;
658 type P = PackedBinaryField8x16b;
659
660 let expr = ArithExpr::Var(0)
662 * ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
663 + (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
664 + ArithExpr::Var(1);
665 let circuit = ArithCircuitPoly::<F>::new(expr);
666 assert_eq!(circuit.steps.len(), 2);
667
668 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
669 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
670 assert_eq!(typed_circuit.degree(), 1);
671 assert_eq!(typed_circuit.n_vars(), 2);
672
673 assert_eq!(
675 CompositionPoly::evaluate(
676 &circuit,
677 &[
678 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
679 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
680 ]
681 )
682 .unwrap(),
683 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
684 );
685
686 let query1 = &[
688 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
689 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
690 ];
691 let query2 = &[
692 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
693 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
694 ];
695 let query3 = &[
696 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
697 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
698 ];
699 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
700 let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
701 let expected3 =
702 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
703
704 let mut batch_result = vec![P::zero(); 3];
705 CompositionPoly::batch_evaluate(
706 &circuit,
707 &[
708 &[query1[0], query2[0], query3[0]],
709 &[query1[1], query2[1], query3[1]],
710 ],
711 &mut batch_result,
712 )
713 .unwrap();
714 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
715 }
716
717 #[test]
718 fn test_pow_const_fold() {
719 type F = BinaryField8b;
720 type P = PackedBinaryField8x16b;
721
722 let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
724 let circuit = ArithCircuitPoly::<F>::new(expr);
725 assert_eq!(circuit.steps.len(), 1);
726
727 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
728 assert_eq!(typed_circuit.binary_tower_level(), 1);
729 assert_eq!(typed_circuit.degree(), 1);
730 assert_eq!(typed_circuit.n_vars(), 1);
731
732 assert_eq!(
733 CompositionPoly::evaluate(
734 &circuit,
735 &[P::from_scalars(
736 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
737 )]
738 )
739 .unwrap(),
740 P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
741 );
742 }
743
744 #[test]
745 fn test_pow_nested() {
746 type F = BinaryField8b;
747 type P = PackedBinaryField8x16b;
748
749 let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
751 let circuit = ArithCircuitPoly::<F>::new(expr);
752 assert_eq!(circuit.steps.len(), 5);
753
754 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
755 assert_eq!(typed_circuit.binary_tower_level(), 0);
756 assert_eq!(typed_circuit.degree(), 24);
757 assert_eq!(typed_circuit.n_vars(), 1);
758
759 assert_eq!(
760 CompositionPoly::evaluate(
761 &circuit,
762 &[P::from_scalars(
763 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
764 )]
765 )
766 .unwrap(),
767 P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
768 );
769 }
770
771 #[test]
772 fn test_circuit_steps_for_expr_constant() {
773 type F = BinaryField8b;
774
775 let expr = ArithExpr::Const(F::new(5));
776 let (steps, retval) = circuit_steps_for_expr(&expr);
777
778 assert!(steps.is_empty(), "No steps should be generated for a constant");
779 assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
780 }
781
782 #[test]
783 fn test_circuit_steps_for_expr_variable() {
784 type F = BinaryField8b;
785
786 let expr = ArithExpr::<F>::Var(18);
787 let (steps, retval) = circuit_steps_for_expr(&expr);
788
789 assert!(steps.is_empty(), "No steps should be generated for a variable");
790 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
791 }
792
793 #[test]
794 fn test_circuit_steps_for_expr_addition() {
795 type F = BinaryField8b;
796
797 let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
798 let (steps, retval) = circuit_steps_for_expr(&expr);
799
800 assert_eq!(steps.len(), 1, "One addition step should be generated");
801 assert!(matches!(
802 steps[0],
803 CircuitStep::Add(
804 CircuitStepArgument::Expr(CircuitNode::Var(14)),
805 CircuitStepArgument::Expr(CircuitNode::Var(56))
806 )
807 ));
808 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
809 }
810
811 #[test]
812 fn test_circuit_steps_for_expr_multiplication() {
813 type F = BinaryField8b;
814
815 let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
816 let (steps, retval) = circuit_steps_for_expr(&expr);
817
818 assert_eq!(steps.len(), 1, "One multiplication step should be generated");
819 assert!(matches!(
820 steps[0],
821 CircuitStep::Mul(
822 CircuitStepArgument::Expr(CircuitNode::Var(36)),
823 CircuitStepArgument::Expr(CircuitNode::Var(26))
824 )
825 ));
826 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
827 }
828
829 #[test]
830 fn test_circuit_steps_for_expr_pow_1() {
831 type F = BinaryField8b;
832
833 let expr = ArithExpr::<F>::Var(12).pow(1);
834 let (steps, retval) = circuit_steps_for_expr(&expr);
835
836 assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
838
839 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
841 }
842
843 #[test]
844 fn test_circuit_steps_for_expr_pow_2() {
845 type F = BinaryField8b;
846
847 let expr = ArithExpr::<F>::Var(10).pow(2);
848 let (steps, retval) = circuit_steps_for_expr(&expr);
849
850 assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
851 assert!(matches!(
852 steps[0],
853 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
854 ));
855 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
856 }
857
858 #[test]
859 fn test_circuit_steps_for_expr_pow_3() {
860 type F = BinaryField8b;
861
862 let expr = ArithExpr::<F>::Var(5).pow(3);
863 let (steps, retval) = circuit_steps_for_expr(&expr);
864
865 assert_eq!(
866 steps.len(),
867 2,
868 "Pow(3) should generate one squaring and one multiplication step"
869 );
870 assert!(matches!(
871 steps[0],
872 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
873 ));
874 assert!(matches!(
875 steps[1],
876 CircuitStep::Mul(
877 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
878 CircuitStepArgument::Expr(CircuitNode::Var(5))
879 )
880 ));
881 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
882 }
883
884 #[test]
885 fn test_circuit_steps_for_expr_pow_4() {
886 type F = BinaryField8b;
887
888 let expr = ArithExpr::<F>::Var(7).pow(4);
889 let (steps, retval) = circuit_steps_for_expr(&expr);
890
891 assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
892 assert!(matches!(
893 steps[0],
894 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
895 ));
896
897 assert!(matches!(
898 steps[1],
899 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
900 ));
901
902 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
903 }
904
905 #[test]
906 fn test_circuit_steps_for_expr_pow_5() {
907 type F = BinaryField8b;
908
909 let expr = ArithExpr::<F>::Var(3).pow(5);
910 let (steps, retval) = circuit_steps_for_expr(&expr);
911
912 assert_eq!(
913 steps.len(),
914 3,
915 "Pow(5) should generate two squaring steps and one multiplication"
916 );
917 assert!(matches!(
918 steps[0],
919 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
920 ));
921 assert!(matches!(
922 steps[1],
923 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
924 ));
925 assert!(matches!(
926 steps[2],
927 CircuitStep::Mul(
928 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
929 CircuitStepArgument::Expr(CircuitNode::Var(3))
930 )
931 ));
932
933 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
934 }
935
936 #[test]
937 fn test_circuit_steps_for_expr_pow_8() {
938 type F = BinaryField8b;
939
940 let expr = ArithExpr::<F>::Var(4).pow(8);
941 let (steps, retval) = circuit_steps_for_expr(&expr);
942
943 assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
944 assert!(matches!(
945 steps[0],
946 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
947 ));
948 assert!(matches!(
949 steps[1],
950 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
951 ));
952 assert!(matches!(
953 steps[2],
954 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
955 ));
956
957 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
958 }
959
960 #[test]
961 fn test_circuit_steps_for_expr_pow_9() {
962 type F = BinaryField8b;
963
964 let expr = ArithExpr::<F>::Var(8).pow(9);
965 let (steps, retval) = circuit_steps_for_expr(&expr);
966
967 assert_eq!(
968 steps.len(),
969 4,
970 "Pow(9) should generate three squaring steps and one multiplication"
971 );
972 assert!(matches!(
973 steps[0],
974 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
975 ));
976 assert!(matches!(
977 steps[1],
978 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
979 ));
980 assert!(matches!(
981 steps[2],
982 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
983 ));
984 assert!(matches!(
985 steps[3],
986 CircuitStep::Mul(
987 CircuitStepArgument::Expr(CircuitNode::Slot(2)),
988 CircuitStepArgument::Expr(CircuitNode::Var(8))
989 )
990 ));
991
992 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
993 }
994
995 #[test]
996 fn test_circuit_steps_for_expr_pow_12() {
997 type F = BinaryField8b;
998 let expr = ArithExpr::<F>::Var(6).pow(12);
999 let (steps, retval) = circuit_steps_for_expr(&expr);
1000
1001 assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1002
1003 assert!(matches!(
1004 steps[0],
1005 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1006 ));
1007 assert!(matches!(
1008 steps[1],
1009 CircuitStep::Mul(
1010 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1011 CircuitStepArgument::Expr(CircuitNode::Var(6))
1012 )
1013 ));
1014 assert!(matches!(
1015 steps[2],
1016 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1017 ));
1018 assert!(matches!(
1019 steps[3],
1020 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1021 ));
1022
1023 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1024 }
1025
1026 #[test]
1027 fn test_circuit_steps_for_expr_pow_13() {
1028 type F = BinaryField8b;
1029 let expr = ArithExpr::<F>::Var(7).pow(13);
1030 let (steps, retval) = circuit_steps_for_expr(&expr);
1031
1032 assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1033 assert!(matches!(
1034 steps[0],
1035 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1036 ));
1037 assert!(matches!(
1038 steps[1],
1039 CircuitStep::Mul(
1040 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1041 CircuitStepArgument::Expr(CircuitNode::Var(7))
1042 )
1043 ));
1044 assert!(matches!(
1045 steps[2],
1046 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1047 ));
1048 assert!(matches!(
1049 steps[3],
1050 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1051 ));
1052 assert!(matches!(
1053 steps[4],
1054 CircuitStep::Mul(
1055 CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1056 CircuitStepArgument::Expr(CircuitNode::Var(7))
1057 )
1058 ));
1059 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1060 }
1061
1062 #[test]
1063 fn test_circuit_steps_for_expr_complex() {
1064 type F = BinaryField8b;
1065
1066 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1067 + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1068 - ArithExpr::Var(3);
1069
1070 let (steps, retval) = circuit_steps_for_expr(&expr);
1071
1072 assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1073
1074 assert!(
1075 matches!(
1076 steps[0],
1077 CircuitStep::Mul(
1078 CircuitStepArgument::Expr(CircuitNode::Var(0)),
1079 CircuitStepArgument::Expr(CircuitNode::Var(1))
1080 )
1081 ),
1082 "First step should be multiplication x0 * x1"
1083 );
1084
1085 assert!(
1086 matches!(
1087 steps[1],
1088 CircuitStep::Add(
1089 CircuitStepArgument::Const(F::ONE),
1090 CircuitStepArgument::Expr(CircuitNode::Var(0))
1091 )
1092 ),
1093 "Second step should be (1 - x0)"
1094 );
1095
1096 assert!(
1097 matches!(
1098 steps[2],
1099 CircuitStep::AddMul(
1100 0,
1101 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1102 CircuitStepArgument::Expr(CircuitNode::Var(2))
1103 )
1104 ),
1105 "Third step should be (1 - x0) * x2"
1106 );
1107
1108 assert!(
1109 matches!(
1110 steps[3],
1111 CircuitStep::Add(
1112 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1113 CircuitStepArgument::Expr(CircuitNode::Var(3))
1114 )
1115 ),
1116 "Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1117 );
1118
1119 assert!(
1120 matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1121 "Final result should be stored in Slot(3)"
1122 );
1123 }
1124}