1use std::{mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithCircuit, ArithCircuitStep, CompositionPoly, Error, RowsBatchRef};
7use binius_utils::{
8 DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, bail,
9 mem::{slice_assume_init_mut, slice_assume_init_ref},
10};
11
12fn convert_circuit_steps<F: Field>(
14 expr: &ArithCircuit<F>,
15) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
16 struct StepsMapping {
18 original_to_converted: Vec<Option<usize>>,
19 converted_to_original: Vec<Option<usize>>,
20 }
21
22 impl StepsMapping {
23 fn new(original_size: usize) -> Self {
24 Self {
25 original_to_converted: vec![None; original_size],
26 converted_to_original: vec![None; original_size],
28 }
29 }
30
31 fn register(&mut self, original_step: usize, converted_step: usize) {
32 self.original_to_converted[original_step] = Some(converted_step);
33
34 if converted_step >= self.converted_to_original.len() {
35 self.converted_to_original.resize(converted_step + 1, None);
36 }
37 self.converted_to_original[converted_step] = Some(original_step);
38 }
39
40 fn clear_step(&mut self, converted_step: usize) {
41 if self.converted_to_original.len() <= converted_step {
42 return;
43 }
44
45 if let Some(original_step) = self.converted_to_original[converted_step].take() {
46 self.original_to_converted[original_step] = None;
47 }
48 }
49
50 fn get_converted_step(&self, original_step: usize) -> Option<usize> {
51 self.original_to_converted[original_step]
52 }
53 }
54
55 fn convert_step<F: Field>(
56 original_step: usize,
57 original_steps: &[ArithCircuitStep<F>],
58 result: &mut Vec<CircuitStep<F>>,
59 node_to_step: &mut StepsMapping,
60 ) -> CircuitStepArgument<F> {
61 if let Some(converted_step) = node_to_step.get_converted_step(original_step) {
62 return CircuitStepArgument::Expr(CircuitNode::Slot(converted_step));
63 }
64
65 match &original_steps[original_step] {
66 ArithCircuitStep::Const(constant) => CircuitStepArgument::Const(*constant),
67 ArithCircuitStep::Var(var) => CircuitStepArgument::Expr(CircuitNode::Var(*var)),
68 ArithCircuitStep::Add(left, right) => {
69 let left = convert_step(*left, original_steps, result, node_to_step);
70
71 if let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left {
72 if let ArithCircuitStep::Mul(mleft, mright) = &original_steps[*right] {
73 let mleft = convert_step(*mleft, original_steps, result, node_to_step);
77 let mright = convert_step(*mright, original_steps, result, node_to_step);
78
79 node_to_step.clear_step(left);
82 node_to_step.register(original_step, left);
83 result.push(CircuitStep::AddMul(left, mleft, mright));
84
85 return CircuitStepArgument::Expr(CircuitNode::Slot(left));
86 }
87 }
88
89 let right = convert_step(*right, original_steps, result, node_to_step);
90
91 node_to_step.register(original_step, result.len());
92 result.push(CircuitStep::Add(left, right));
93 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
94 }
95 ArithCircuitStep::Mul(left, right) => {
96 let left = convert_step(*left, original_steps, result, node_to_step);
97 let right = convert_step(*right, original_steps, result, node_to_step);
98
99 node_to_step.register(original_step, result.len());
100 result.push(CircuitStep::Mul(left, right));
101 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
102 }
103 ArithCircuitStep::Pow(base, exp) => {
104 let mut acc = convert_step(*base, original_steps, result, node_to_step);
105 let base_expr = acc;
106 let highest_bit = exp.ilog2();
107
108 for i in (0..highest_bit).rev() {
109 if i == 0 {
110 node_to_step.register(original_step, result.len());
111 }
112 result.push(CircuitStep::Square(acc));
113 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
114
115 if (exp >> i) & 1 != 0 {
116 result.push(CircuitStep::Mul(acc, base_expr));
117 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
118 }
119 }
120
121 acc
122 }
123 }
124 }
125
126 let mut steps = Vec::new();
127 let mut steps_mapping = StepsMapping::new(expr.steps().len());
128 let ret = convert_step(expr.steps().len() - 1, expr.steps(), &mut steps, &mut steps_mapping);
129 (steps, ret)
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134enum CircuitNode {
135 Var(usize),
137 Slot(usize),
139}
140
141impl CircuitNode {
142 fn get_sparse_chunk<'a, P: PackedField>(
145 &'a self,
146 inputs: &'a RowsBatchRef<'a, P>,
147 evals: &'a [P],
148 row_len: usize,
149 ) -> &'a [P] {
150 match self {
151 Self::Var(index) => inputs.row(*index),
152 Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
153 }
154 }
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq)]
158enum CircuitStepArgument<F> {
159 Expr(CircuitNode),
160 Const(F),
161}
162
163#[derive(Debug)]
169enum CircuitStep<F: Field> {
170 Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
171 Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
172 Square(CircuitStepArgument<F>),
173 AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
174}
175
176#[derive(Debug, Clone)]
184pub struct ArithCircuitPoly<F: Field> {
185 expr: ArithCircuit<F>,
186 steps: Arc<[CircuitStep<F>]>,
187 retval: CircuitStepArgument<F>,
189 degree: usize,
190 n_vars: usize,
191 tower_level: usize,
192}
193
194impl<F: Field> PartialEq for ArithCircuitPoly<F> {
195 fn eq(&self, other: &Self) -> bool {
196 self.n_vars == other.n_vars && self.expr == other.expr
197 }
198}
199
200impl<F: Field> Eq for ArithCircuitPoly<F> {}
201
202impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
203 fn serialize(
204 &self,
205 mut write_buf: impl bytes::BufMut,
206 mode: SerializationMode,
207 ) -> Result<(), SerializationError> {
208 (&self.expr, self.n_vars).serialize(&mut write_buf, mode)
209 }
210}
211
212impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
213 fn deserialize(
214 read_buf: impl bytes::Buf,
215 mode: SerializationMode,
216 ) -> Result<Self, SerializationError>
217 where
218 Self: Sized,
219 {
220 let (expr, n_vars) = <(ArithCircuit<F>, usize)>::deserialize(read_buf, mode)?;
221 Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
222 name: "ArithCircuitPoly",
223 })
224 }
225}
226
227impl<F: TowerField> ArithCircuitPoly<F> {
228 pub fn new(mut expr: ArithCircuit<F>) -> Self {
229 expr.optimize_in_place();
230
231 let degree = expr.degree();
232 let n_vars = expr.n_vars();
233 let tower_level = expr.binary_tower_level();
234 let (exprs, retval) = convert_circuit_steps(&expr);
235
236 Self {
237 expr,
238 steps: exprs.into(),
239 retval,
240 degree,
241 n_vars,
242 tower_level,
243 }
244 }
245 pub fn with_n_vars(n_vars: usize, mut expr: ArithCircuit<F>) -> Result<Self, Error> {
250 expr.optimize_in_place();
251
252 let degree = expr.degree();
253 let tower_level = expr.binary_tower_level();
254 if n_vars < expr.n_vars() {
255 return Err(Error::IncorrectNumberOfVariables {
256 expected: expr.n_vars(),
257 actual: n_vars,
258 });
259 }
260 let (steps, retval) = convert_circuit_steps(&expr);
261
262 Ok(Self {
263 expr,
264 steps: steps.into(),
265 retval,
266 n_vars,
267 degree,
268 tower_level,
269 })
270 }
271
272 pub fn expr(&self) -> &ArithCircuit<F> {
274 &self.expr
275 }
276}
277
278impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
279 for ArithCircuitPoly<F>
280{
281 fn degree(&self) -> usize {
282 self.degree
283 }
284
285 fn n_vars(&self) -> usize {
286 self.n_vars
287 }
288
289 fn binary_tower_level(&self) -> usize {
290 self.tower_level
291 }
292
293 fn expression(&self) -> ArithCircuit<P::Scalar> {
294 self.expr.convert_field()
295 }
296
297 fn evaluate(&self, query: &[P]) -> Result<P, Error> {
298 if query.len() != self.n_vars {
299 return Err(Error::IncorrectQuerySize {
300 expected: self.n_vars,
301 actual: query.len(),
302 });
303 }
304
305 fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
306 unsafe {
309 target.get_unchecked_mut(0).write(value);
310 }
311 }
312
313 alloc_scratch_space::<P, _, _>(self.steps.len(), |evals| {
314 let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
315 CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
318 *query.get_unchecked(index)
319 },
320 CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
323 *evals.get_unchecked(slot)
324 },
325 CircuitStepArgument::Const(value) => P::broadcast(value.into()),
326 };
327
328 for (i, expr) in self.steps.iter().enumerate() {
329 let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
332 let before = unsafe { slice_assume_init_mut(before) };
333 match expr {
334 CircuitStep::Add(x, y) => write_result(
335 after,
336 get_argument_value(*x, before) + get_argument_value(*y, before),
337 ),
338 CircuitStep::AddMul(target_slot, x, y) => {
339 let intermediate =
340 get_argument_value(*x, before) * get_argument_value(*y, before);
341 let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
344 *target_slot += intermediate;
345 }
346 CircuitStep::Mul(x, y) => write_result(
347 after,
348 get_argument_value(*x, before) * get_argument_value(*y, before),
349 ),
350 CircuitStep::Square(x) => {
351 write_result(after, get_argument_value(*x, before).square())
352 }
353 };
354 }
355
356 unsafe {
359 let evals = slice_assume_init_ref(evals);
360 Ok(get_argument_value(self.retval, evals))
361 }
362 })
363 }
364
365 fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
366 let row_len = evals.len();
367 if batch_query.row_len() != row_len {
368 bail!(Error::BatchEvaluateSizeMismatch {
369 expected: row_len,
370 actual: batch_query.row_len(),
371 });
372 }
373
374 alloc_scratch_space::<P, (), _>(self.steps.len() * row_len, |sparse_evals| {
375 for (i, expr) in self.steps.iter().enumerate() {
376 let (before, current) = sparse_evals.split_at_mut(i * row_len);
377
378 let before = unsafe { slice_assume_init_mut(before) };
381 let current = &mut current[..row_len];
382
383 match expr {
384 CircuitStep::Add(left, right) => {
385 apply_binary_op(
386 left,
387 right,
388 batch_query,
389 before,
390 current,
391 |left, right, out| {
392 out.write(left + right);
393 },
394 );
395 }
396 CircuitStep::Mul(left, right) => {
397 apply_binary_op(
398 left,
399 right,
400 batch_query,
401 before,
402 current,
403 |left, right, out| {
404 out.write(left * right);
405 },
406 );
407 }
408 CircuitStep::Square(arg) => {
409 match arg {
410 CircuitStepArgument::Expr(node) => {
411 let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
412 for j in 0..row_len {
413 unsafe {
416 current
417 .get_unchecked_mut(j)
418 .write(id_chunk.get_unchecked(j).square());
419 }
420 }
421 }
422 CircuitStepArgument::Const(value) => {
423 let value: P = P::broadcast((*value).into());
424 let result = value.square();
425 for j in 0..row_len {
426 unsafe {
428 current.get_unchecked_mut(j).write(result);
429 }
430 }
431 }
432 }
433 }
434 CircuitStep::AddMul(target, left, right) => {
435 let target = &mut before[row_len * target..(target + 1) * row_len];
436 let target: &mut [MaybeUninit<P>] = unsafe {
439 std::slice::from_raw_parts_mut(
440 target.as_mut_ptr() as *mut MaybeUninit<P>,
441 target.len(),
442 )
443 };
444 apply_binary_op(
445 left,
446 right,
447 batch_query,
448 before,
449 target,
450 |left, right, out| unsafe {
453 let out = out.assume_init_mut();
454 *out += left * right;
455 },
456 );
457 }
458 }
459 }
460
461 match self.retval {
462 CircuitStepArgument::Expr(node) => {
463 let sparse_evals = unsafe { slice_assume_init_ref(sparse_evals) };
465 evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
466 }
467 CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
468 }
469 });
470
471 Ok(())
472 }
473}
474
475fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
479 left: &CircuitStepArgument<F>,
480 right: &CircuitStepArgument<F>,
481 batch_query: &RowsBatchRef<P>,
482 evals_before: &[P],
483 current_evals: &mut [MaybeUninit<P>],
484 op: impl Fn(P, P, &mut MaybeUninit<P>),
485) {
486 let row_len = current_evals.len();
487
488 match (left, right) {
489 (CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
490 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
491 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
492 for j in 0..row_len {
493 unsafe {
495 op(
496 *left.get_unchecked(j),
497 *right.get_unchecked(j),
498 current_evals.get_unchecked_mut(j),
499 )
500 }
501 }
502 }
503 (CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
504 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
505 let right = P::broadcast((*right).into());
506 for j in 0..row_len {
507 unsafe {
509 op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
510 }
511 }
512 }
513 (CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
514 let left = P::broadcast((*left).into());
515 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
516 for j in 0..row_len {
517 unsafe {
519 op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
520 }
521 }
522 }
523 (CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
524 let left = P::broadcast((*left).into());
525 let right = P::broadcast((*right).into());
526 let mut result = MaybeUninit::uninit();
527 op(left, right, &mut result);
528 for j in 0..row_len {
529 unsafe {
533 current_evals
534 .get_unchecked_mut(j)
535 .write(result.assume_init());
536 }
537 }
538 }
539 }
540}
541
542fn alloc_scratch_space<T, U, F>(size: usize, callback: F) -> U
543where
544 F: FnOnce(&mut [MaybeUninit<T>]) -> U,
545{
546 use std::mem;
547 assert!(!mem::needs_drop::<T>());
549
550 #[cfg(miri)]
551 {
552 let mut scratch_space = Vec::<T>::with_capacity(size);
553 let out = callback(scratch_space.spare_capacity_mut());
554 drop(scratch_space);
555 out
556 }
557 #[cfg(not(miri))]
558 {
559 let size = size.max(1);
561 stackalloc::stackalloc_uninit(size, callback)
562 }
563}
564
565#[cfg(test)]
566mod tests {
567 use binius_field::{
568 BinaryField8b, BinaryField16b, PackedBinaryField8x16b, PackedField, TowerField,
569 };
570 use binius_math::{ArithExpr, CompositionPoly, RowsBatch};
571 use binius_utils::felts;
572
573 use super::*;
574
575 #[test]
576 fn test_constant() {
577 type F = BinaryField8b;
578 type P = PackedBinaryField8x16b;
579
580 let expr = ArithExpr::Const(F::new(123));
581 let circuit = ArithCircuitPoly::<F>::new(expr.into());
582
583 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
584 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
585 assert_eq!(typed_circuit.degree(), 0);
586 assert_eq!(typed_circuit.n_vars(), 0);
587
588 assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
589
590 let mut evals = [P::default()];
591 typed_circuit
592 .batch_evaluate(&RowsBatchRef::new(&[], 1), &mut evals)
593 .unwrap();
594 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
595 }
596
597 #[test]
598 fn test_identity() {
599 type F = BinaryField8b;
600 type P = PackedBinaryField8x16b;
601
602 let expr = ArithExpr::Var(0);
604 let circuit = ArithCircuitPoly::<F>::new(expr.into());
605
606 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
607 assert_eq!(typed_circuit.binary_tower_level(), 0);
608 assert_eq!(typed_circuit.degree(), 1);
609 assert_eq!(typed_circuit.n_vars(), 1);
610
611 assert_eq!(
612 typed_circuit
613 .evaluate(&[P::broadcast(F::new(123).into())])
614 .unwrap(),
615 P::broadcast(F::new(123).into())
616 );
617
618 let mut evals = [P::default()];
619 let batch_query = [[P::broadcast(F::new(123).into())]; 1];
620 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
621 typed_circuit
622 .batch_evaluate(&batch_query.get_ref(), &mut evals)
623 .unwrap();
624 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
625 }
626
627 #[test]
628 fn test_add() {
629 type F = BinaryField8b;
630 type P = PackedBinaryField8x16b;
631
632 let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
634 let circuit = ArithCircuitPoly::<F>::new(expr.into());
635
636 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
637 assert_eq!(typed_circuit.binary_tower_level(), 3);
638 assert_eq!(typed_circuit.degree(), 1);
639 assert_eq!(typed_circuit.n_vars(), 1);
640
641 assert_eq!(
642 CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
643 P::broadcast(F::new(123).into())
644 );
645 }
646
647 #[test]
648 fn test_mul() {
649 type F = BinaryField8b;
650 type P = PackedBinaryField8x16b;
651
652 let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
654 let circuit = ArithCircuitPoly::<F>::new(expr.into());
655
656 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
657 assert_eq!(typed_circuit.binary_tower_level(), 3);
658 assert_eq!(typed_circuit.degree(), 1);
659 assert_eq!(typed_circuit.n_vars(), 1);
660
661 assert_eq!(
662 CompositionPoly::evaluate(
663 &circuit,
664 &[P::from_scalars(
665 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
666 )]
667 )
668 .unwrap(),
669 P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
670 );
671 }
672
673 #[test]
674 fn test_pow() {
675 type F = BinaryField8b;
676 type P = PackedBinaryField8x16b;
677
678 let expr = ArithExpr::Var(0).pow(13);
680 let circuit = ArithCircuitPoly::<F>::new(expr.into());
681
682 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
683 assert_eq!(typed_circuit.binary_tower_level(), 0);
684 assert_eq!(typed_circuit.degree(), 13);
685 assert_eq!(typed_circuit.n_vars(), 1);
686
687 assert_eq!(
688 CompositionPoly::evaluate(
689 &circuit,
690 &[P::from_scalars(
691 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
692 )]
693 )
694 .unwrap(),
695 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
696 );
697 }
698
699 #[test]
700 fn test_mixed() {
701 type F = BinaryField8b;
702 type P = PackedBinaryField8x16b;
703
704 let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
706 let circuit = ArithCircuitPoly::<F>::new(expr.into());
707
708 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
709 assert_eq!(typed_circuit.binary_tower_level(), 3);
710 assert_eq!(typed_circuit.degree(), 3);
711 assert_eq!(typed_circuit.n_vars(), 2);
712
713 assert_eq!(
715 CompositionPoly::evaluate(
716 &circuit,
717 &[
718 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
719 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
720 ]
721 )
722 .unwrap(),
723 P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
724 );
725
726 let query1 = &[
728 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
729 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
730 ];
731 let query2 = &[
732 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
733 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
734 ];
735 let query3 = &[
736 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
737 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
738 ];
739 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
740 let expected2 =
741 P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
742 let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
743
744 let mut batch_result = vec![P::zero(); 3];
745 let batch_query = &[
746 &[query1[0], query2[0], query3[0]],
747 &[query1[1], query2[1], query3[1]],
748 ];
749 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
750
751 CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
752 .unwrap();
753 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
754 }
755
756 #[test]
757 fn batch_evaluate_add_mul() {
758 type F = BinaryField8b;
762 type P = PackedBinaryField8x16b;
763
764 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(0))
765 + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(0)
766 - ArithExpr::Var(0);
767 let circuit = ArithCircuitPoly::<F>::new(expr.into());
768
769 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
770 assert_eq!(typed_circuit.binary_tower_level(), 0);
771 assert_eq!(typed_circuit.degree(), 2);
772 assert_eq!(typed_circuit.n_vars(), 1);
773
774 let mut evals = [P::default(); 1];
775 let batch_query = [[P::broadcast(F::new(1).into())]; 1];
776 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
777 typed_circuit
778 .batch_evaluate(&batch_query.get_ref(), &mut evals)
779 .unwrap();
780 }
781
782 #[test]
783 fn test_const_fold() {
784 type F = BinaryField8b;
785 type P = PackedBinaryField8x16b;
786
787 let expr = ArithExpr::Var(0)
789 * ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
790 + (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
791 + ArithExpr::Var(1);
792 let circuit = ArithCircuitPoly::<F>::new(expr.into());
793 assert_eq!(circuit.steps.len(), 2);
794
795 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
796 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
797 assert_eq!(typed_circuit.degree(), 1);
798 assert_eq!(typed_circuit.n_vars(), 2);
799
800 assert_eq!(
802 CompositionPoly::evaluate(
803 &circuit,
804 &[
805 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
806 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
807 ]
808 )
809 .unwrap(),
810 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
811 );
812
813 let query1 = &[
815 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
816 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
817 ];
818 let query2 = &[
819 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
820 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
821 ];
822 let query3 = &[
823 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
824 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
825 ];
826 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
827 let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
828 let expected3 =
829 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
830
831 let mut batch_result = vec![P::zero(); 3];
832 let batch_query = &[
833 &[query1[0], query2[0], query3[0]],
834 &[query1[1], query2[1], query3[1]],
835 ];
836 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
837 CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
838 .unwrap();
839 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
840 }
841
842 #[test]
843 fn test_pow_const_fold() {
844 type F = BinaryField8b;
845 type P = PackedBinaryField8x16b;
846
847 let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
849 let circuit = ArithCircuitPoly::<F>::new(expr.into());
850 assert_eq!(circuit.steps.len(), 1);
851
852 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
853 assert_eq!(typed_circuit.binary_tower_level(), 1);
854 assert_eq!(typed_circuit.degree(), 1);
855 assert_eq!(typed_circuit.n_vars(), 1);
856
857 assert_eq!(
858 CompositionPoly::evaluate(
859 &circuit,
860 &[P::from_scalars(
861 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
862 )]
863 )
864 .unwrap(),
865 P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
866 );
867 }
868
869 #[test]
870 fn test_pow_nested() {
871 type F = BinaryField8b;
872 type P = PackedBinaryField8x16b;
873
874 let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
876 let circuit = ArithCircuitPoly::<F>::new(expr.into());
877 assert_eq!(circuit.steps.len(), 5);
878
879 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
880 assert_eq!(typed_circuit.binary_tower_level(), 0);
881 assert_eq!(typed_circuit.degree(), 24);
882 assert_eq!(typed_circuit.n_vars(), 1);
883
884 assert_eq!(
885 CompositionPoly::evaluate(
886 &circuit,
887 &[P::from_scalars(
888 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
889 )]
890 )
891 .unwrap(),
892 P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
893 );
894 }
895
896 #[test]
897 fn test_circuit_steps_for_expr_constant() {
898 type F = BinaryField8b;
899
900 let expr = ArithExpr::Const(F::new(5));
901 let (steps, retval) = convert_circuit_steps(&expr.into());
902
903 assert!(steps.is_empty(), "No steps should be generated for a constant");
904 assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
905 }
906
907 #[test]
908 fn test_circuit_steps_for_expr_variable() {
909 type F = BinaryField8b;
910
911 let expr = ArithExpr::<F>::Var(18);
912 let (steps, retval) = convert_circuit_steps(&expr.into());
913
914 assert!(steps.is_empty(), "No steps should be generated for a variable");
915 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
916 }
917
918 #[test]
919 fn test_circuit_steps_for_expr_addition() {
920 type F = BinaryField8b;
921
922 let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
923 let (steps, retval) = convert_circuit_steps(&expr.into());
924
925 assert_eq!(steps.len(), 1, "One addition step should be generated");
926 assert!(matches!(
927 steps[0],
928 CircuitStep::Add(
929 CircuitStepArgument::Expr(CircuitNode::Var(14)),
930 CircuitStepArgument::Expr(CircuitNode::Var(56))
931 )
932 ));
933 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
934 }
935
936 #[test]
937 fn test_circuit_steps_for_expr_multiplication() {
938 type F = BinaryField8b;
939
940 let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
941 let (steps, retval) = convert_circuit_steps(&expr.into());
942
943 assert_eq!(steps.len(), 1, "One multiplication step should be generated");
944 assert!(matches!(
945 steps[0],
946 CircuitStep::Mul(
947 CircuitStepArgument::Expr(CircuitNode::Var(36)),
948 CircuitStepArgument::Expr(CircuitNode::Var(26))
949 )
950 ));
951 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
952 }
953
954 #[test]
955 fn test_circuit_steps_for_expr_pow_1() {
956 type F = BinaryField8b;
957
958 let expr = ArithExpr::<F>::Var(12).pow(1);
959 let (steps, retval) = convert_circuit_steps(&expr.into());
960
961 assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
963
964 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
966 }
967
968 #[test]
969 fn test_circuit_steps_for_expr_pow_2() {
970 type F = BinaryField8b;
971
972 let expr = ArithExpr::<F>::Var(10).pow(2);
973 let (steps, retval) = convert_circuit_steps(&expr.into());
974
975 assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
976 assert!(matches!(
977 steps[0],
978 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
979 ));
980 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
981 }
982
983 #[test]
984 fn test_circuit_steps_for_expr_pow_3() {
985 type F = BinaryField8b;
986
987 let expr = ArithExpr::<F>::Var(5).pow(3);
988 let (steps, retval) = convert_circuit_steps(&expr.into());
989
990 assert_eq!(
991 steps.len(),
992 2,
993 "Pow(3) should generate one squaring and one multiplication step"
994 );
995 assert!(matches!(
996 steps[0],
997 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
998 ));
999 assert!(matches!(
1000 steps[1],
1001 CircuitStep::Mul(
1002 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1003 CircuitStepArgument::Expr(CircuitNode::Var(5))
1004 )
1005 ));
1006 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
1007 }
1008
1009 #[test]
1010 fn test_circuit_steps_for_expr_pow_4() {
1011 type F = BinaryField8b;
1012
1013 let expr = ArithExpr::<F>::Var(7).pow(4);
1014 let (steps, retval) = convert_circuit_steps(&expr.into());
1015
1016 assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
1017 assert!(matches!(
1018 steps[0],
1019 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1020 ));
1021
1022 assert!(matches!(
1023 steps[1],
1024 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1025 ));
1026
1027 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
1028 }
1029
1030 #[test]
1031 fn test_circuit_steps_for_expr_pow_5() {
1032 type F = BinaryField8b;
1033
1034 let expr = ArithExpr::<F>::Var(3).pow(5);
1035 let (steps, retval) = convert_circuit_steps(&expr.into());
1036
1037 assert_eq!(
1038 steps.len(),
1039 3,
1040 "Pow(5) should generate two squaring steps and one multiplication"
1041 );
1042 assert!(matches!(
1043 steps[0],
1044 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
1045 ));
1046 assert!(matches!(
1047 steps[1],
1048 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1049 ));
1050 assert!(matches!(
1051 steps[2],
1052 CircuitStep::Mul(
1053 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1054 CircuitStepArgument::Expr(CircuitNode::Var(3))
1055 )
1056 ));
1057
1058 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1059 }
1060
1061 #[test]
1062 fn test_circuit_steps_for_expr_pow_8() {
1063 type F = BinaryField8b;
1064
1065 let expr = ArithExpr::<F>::Var(4).pow(8);
1066 let (steps, retval) = convert_circuit_steps(&expr.into());
1067
1068 assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
1069 assert!(matches!(
1070 steps[0],
1071 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
1072 ));
1073 assert!(matches!(
1074 steps[1],
1075 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1076 ));
1077 assert!(matches!(
1078 steps[2],
1079 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1080 ));
1081
1082 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1083 }
1084
1085 #[test]
1086 fn test_circuit_steps_for_expr_pow_9() {
1087 type F = BinaryField8b;
1088
1089 let expr = ArithExpr::<F>::Var(8).pow(9);
1090 let (steps, retval) = convert_circuit_steps(&expr.into());
1091
1092 assert_eq!(
1093 steps.len(),
1094 4,
1095 "Pow(9) should generate three squaring steps and one multiplication"
1096 );
1097 assert!(matches!(
1098 steps[0],
1099 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
1100 ));
1101 assert!(matches!(
1102 steps[1],
1103 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1104 ));
1105 assert!(matches!(
1106 steps[2],
1107 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1108 ));
1109 assert!(matches!(
1110 steps[3],
1111 CircuitStep::Mul(
1112 CircuitStepArgument::Expr(CircuitNode::Slot(2)),
1113 CircuitStepArgument::Expr(CircuitNode::Var(8))
1114 )
1115 ));
1116
1117 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1118 }
1119
1120 #[test]
1121 fn test_circuit_steps_for_expr_pow_12() {
1122 type F = BinaryField8b;
1123 let expr = ArithExpr::<F>::Var(6).pow(12);
1124 let (steps, retval) = convert_circuit_steps(&expr.into());
1125
1126 assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1127
1128 assert!(matches!(
1129 steps[0],
1130 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1131 ));
1132 assert!(matches!(
1133 steps[1],
1134 CircuitStep::Mul(
1135 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1136 CircuitStepArgument::Expr(CircuitNode::Var(6))
1137 )
1138 ));
1139 assert!(matches!(
1140 steps[2],
1141 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1142 ));
1143 assert!(matches!(
1144 steps[3],
1145 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1146 ));
1147
1148 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1149 }
1150
1151 #[test]
1152 fn test_circuit_steps_for_expr_pow_13() {
1153 type F = BinaryField8b;
1154 let expr = ArithExpr::<F>::Var(7).pow(13);
1155 let (steps, retval) = convert_circuit_steps(&expr.into());
1156
1157 assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1158 assert!(matches!(
1159 steps[0],
1160 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1161 ));
1162 assert!(matches!(
1163 steps[1],
1164 CircuitStep::Mul(
1165 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1166 CircuitStepArgument::Expr(CircuitNode::Var(7))
1167 )
1168 ));
1169 assert!(matches!(
1170 steps[2],
1171 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1172 ));
1173 assert!(matches!(
1174 steps[3],
1175 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1176 ));
1177 assert!(matches!(
1178 steps[4],
1179 CircuitStep::Mul(
1180 CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1181 CircuitStepArgument::Expr(CircuitNode::Var(7))
1182 )
1183 ));
1184 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1185 }
1186
1187 #[test]
1188 fn test_circuit_steps_for_expr_complex() {
1189 type F = BinaryField8b;
1190
1191 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1192 + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1193 - ArithExpr::Var(3);
1194
1195 let (steps, retval) = convert_circuit_steps(&expr.into());
1196
1197 assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1198
1199 assert!(
1200 matches!(
1201 steps[0],
1202 CircuitStep::Mul(
1203 CircuitStepArgument::Expr(CircuitNode::Var(0)),
1204 CircuitStepArgument::Expr(CircuitNode::Var(1))
1205 )
1206 ),
1207 "First step should be multiplication x0 * x1"
1208 );
1209
1210 assert!(
1211 matches!(
1212 steps[1],
1213 CircuitStep::Add(
1214 CircuitStepArgument::Const(F::ONE),
1215 CircuitStepArgument::Expr(CircuitNode::Var(0))
1216 )
1217 ),
1218 "Second step should be (1 - x0)"
1219 );
1220
1221 assert!(
1222 matches!(
1223 steps[2],
1224 CircuitStep::AddMul(
1225 0,
1226 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1227 CircuitStepArgument::Expr(CircuitNode::Var(2))
1228 )
1229 ),
1230 "Third step should be (1 - x0) * x2"
1231 );
1232
1233 assert!(
1234 matches!(
1235 steps[3],
1236 CircuitStep::Add(
1237 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1238 CircuitStepArgument::Expr(CircuitNode::Var(3))
1239 )
1240 ),
1241 "Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1242 );
1243
1244 assert!(
1245 matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1246 "Final result should be stored in Slot(3)"
1247 );
1248 }
1249
1250 #[test]
1251 fn check_deduplication_in_steps() {
1252 type F = BinaryField8b;
1253
1254 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1255 + (ArithExpr::<F>::Var(0) * ArithExpr::Var(1)) * ArithExpr::Var(2)
1256 - ArithExpr::Var(3);
1257 let expr = ArithCircuit::from(&expr);
1258 let expr = expr.optimize();
1259
1260 let (steps, retval) = convert_circuit_steps(&expr);
1261
1262 assert_eq!(steps.len(), 3, "Expression should generate 3 computation steps");
1263
1264 assert!(
1265 matches!(
1266 steps[0],
1267 CircuitStep::Mul(
1268 CircuitStepArgument::Expr(CircuitNode::Var(0)),
1269 CircuitStepArgument::Expr(CircuitNode::Var(1))
1270 )
1271 ),
1272 "First step should be multiplication x0 * x1"
1273 );
1274
1275 assert!(
1276 matches!(
1277 steps[1],
1278 CircuitStep::AddMul(
1279 0,
1280 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1281 CircuitStepArgument::Expr(CircuitNode::Var(2))
1282 )
1283 ),
1284 "Second step should be (x0 * x1) * x2"
1285 );
1286
1287 assert!(
1288 matches!(
1289 steps[2],
1290 CircuitStep::Add(
1291 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1292 CircuitStepArgument::Expr(CircuitNode::Var(3))
1293 )
1294 ),
1295 "Third step should be x0 * x1 + (x0 * x1) * x2 + x3"
1296 );
1297
1298 assert!(
1299 matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))),
1300 "Final result should be stored in Slot(2)"
1301 );
1302 }
1303}