1use std::{mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithCircuit, ArithCircuitStep, CompositionPoly, Error, RowsBatchRef};
7use binius_utils::{
8 DeserializeBytes, SerializationError, SerializationMode, SerializeBytes, bail,
9 mem::{slice_assume_init_mut, slice_assume_init_ref},
10};
11
12fn convert_circuit_steps<F: Field>(
14 expr: &ArithCircuit<F>,
15) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
16 struct StepsMapping {
18 original_to_converted: Vec<Option<usize>>,
19 converted_to_original: Vec<Option<usize>>,
20 }
21
22 impl StepsMapping {
23 fn new(original_size: usize) -> Self {
24 Self {
25 original_to_converted: vec![None; original_size],
26 converted_to_original: vec![None; original_size],
28 }
29 }
30
31 fn register(&mut self, original_step: usize, converted_step: usize) {
32 self.original_to_converted[original_step] = Some(converted_step);
33
34 if converted_step >= self.converted_to_original.len() {
35 self.converted_to_original.resize(converted_step + 1, None);
36 }
37 self.converted_to_original[converted_step] = Some(original_step);
38 }
39
40 fn clear_step(&mut self, converted_step: usize) {
41 if self.converted_to_original.len() <= converted_step {
42 return;
43 }
44
45 if let Some(original_step) = self.converted_to_original[converted_step].take() {
46 self.original_to_converted[original_step] = None;
47 }
48 }
49
50 fn get_converted_step(&self, original_step: usize) -> Option<usize> {
51 self.original_to_converted[original_step]
52 }
53 }
54
55 fn convert_step<F: Field>(
56 original_step: usize,
57 original_steps: &[ArithCircuitStep<F>],
58 result: &mut Vec<CircuitStep<F>>,
59 node_to_step: &mut StepsMapping,
60 ) -> CircuitStepArgument<F> {
61 if let Some(converted_step) = node_to_step.get_converted_step(original_step) {
62 return CircuitStepArgument::Expr(CircuitNode::Slot(converted_step));
63 }
64
65 match &original_steps[original_step] {
66 ArithCircuitStep::Const(constant) => CircuitStepArgument::Const(*constant),
67 ArithCircuitStep::Var(var) => CircuitStepArgument::Expr(CircuitNode::Var(*var)),
68 ArithCircuitStep::Add(left, right) => {
69 let left = convert_step(*left, original_steps, result, node_to_step);
70
71 if let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left {
72 if let ArithCircuitStep::Mul(mleft, mright) = &original_steps[*right] {
73 let mleft = convert_step(*mleft, original_steps, result, node_to_step);
77 let mright = convert_step(*mright, original_steps, result, node_to_step);
78
79 node_to_step.clear_step(left);
82 node_to_step.register(original_step, left);
83 result.push(CircuitStep::AddMul(left, mleft, mright));
84
85 return CircuitStepArgument::Expr(CircuitNode::Slot(left));
86 }
87 }
88
89 let right = convert_step(*right, original_steps, result, node_to_step);
90
91 node_to_step.register(original_step, result.len());
92 result.push(CircuitStep::Add(left, right));
93 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
94 }
95 ArithCircuitStep::Mul(left, right) => {
96 let left = convert_step(*left, original_steps, result, node_to_step);
97 let right = convert_step(*right, original_steps, result, node_to_step);
98
99 node_to_step.register(original_step, result.len());
100 result.push(CircuitStep::Mul(left, right));
101 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
102 }
103 ArithCircuitStep::Pow(base, exp) => {
104 let mut acc = convert_step(*base, original_steps, result, node_to_step);
105 let base_expr = acc;
106 let highest_bit = exp.ilog2();
107
108 for i in (0..highest_bit).rev() {
109 if i == 0 {
110 node_to_step.register(original_step, result.len());
111 }
112 result.push(CircuitStep::Square(acc));
113 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
114
115 if (exp >> i) & 1 != 0 {
116 result.push(CircuitStep::Mul(acc, base_expr));
117 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
118 }
119 }
120
121 acc
122 }
123 }
124 }
125
126 let mut steps = Vec::new();
127 let mut steps_mapping = StepsMapping::new(expr.steps().len());
128 let ret = convert_step(expr.steps().len() - 1, expr.steps(), &mut steps, &mut steps_mapping);
129 (steps, ret)
130}
131
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
134enum CircuitNode {
135 Var(usize),
137 Slot(usize),
139}
140
141impl CircuitNode {
142 fn get_sparse_chunk<'a, P: PackedField>(
145 &'a self,
146 inputs: &'a RowsBatchRef<'a, P>,
147 evals: &'a [P],
148 row_len: usize,
149 ) -> &'a [P] {
150 match self {
151 Self::Var(index) => inputs.row(*index),
152 Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
153 }
154 }
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq)]
158enum CircuitStepArgument<F> {
159 Expr(CircuitNode),
160 Const(F),
161}
162
163#[derive(Debug)]
169enum CircuitStep<F: Field> {
170 Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
171 Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
172 Square(CircuitStepArgument<F>),
173 AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
174}
175
176#[derive(Debug, Clone)]
184pub struct ArithCircuitPoly<F: Field> {
185 expr: ArithCircuit<F>,
186 steps: Arc<[CircuitStep<F>]>,
187 retval: CircuitStepArgument<F>,
189 degree: usize,
190 n_vars: usize,
191 tower_level: usize,
192}
193
194impl<F: Field> PartialEq for ArithCircuitPoly<F> {
195 fn eq(&self, other: &Self) -> bool {
196 self.n_vars == other.n_vars && self.expr == other.expr
197 }
198}
199
200impl<F: Field> Eq for ArithCircuitPoly<F> {}
201
202impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
203 fn serialize(
204 &self,
205 mut write_buf: impl bytes::BufMut,
206 mode: SerializationMode,
207 ) -> Result<(), SerializationError> {
208 (&self.expr, self.n_vars).serialize(&mut write_buf, mode)
209 }
210}
211
212impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
213 fn deserialize(
214 read_buf: impl bytes::Buf,
215 mode: SerializationMode,
216 ) -> Result<Self, SerializationError>
217 where
218 Self: Sized,
219 {
220 let (expr, n_vars) = <(ArithCircuit<F>, usize)>::deserialize(read_buf, mode)?;
221 Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
222 name: "ArithCircuitPoly",
223 })
224 }
225}
226
227impl<F: TowerField> ArithCircuitPoly<F> {
228 pub fn new(mut expr: ArithCircuit<F>) -> Self {
229 expr.optimize_in_place();
230
231 let degree = expr.degree();
232 let n_vars = expr.n_vars();
233 let tower_level = expr.binary_tower_level();
234 let (exprs, retval) = convert_circuit_steps(&expr);
235
236 Self {
237 expr,
238 steps: exprs.into(),
239 retval,
240 degree,
241 n_vars,
242 tower_level,
243 }
244 }
245 pub fn with_n_vars(n_vars: usize, mut expr: ArithCircuit<F>) -> Result<Self, Error> {
250 expr.optimize_in_place();
251
252 let degree = expr.degree();
253 let tower_level = expr.binary_tower_level();
254 if n_vars < expr.n_vars() {
255 return Err(Error::IncorrectNumberOfVariables {
256 expected: expr.n_vars(),
257 actual: n_vars,
258 });
259 }
260 let (steps, retval) = convert_circuit_steps(&expr);
261
262 Ok(Self {
263 expr,
264 steps: steps.into(),
265 retval,
266 n_vars,
267 degree,
268 tower_level,
269 })
270 }
271}
272
273impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
274 for ArithCircuitPoly<F>
275{
276 fn degree(&self) -> usize {
277 self.degree
278 }
279
280 fn n_vars(&self) -> usize {
281 self.n_vars
282 }
283
284 fn binary_tower_level(&self) -> usize {
285 self.tower_level
286 }
287
288 fn expression(&self) -> ArithCircuit<P::Scalar> {
289 self.expr.convert_field()
290 }
291
292 fn evaluate(&self, query: &[P]) -> Result<P, Error> {
293 if query.len() != self.n_vars {
294 return Err(Error::IncorrectQuerySize {
295 expected: self.n_vars,
296 actual: query.len(),
297 });
298 }
299
300 fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
301 unsafe {
304 target.get_unchecked_mut(0).write(value);
305 }
306 }
307
308 alloc_scratch_space::<P, _, _>(self.steps.len(), |evals| {
309 let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
310 CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
313 *query.get_unchecked(index)
314 },
315 CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
318 *evals.get_unchecked(slot)
319 },
320 CircuitStepArgument::Const(value) => P::broadcast(value.into()),
321 };
322
323 for (i, expr) in self.steps.iter().enumerate() {
324 let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
327 let before = unsafe { slice_assume_init_mut(before) };
328 match expr {
329 CircuitStep::Add(x, y) => write_result(
330 after,
331 get_argument_value(*x, before) + get_argument_value(*y, before),
332 ),
333 CircuitStep::AddMul(target_slot, x, y) => {
334 let intermediate =
335 get_argument_value(*x, before) * get_argument_value(*y, before);
336 let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
339 *target_slot += intermediate;
340 }
341 CircuitStep::Mul(x, y) => write_result(
342 after,
343 get_argument_value(*x, before) * get_argument_value(*y, before),
344 ),
345 CircuitStep::Square(x) => {
346 write_result(after, get_argument_value(*x, before).square())
347 }
348 };
349 }
350
351 unsafe {
354 let evals = slice_assume_init_ref(evals);
355 Ok(get_argument_value(self.retval, evals))
356 }
357 })
358 }
359
360 fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
361 let row_len = evals.len();
362 if batch_query.row_len() != row_len {
363 bail!(Error::BatchEvaluateSizeMismatch {
364 expected: row_len,
365 actual: batch_query.row_len(),
366 });
367 }
368
369 alloc_scratch_space::<P, (), _>(self.steps.len() * row_len, |sparse_evals| {
370 for (i, expr) in self.steps.iter().enumerate() {
371 let (before, current) = sparse_evals.split_at_mut(i * row_len);
372
373 let before = unsafe { slice_assume_init_mut(before) };
376 let current = &mut current[..row_len];
377
378 match expr {
379 CircuitStep::Add(left, right) => {
380 apply_binary_op(
381 left,
382 right,
383 batch_query,
384 before,
385 current,
386 |left, right, out| {
387 out.write(left + right);
388 },
389 );
390 }
391 CircuitStep::Mul(left, right) => {
392 apply_binary_op(
393 left,
394 right,
395 batch_query,
396 before,
397 current,
398 |left, right, out| {
399 out.write(left * right);
400 },
401 );
402 }
403 CircuitStep::Square(arg) => {
404 match arg {
405 CircuitStepArgument::Expr(node) => {
406 let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
407 for j in 0..row_len {
408 unsafe {
411 current
412 .get_unchecked_mut(j)
413 .write(id_chunk.get_unchecked(j).square());
414 }
415 }
416 }
417 CircuitStepArgument::Const(value) => {
418 let value: P = P::broadcast((*value).into());
419 let result = value.square();
420 for j in 0..row_len {
421 unsafe {
423 current.get_unchecked_mut(j).write(result);
424 }
425 }
426 }
427 }
428 }
429 CircuitStep::AddMul(target, left, right) => {
430 let target = &mut before[row_len * target..(target + 1) * row_len];
431 let target: &mut [MaybeUninit<P>] = unsafe {
434 std::slice::from_raw_parts_mut(
435 target.as_mut_ptr() as *mut MaybeUninit<P>,
436 target.len(),
437 )
438 };
439 apply_binary_op(
440 left,
441 right,
442 batch_query,
443 before,
444 target,
445 |left, right, out| unsafe {
448 let out = out.assume_init_mut();
449 *out += left * right;
450 },
451 );
452 }
453 }
454 }
455
456 match self.retval {
457 CircuitStepArgument::Expr(node) => {
458 let sparse_evals = unsafe { slice_assume_init_ref(sparse_evals) };
460 evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
461 }
462 CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
463 }
464 });
465
466 Ok(())
467 }
468}
469
470fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
474 left: &CircuitStepArgument<F>,
475 right: &CircuitStepArgument<F>,
476 batch_query: &RowsBatchRef<P>,
477 evals_before: &[P],
478 current_evals: &mut [MaybeUninit<P>],
479 op: impl Fn(P, P, &mut MaybeUninit<P>),
480) {
481 let row_len = current_evals.len();
482
483 match (left, right) {
484 (CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
485 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
486 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
487 for j in 0..row_len {
488 unsafe {
490 op(
491 *left.get_unchecked(j),
492 *right.get_unchecked(j),
493 current_evals.get_unchecked_mut(j),
494 )
495 }
496 }
497 }
498 (CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
499 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
500 let right = P::broadcast((*right).into());
501 for j in 0..row_len {
502 unsafe {
504 op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
505 }
506 }
507 }
508 (CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
509 let left = P::broadcast((*left).into());
510 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
511 for j in 0..row_len {
512 unsafe {
514 op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
515 }
516 }
517 }
518 (CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
519 let left = P::broadcast((*left).into());
520 let right = P::broadcast((*right).into());
521 let mut result = MaybeUninit::uninit();
522 op(left, right, &mut result);
523 for j in 0..row_len {
524 unsafe {
528 current_evals
529 .get_unchecked_mut(j)
530 .write(result.assume_init());
531 }
532 }
533 }
534 }
535}
536
537fn alloc_scratch_space<T, U, F>(size: usize, callback: F) -> U
538where
539 F: FnOnce(&mut [MaybeUninit<T>]) -> U,
540{
541 use std::mem;
542 assert!(!mem::needs_drop::<T>());
544
545 #[cfg(miri)]
546 {
547 let mut scratch_space = Vec::<T>::with_capacity(size);
548 let out = callback(scratch_space.spare_capacity_mut());
549 drop(scratch_space);
550 out
551 }
552 #[cfg(not(miri))]
553 {
554 let size = size.max(1);
556 stackalloc::stackalloc_uninit(size, callback)
557 }
558}
559
560#[cfg(test)]
561mod tests {
562 use binius_field::{
563 BinaryField8b, BinaryField16b, PackedBinaryField8x16b, PackedField, TowerField,
564 };
565 use binius_math::{ArithExpr, CompositionPoly, RowsBatch};
566 use binius_utils::felts;
567
568 use super::*;
569
570 #[test]
571 fn test_constant() {
572 type F = BinaryField8b;
573 type P = PackedBinaryField8x16b;
574
575 let expr = ArithExpr::Const(F::new(123));
576 let circuit = ArithCircuitPoly::<F>::new(expr.into());
577
578 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
579 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
580 assert_eq!(typed_circuit.degree(), 0);
581 assert_eq!(typed_circuit.n_vars(), 0);
582
583 assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
584
585 let mut evals = [P::default()];
586 typed_circuit
587 .batch_evaluate(&RowsBatchRef::new(&[], 1), &mut evals)
588 .unwrap();
589 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
590 }
591
592 #[test]
593 fn test_identity() {
594 type F = BinaryField8b;
595 type P = PackedBinaryField8x16b;
596
597 let expr = ArithExpr::Var(0);
599 let circuit = ArithCircuitPoly::<F>::new(expr.into());
600
601 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
602 assert_eq!(typed_circuit.binary_tower_level(), 0);
603 assert_eq!(typed_circuit.degree(), 1);
604 assert_eq!(typed_circuit.n_vars(), 1);
605
606 assert_eq!(
607 typed_circuit
608 .evaluate(&[P::broadcast(F::new(123).into())])
609 .unwrap(),
610 P::broadcast(F::new(123).into())
611 );
612
613 let mut evals = [P::default()];
614 let batch_query = [[P::broadcast(F::new(123).into())]; 1];
615 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
616 typed_circuit
617 .batch_evaluate(&batch_query.get_ref(), &mut evals)
618 .unwrap();
619 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
620 }
621
622 #[test]
623 fn test_add() {
624 type F = BinaryField8b;
625 type P = PackedBinaryField8x16b;
626
627 let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
629 let circuit = ArithCircuitPoly::<F>::new(expr.into());
630
631 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
632 assert_eq!(typed_circuit.binary_tower_level(), 3);
633 assert_eq!(typed_circuit.degree(), 1);
634 assert_eq!(typed_circuit.n_vars(), 1);
635
636 assert_eq!(
637 CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
638 P::broadcast(F::new(123).into())
639 );
640 }
641
642 #[test]
643 fn test_mul() {
644 type F = BinaryField8b;
645 type P = PackedBinaryField8x16b;
646
647 let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
649 let circuit = ArithCircuitPoly::<F>::new(expr.into());
650
651 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
652 assert_eq!(typed_circuit.binary_tower_level(), 3);
653 assert_eq!(typed_circuit.degree(), 1);
654 assert_eq!(typed_circuit.n_vars(), 1);
655
656 assert_eq!(
657 CompositionPoly::evaluate(
658 &circuit,
659 &[P::from_scalars(
660 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
661 )]
662 )
663 .unwrap(),
664 P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
665 );
666 }
667
668 #[test]
669 fn test_pow() {
670 type F = BinaryField8b;
671 type P = PackedBinaryField8x16b;
672
673 let expr = ArithExpr::Var(0).pow(13);
675 let circuit = ArithCircuitPoly::<F>::new(expr.into());
676
677 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
678 assert_eq!(typed_circuit.binary_tower_level(), 0);
679 assert_eq!(typed_circuit.degree(), 13);
680 assert_eq!(typed_circuit.n_vars(), 1);
681
682 assert_eq!(
683 CompositionPoly::evaluate(
684 &circuit,
685 &[P::from_scalars(
686 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
687 )]
688 )
689 .unwrap(),
690 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
691 );
692 }
693
694 #[test]
695 fn test_mixed() {
696 type F = BinaryField8b;
697 type P = PackedBinaryField8x16b;
698
699 let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
701 let circuit = ArithCircuitPoly::<F>::new(expr.into());
702
703 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
704 assert_eq!(typed_circuit.binary_tower_level(), 3);
705 assert_eq!(typed_circuit.degree(), 3);
706 assert_eq!(typed_circuit.n_vars(), 2);
707
708 assert_eq!(
710 CompositionPoly::evaluate(
711 &circuit,
712 &[
713 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
714 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
715 ]
716 )
717 .unwrap(),
718 P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
719 );
720
721 let query1 = &[
723 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
724 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
725 ];
726 let query2 = &[
727 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
728 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
729 ];
730 let query3 = &[
731 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
732 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
733 ];
734 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
735 let expected2 =
736 P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
737 let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
738
739 let mut batch_result = vec![P::zero(); 3];
740 let batch_query = &[
741 &[query1[0], query2[0], query3[0]],
742 &[query1[1], query2[1], query3[1]],
743 ];
744 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
745
746 CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
747 .unwrap();
748 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
749 }
750
751 #[test]
752 fn batch_evaluate_add_mul() {
753 type F = BinaryField8b;
757 type P = PackedBinaryField8x16b;
758
759 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(0))
760 + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(0)
761 - ArithExpr::Var(0);
762 let circuit = ArithCircuitPoly::<F>::new(expr.into());
763
764 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
765 assert_eq!(typed_circuit.binary_tower_level(), 0);
766 assert_eq!(typed_circuit.degree(), 2);
767 assert_eq!(typed_circuit.n_vars(), 1);
768
769 let mut evals = [P::default(); 1];
770 let batch_query = [[P::broadcast(F::new(1).into())]; 1];
771 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
772 typed_circuit
773 .batch_evaluate(&batch_query.get_ref(), &mut evals)
774 .unwrap();
775 }
776
777 #[test]
778 fn test_const_fold() {
779 type F = BinaryField8b;
780 type P = PackedBinaryField8x16b;
781
782 let expr = ArithExpr::Var(0)
784 * ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
785 + (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
786 + ArithExpr::Var(1);
787 let circuit = ArithCircuitPoly::<F>::new(expr.into());
788 assert_eq!(circuit.steps.len(), 2);
789
790 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
791 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
792 assert_eq!(typed_circuit.degree(), 1);
793 assert_eq!(typed_circuit.n_vars(), 2);
794
795 assert_eq!(
797 CompositionPoly::evaluate(
798 &circuit,
799 &[
800 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
801 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
802 ]
803 )
804 .unwrap(),
805 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
806 );
807
808 let query1 = &[
810 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
811 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
812 ];
813 let query2 = &[
814 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
815 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
816 ];
817 let query3 = &[
818 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
819 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
820 ];
821 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
822 let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
823 let expected3 =
824 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
825
826 let mut batch_result = vec![P::zero(); 3];
827 let batch_query = &[
828 &[query1[0], query2[0], query3[0]],
829 &[query1[1], query2[1], query3[1]],
830 ];
831 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
832 CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
833 .unwrap();
834 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
835 }
836
837 #[test]
838 fn test_pow_const_fold() {
839 type F = BinaryField8b;
840 type P = PackedBinaryField8x16b;
841
842 let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
844 let circuit = ArithCircuitPoly::<F>::new(expr.into());
845 assert_eq!(circuit.steps.len(), 1);
846
847 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
848 assert_eq!(typed_circuit.binary_tower_level(), 1);
849 assert_eq!(typed_circuit.degree(), 1);
850 assert_eq!(typed_circuit.n_vars(), 1);
851
852 assert_eq!(
853 CompositionPoly::evaluate(
854 &circuit,
855 &[P::from_scalars(
856 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
857 )]
858 )
859 .unwrap(),
860 P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
861 );
862 }
863
864 #[test]
865 fn test_pow_nested() {
866 type F = BinaryField8b;
867 type P = PackedBinaryField8x16b;
868
869 let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
871 let circuit = ArithCircuitPoly::<F>::new(expr.into());
872 assert_eq!(circuit.steps.len(), 5);
873
874 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
875 assert_eq!(typed_circuit.binary_tower_level(), 0);
876 assert_eq!(typed_circuit.degree(), 24);
877 assert_eq!(typed_circuit.n_vars(), 1);
878
879 assert_eq!(
880 CompositionPoly::evaluate(
881 &circuit,
882 &[P::from_scalars(
883 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
884 )]
885 )
886 .unwrap(),
887 P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
888 );
889 }
890
891 #[test]
892 fn test_circuit_steps_for_expr_constant() {
893 type F = BinaryField8b;
894
895 let expr = ArithExpr::Const(F::new(5));
896 let (steps, retval) = convert_circuit_steps(&expr.into());
897
898 assert!(steps.is_empty(), "No steps should be generated for a constant");
899 assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
900 }
901
902 #[test]
903 fn test_circuit_steps_for_expr_variable() {
904 type F = BinaryField8b;
905
906 let expr = ArithExpr::<F>::Var(18);
907 let (steps, retval) = convert_circuit_steps(&expr.into());
908
909 assert!(steps.is_empty(), "No steps should be generated for a variable");
910 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
911 }
912
913 #[test]
914 fn test_circuit_steps_for_expr_addition() {
915 type F = BinaryField8b;
916
917 let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
918 let (steps, retval) = convert_circuit_steps(&expr.into());
919
920 assert_eq!(steps.len(), 1, "One addition step should be generated");
921 assert!(matches!(
922 steps[0],
923 CircuitStep::Add(
924 CircuitStepArgument::Expr(CircuitNode::Var(14)),
925 CircuitStepArgument::Expr(CircuitNode::Var(56))
926 )
927 ));
928 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
929 }
930
931 #[test]
932 fn test_circuit_steps_for_expr_multiplication() {
933 type F = BinaryField8b;
934
935 let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
936 let (steps, retval) = convert_circuit_steps(&expr.into());
937
938 assert_eq!(steps.len(), 1, "One multiplication step should be generated");
939 assert!(matches!(
940 steps[0],
941 CircuitStep::Mul(
942 CircuitStepArgument::Expr(CircuitNode::Var(36)),
943 CircuitStepArgument::Expr(CircuitNode::Var(26))
944 )
945 ));
946 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
947 }
948
949 #[test]
950 fn test_circuit_steps_for_expr_pow_1() {
951 type F = BinaryField8b;
952
953 let expr = ArithExpr::<F>::Var(12).pow(1);
954 let (steps, retval) = convert_circuit_steps(&expr.into());
955
956 assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
958
959 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
961 }
962
963 #[test]
964 fn test_circuit_steps_for_expr_pow_2() {
965 type F = BinaryField8b;
966
967 let expr = ArithExpr::<F>::Var(10).pow(2);
968 let (steps, retval) = convert_circuit_steps(&expr.into());
969
970 assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
971 assert!(matches!(
972 steps[0],
973 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
974 ));
975 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
976 }
977
978 #[test]
979 fn test_circuit_steps_for_expr_pow_3() {
980 type F = BinaryField8b;
981
982 let expr = ArithExpr::<F>::Var(5).pow(3);
983 let (steps, retval) = convert_circuit_steps(&expr.into());
984
985 assert_eq!(
986 steps.len(),
987 2,
988 "Pow(3) should generate one squaring and one multiplication step"
989 );
990 assert!(matches!(
991 steps[0],
992 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
993 ));
994 assert!(matches!(
995 steps[1],
996 CircuitStep::Mul(
997 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
998 CircuitStepArgument::Expr(CircuitNode::Var(5))
999 )
1000 ));
1001 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
1002 }
1003
1004 #[test]
1005 fn test_circuit_steps_for_expr_pow_4() {
1006 type F = BinaryField8b;
1007
1008 let expr = ArithExpr::<F>::Var(7).pow(4);
1009 let (steps, retval) = convert_circuit_steps(&expr.into());
1010
1011 assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
1012 assert!(matches!(
1013 steps[0],
1014 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1015 ));
1016
1017 assert!(matches!(
1018 steps[1],
1019 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1020 ));
1021
1022 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
1023 }
1024
1025 #[test]
1026 fn test_circuit_steps_for_expr_pow_5() {
1027 type F = BinaryField8b;
1028
1029 let expr = ArithExpr::<F>::Var(3).pow(5);
1030 let (steps, retval) = convert_circuit_steps(&expr.into());
1031
1032 assert_eq!(
1033 steps.len(),
1034 3,
1035 "Pow(5) should generate two squaring steps and one multiplication"
1036 );
1037 assert!(matches!(
1038 steps[0],
1039 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
1040 ));
1041 assert!(matches!(
1042 steps[1],
1043 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1044 ));
1045 assert!(matches!(
1046 steps[2],
1047 CircuitStep::Mul(
1048 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1049 CircuitStepArgument::Expr(CircuitNode::Var(3))
1050 )
1051 ));
1052
1053 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1054 }
1055
1056 #[test]
1057 fn test_circuit_steps_for_expr_pow_8() {
1058 type F = BinaryField8b;
1059
1060 let expr = ArithExpr::<F>::Var(4).pow(8);
1061 let (steps, retval) = convert_circuit_steps(&expr.into());
1062
1063 assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
1064 assert!(matches!(
1065 steps[0],
1066 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
1067 ));
1068 assert!(matches!(
1069 steps[1],
1070 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1071 ));
1072 assert!(matches!(
1073 steps[2],
1074 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1075 ));
1076
1077 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1078 }
1079
1080 #[test]
1081 fn test_circuit_steps_for_expr_pow_9() {
1082 type F = BinaryField8b;
1083
1084 let expr = ArithExpr::<F>::Var(8).pow(9);
1085 let (steps, retval) = convert_circuit_steps(&expr.into());
1086
1087 assert_eq!(
1088 steps.len(),
1089 4,
1090 "Pow(9) should generate three squaring steps and one multiplication"
1091 );
1092 assert!(matches!(
1093 steps[0],
1094 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
1095 ));
1096 assert!(matches!(
1097 steps[1],
1098 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1099 ));
1100 assert!(matches!(
1101 steps[2],
1102 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1103 ));
1104 assert!(matches!(
1105 steps[3],
1106 CircuitStep::Mul(
1107 CircuitStepArgument::Expr(CircuitNode::Slot(2)),
1108 CircuitStepArgument::Expr(CircuitNode::Var(8))
1109 )
1110 ));
1111
1112 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1113 }
1114
1115 #[test]
1116 fn test_circuit_steps_for_expr_pow_12() {
1117 type F = BinaryField8b;
1118 let expr = ArithExpr::<F>::Var(6).pow(12);
1119 let (steps, retval) = convert_circuit_steps(&expr.into());
1120
1121 assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1122
1123 assert!(matches!(
1124 steps[0],
1125 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1126 ));
1127 assert!(matches!(
1128 steps[1],
1129 CircuitStep::Mul(
1130 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1131 CircuitStepArgument::Expr(CircuitNode::Var(6))
1132 )
1133 ));
1134 assert!(matches!(
1135 steps[2],
1136 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1137 ));
1138 assert!(matches!(
1139 steps[3],
1140 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1141 ));
1142
1143 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1144 }
1145
1146 #[test]
1147 fn test_circuit_steps_for_expr_pow_13() {
1148 type F = BinaryField8b;
1149 let expr = ArithExpr::<F>::Var(7).pow(13);
1150 let (steps, retval) = convert_circuit_steps(&expr.into());
1151
1152 assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1153 assert!(matches!(
1154 steps[0],
1155 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1156 ));
1157 assert!(matches!(
1158 steps[1],
1159 CircuitStep::Mul(
1160 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1161 CircuitStepArgument::Expr(CircuitNode::Var(7))
1162 )
1163 ));
1164 assert!(matches!(
1165 steps[2],
1166 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1167 ));
1168 assert!(matches!(
1169 steps[3],
1170 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1171 ));
1172 assert!(matches!(
1173 steps[4],
1174 CircuitStep::Mul(
1175 CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1176 CircuitStepArgument::Expr(CircuitNode::Var(7))
1177 )
1178 ));
1179 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1180 }
1181
1182 #[test]
1183 fn test_circuit_steps_for_expr_complex() {
1184 type F = BinaryField8b;
1185
1186 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1187 + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1188 - ArithExpr::Var(3);
1189
1190 let (steps, retval) = convert_circuit_steps(&expr.into());
1191
1192 assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1193
1194 assert!(
1195 matches!(
1196 steps[0],
1197 CircuitStep::Mul(
1198 CircuitStepArgument::Expr(CircuitNode::Var(0)),
1199 CircuitStepArgument::Expr(CircuitNode::Var(1))
1200 )
1201 ),
1202 "First step should be multiplication x0 * x1"
1203 );
1204
1205 assert!(
1206 matches!(
1207 steps[1],
1208 CircuitStep::Add(
1209 CircuitStepArgument::Const(F::ONE),
1210 CircuitStepArgument::Expr(CircuitNode::Var(0))
1211 )
1212 ),
1213 "Second step should be (1 - x0)"
1214 );
1215
1216 assert!(
1217 matches!(
1218 steps[2],
1219 CircuitStep::AddMul(
1220 0,
1221 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1222 CircuitStepArgument::Expr(CircuitNode::Var(2))
1223 )
1224 ),
1225 "Third step should be (1 - x0) * x2"
1226 );
1227
1228 assert!(
1229 matches!(
1230 steps[3],
1231 CircuitStep::Add(
1232 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1233 CircuitStepArgument::Expr(CircuitNode::Var(3))
1234 )
1235 ),
1236 "Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1237 );
1238
1239 assert!(
1240 matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1241 "Final result should be stored in Slot(3)"
1242 );
1243 }
1244
1245 #[test]
1246 fn check_deduplication_in_steps() {
1247 type F = BinaryField8b;
1248
1249 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1250 + (ArithExpr::<F>::Var(0) * ArithExpr::Var(1)) * ArithExpr::Var(2)
1251 - ArithExpr::Var(3);
1252 let expr = ArithCircuit::from(&expr);
1253 let expr = expr.optimize();
1254
1255 let (steps, retval) = convert_circuit_steps(&expr);
1256
1257 assert_eq!(steps.len(), 3, "Expression should generate 3 computation steps");
1258
1259 assert!(
1260 matches!(
1261 steps[0],
1262 CircuitStep::Mul(
1263 CircuitStepArgument::Expr(CircuitNode::Var(0)),
1264 CircuitStepArgument::Expr(CircuitNode::Var(1))
1265 )
1266 ),
1267 "First step should be multiplication x0 * x1"
1268 );
1269
1270 assert!(
1271 matches!(
1272 steps[1],
1273 CircuitStep::AddMul(
1274 0,
1275 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1276 CircuitStepArgument::Expr(CircuitNode::Var(2))
1277 )
1278 ),
1279 "Second step should be (x0 * x1) * x2"
1280 );
1281
1282 assert!(
1283 matches!(
1284 steps[2],
1285 CircuitStep::Add(
1286 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1287 CircuitStepArgument::Expr(CircuitNode::Var(3))
1288 )
1289 ),
1290 "Third step should be x0 * x1 + (x0 * x1) * x2 + x3"
1291 );
1292
1293 assert!(
1294 matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))),
1295 "Final result should be stored in Slot(2)"
1296 );
1297 }
1298}