1use std::{fmt::Debug, mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithCircuit, ArithCircuitStep, CompositionPoly, Error, RowsBatchRef};
7use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
8use stackalloc::{
9 helpers::{slice_assume_init, slice_assume_init_mut},
10 stackalloc_uninit,
11};
12
13fn convert_circuit_steps<F: Field>(
15 expr: &ArithCircuit<F>,
16) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
17 struct StepsMapping {
19 original_to_converted: Vec<Option<usize>>,
20 converted_to_original: Vec<Option<usize>>,
21 }
22
23 impl StepsMapping {
24 fn new(original_size: usize) -> Self {
25 Self {
26 original_to_converted: vec![None; original_size],
27 converted_to_original: vec![None; original_size],
29 }
30 }
31
32 fn register(&mut self, original_step: usize, converted_step: usize) {
33 self.original_to_converted[original_step] = Some(converted_step);
34
35 if converted_step >= self.converted_to_original.len() {
36 self.converted_to_original.resize(converted_step + 1, None);
37 }
38 self.converted_to_original[converted_step] = Some(original_step);
39 }
40
41 fn clear_step(&mut self, converted_step: usize) {
42 if self.converted_to_original.len() <= converted_step {
43 return;
44 }
45
46 if let Some(original_step) = self.converted_to_original[converted_step].take() {
47 self.original_to_converted[original_step] = None;
48 }
49 }
50
51 fn get_converted_step(&self, original_step: usize) -> Option<usize> {
52 self.original_to_converted[original_step]
53 }
54 }
55
56 fn convert_step<F: Field>(
57 original_step: usize,
58 original_steps: &[ArithCircuitStep<F>],
59 result: &mut Vec<CircuitStep<F>>,
60 node_to_step: &mut StepsMapping,
61 ) -> CircuitStepArgument<F> {
62 if let Some(converted_step) = node_to_step.get_converted_step(original_step) {
63 return CircuitStepArgument::Expr(CircuitNode::Slot(converted_step));
64 }
65
66 match &original_steps[original_step] {
67 ArithCircuitStep::Const(constant) => CircuitStepArgument::Const(*constant),
68 ArithCircuitStep::Var(var) => CircuitStepArgument::Expr(CircuitNode::Var(*var)),
69 ArithCircuitStep::Add(left, right) => {
70 let left = convert_step(*left, original_steps, result, node_to_step);
71
72 if let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left {
73 if let ArithCircuitStep::Mul(mleft, mright) = &original_steps[*right] {
74 let mleft = convert_step(*mleft, original_steps, result, node_to_step);
77 let mright = convert_step(*mright, original_steps, result, node_to_step);
78
79 node_to_step.clear_step(left);
81 node_to_step.register(original_step, left);
82 result.push(CircuitStep::AddMul(left, mleft, mright));
83
84 return CircuitStepArgument::Expr(CircuitNode::Slot(left));
85 }
86 }
87
88 let right = convert_step(*right, original_steps, result, node_to_step);
89
90 node_to_step.register(original_step, result.len());
91 result.push(CircuitStep::Add(left, right));
92 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
93 }
94 ArithCircuitStep::Mul(left, right) => {
95 let left = convert_step(*left, original_steps, result, node_to_step);
96 let right = convert_step(*right, original_steps, result, node_to_step);
97
98 node_to_step.register(original_step, result.len());
99 result.push(CircuitStep::Mul(left, right));
100 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
101 }
102 ArithCircuitStep::Pow(base, exp) => {
103 let mut acc = convert_step(*base, original_steps, result, node_to_step);
104 let base_expr = acc;
105 let highest_bit = exp.ilog2();
106
107 for i in (0..highest_bit).rev() {
108 if i == 0 {
109 node_to_step.register(original_step, result.len());
110 }
111 result.push(CircuitStep::Square(acc));
112 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
113
114 if (exp >> i) & 1 != 0 {
115 result.push(CircuitStep::Mul(acc, base_expr));
116 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
117 }
118 }
119
120 acc
121 }
122 }
123 }
124
125 let mut steps = Vec::new();
126 let mut steps_mapping = StepsMapping::new(expr.steps().len());
127 let ret = convert_step(expr.steps().len() - 1, expr.steps(), &mut steps, &mut steps_mapping);
128 (steps, ret)
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133enum CircuitNode {
134 Var(usize),
136 Slot(usize),
138}
139
140impl CircuitNode {
141 fn get_sparse_chunk<'a, P: PackedField>(
144 &'a self,
145 inputs: &'a RowsBatchRef<'a, P>,
146 evals: &'a [P],
147 row_len: usize,
148 ) -> &'a [P] {
149 match self {
150 Self::Var(index) => inputs.rows()[*index],
151 Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
152 }
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157enum CircuitStepArgument<F> {
158 Expr(CircuitNode),
159 Const(F),
160}
161
162#[derive(Debug)]
167enum CircuitStep<F: Field> {
168 Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
169 Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
170 Square(CircuitStepArgument<F>),
171 AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
172}
173
174#[derive(Debug, Clone)]
181pub struct ArithCircuitPoly<F: Field> {
182 expr: ArithCircuit<F>,
183 steps: Arc<[CircuitStep<F>]>,
184 retval: CircuitStepArgument<F>,
186 degree: usize,
187 n_vars: usize,
188 tower_level: usize,
189}
190
191impl<F: Field> PartialEq for ArithCircuitPoly<F> {
192 fn eq(&self, other: &Self) -> bool {
193 self.n_vars == other.n_vars && self.expr == other.expr
194 }
195}
196
197impl<F: Field> Eq for ArithCircuitPoly<F> {}
198
199impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
200 fn serialize(
201 &self,
202 mut write_buf: impl bytes::BufMut,
203 mode: SerializationMode,
204 ) -> Result<(), SerializationError> {
205 (&self.expr, self.n_vars).serialize(&mut write_buf, mode)
206 }
207}
208
209impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
210 fn deserialize(
211 read_buf: impl bytes::Buf,
212 mode: SerializationMode,
213 ) -> Result<Self, SerializationError>
214 where
215 Self: Sized,
216 {
217 let (expr, n_vars) = <(ArithCircuit<F>, usize)>::deserialize(read_buf, mode)?;
218 Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
219 name: "ArithCircuitPoly",
220 })
221 }
222}
223
224impl<F: TowerField> ArithCircuitPoly<F> {
225 pub fn new(mut expr: ArithCircuit<F>) -> Self {
226 expr.optimize_in_place();
227
228 let degree = expr.degree();
229 let n_vars = expr.n_vars();
230 let tower_level = expr.binary_tower_level();
231 let (exprs, retval) = convert_circuit_steps(&expr);
232
233 Self {
234 expr,
235 steps: exprs.into(),
236 retval,
237 degree,
238 n_vars,
239 tower_level,
240 }
241 }
242 pub fn with_n_vars(n_vars: usize, mut expr: ArithCircuit<F>) -> Result<Self, Error> {
247 expr.optimize_in_place();
248
249 let degree = expr.degree();
250 let tower_level = expr.binary_tower_level();
251 if n_vars < expr.n_vars() {
252 return Err(Error::IncorrectNumberOfVariables {
253 expected: expr.n_vars(),
254 actual: n_vars,
255 });
256 }
257 let (steps, retval) = convert_circuit_steps(&expr);
258
259 Ok(Self {
260 expr,
261 steps: steps.into(),
262 retval,
263 n_vars,
264 degree,
265 tower_level,
266 })
267 }
268}
269
270impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
271 for ArithCircuitPoly<F>
272{
273 fn degree(&self) -> usize {
274 self.degree
275 }
276
277 fn n_vars(&self) -> usize {
278 self.n_vars
279 }
280
281 fn binary_tower_level(&self) -> usize {
282 self.tower_level
283 }
284
285 fn expression(&self) -> ArithCircuit<P::Scalar> {
286 self.expr.convert_field()
287 }
288
289 fn evaluate(&self, query: &[P]) -> Result<P, Error> {
290 if query.len() != self.n_vars {
291 return Err(Error::IncorrectQuerySize {
292 expected: self.n_vars,
293 });
294 }
295
296 fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
297 unsafe {
300 target.get_unchecked_mut(0).write(value);
301 }
302 }
303
304 stackalloc_uninit::<P, _, _>(self.steps.len().max(1), |evals| {
306 let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
307 CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
309 *query.get_unchecked(index)
310 },
311 CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
313 *evals.get_unchecked(slot)
314 },
315 CircuitStepArgument::Const(value) => P::broadcast(value.into()),
316 };
317
318 for (i, expr) in self.steps.iter().enumerate() {
319 let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
321 let before = unsafe { slice_assume_init_mut(before) };
322 match expr {
323 CircuitStep::Add(x, y) => write_result(
324 after,
325 get_argument_value(*x, before) + get_argument_value(*y, before),
326 ),
327 CircuitStep::AddMul(target_slot, x, y) => {
328 let intermediate =
329 get_argument_value(*x, before) * get_argument_value(*y, before);
330 let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
332 *target_slot += intermediate;
333 }
334 CircuitStep::Mul(x, y) => write_result(
335 after,
336 get_argument_value(*x, before) * get_argument_value(*y, before),
337 ),
338 CircuitStep::Square(x) => {
339 write_result(after, get_argument_value(*x, before).square())
340 }
341 };
342 }
343
344 unsafe {
347 let evals = slice_assume_init(evals);
348 Ok(get_argument_value(self.retval, evals))
349 }
350 })
351 }
352
353 fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
354 let row_len = evals.len();
355 if batch_query.row_len() != row_len {
356 bail!(Error::BatchEvaluateSizeMismatch {
357 expected: row_len,
358 actual: batch_query.row_len(),
359 });
360 }
361
362 stackalloc_uninit::<P, (), _>((self.steps.len() * row_len).max(1), |sparse_evals| {
364 for (i, expr) in self.steps.iter().enumerate() {
365 let (before, current) = sparse_evals.split_at_mut(i * row_len);
366
367 let before = unsafe { slice_assume_init_mut(before) };
369 let current = &mut current[..row_len];
370
371 match expr {
372 CircuitStep::Add(left, right) => {
373 apply_binary_op(
374 left,
375 right,
376 batch_query,
377 before,
378 current,
379 |left, right, out| {
380 out.write(left + right);
381 },
382 );
383 }
384 CircuitStep::Mul(left, right) => {
385 apply_binary_op(
386 left,
387 right,
388 batch_query,
389 before,
390 current,
391 |left, right, out| {
392 out.write(left * right);
393 },
394 );
395 }
396 CircuitStep::Square(arg) => {
397 match arg {
398 CircuitStepArgument::Expr(node) => {
399 let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
400 for j in 0..row_len {
401 unsafe {
403 current
404 .get_unchecked_mut(j)
405 .write(id_chunk.get_unchecked(j).square());
406 }
407 }
408 }
409 CircuitStepArgument::Const(value) => {
410 let value: P = P::broadcast((*value).into());
411 let result = value.square();
412 for j in 0..row_len {
413 unsafe {
415 current.get_unchecked_mut(j).write(result);
416 }
417 }
418 }
419 }
420 }
421 CircuitStep::AddMul(target, left, right) => {
422 let target = &before[row_len * target..(target + 1) * row_len];
423 let target: &mut [MaybeUninit<P>] = unsafe {
426 std::slice::from_raw_parts_mut(
427 target.as_ptr() as *mut MaybeUninit<P>,
428 target.len(),
429 )
430 };
431 apply_binary_op(
432 left,
433 right,
434 batch_query,
435 before,
436 target,
437 |left, right, out| unsafe {
440 let out = out.assume_init_mut();
441 *out += left * right;
442 },
443 );
444 }
445 }
446 }
447
448 match self.retval {
449 CircuitStepArgument::Expr(node) => {
450 let sparse_evals = unsafe { slice_assume_init(sparse_evals) };
452 evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
453 }
454 CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
455 }
456 });
457
458 Ok(())
459 }
460}
461
462fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
465 left: &CircuitStepArgument<F>,
466 right: &CircuitStepArgument<F>,
467 batch_query: &RowsBatchRef<P>,
468 evals_before: &[P],
469 current_evals: &mut [MaybeUninit<P>],
470 op: impl Fn(P, P, &mut MaybeUninit<P>),
471) {
472 let row_len = current_evals.len();
473
474 match (left, right) {
475 (CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
476 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
477 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
478 for j in 0..row_len {
479 unsafe {
481 op(
482 *left.get_unchecked(j),
483 *right.get_unchecked(j),
484 current_evals.get_unchecked_mut(j),
485 )
486 }
487 }
488 }
489 (CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
490 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
491 let right = P::broadcast((*right).into());
492 for j in 0..row_len {
493 unsafe {
495 op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
496 }
497 }
498 }
499 (CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
500 let left = P::broadcast((*left).into());
501 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
502 for j in 0..row_len {
503 unsafe {
505 op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
506 }
507 }
508 }
509 (CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
510 let left = P::broadcast((*left).into());
511 let right = P::broadcast((*right).into());
512 let mut result = MaybeUninit::uninit();
513 op(left, right, &mut result);
514 for j in 0..row_len {
515 unsafe {
519 current_evals
520 .get_unchecked_mut(j)
521 .write(result.assume_init());
522 }
523 }
524 }
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use binius_field::{
531 BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedField, TowerField,
532 };
533 use binius_math::{ArithExpr, CompositionPoly, RowsBatch};
534 use binius_utils::felts;
535
536 use super::*;
537
538 #[test]
539 fn test_constant() {
540 type F = BinaryField8b;
541 type P = PackedBinaryField8x16b;
542
543 let expr = ArithExpr::Const(F::new(123));
544 let circuit = ArithCircuitPoly::<F>::new(expr.into());
545
546 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
547 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
548 assert_eq!(typed_circuit.degree(), 0);
549 assert_eq!(typed_circuit.n_vars(), 0);
550
551 assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
552
553 let mut evals = [P::default()];
554 typed_circuit
555 .batch_evaluate(&RowsBatchRef::new(&[], 1), &mut evals)
556 .unwrap();
557 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
558 }
559
560 #[test]
561 fn test_identity() {
562 type F = BinaryField8b;
563 type P = PackedBinaryField8x16b;
564
565 let expr = ArithExpr::Var(0);
567 let circuit = ArithCircuitPoly::<F>::new(expr.into());
568
569 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
570 assert_eq!(typed_circuit.binary_tower_level(), 0);
571 assert_eq!(typed_circuit.degree(), 1);
572 assert_eq!(typed_circuit.n_vars(), 1);
573
574 assert_eq!(
575 typed_circuit
576 .evaluate(&[P::broadcast(F::new(123).into())])
577 .unwrap(),
578 P::broadcast(F::new(123).into())
579 );
580
581 let mut evals = [P::default()];
582 let batch_query = [[P::broadcast(F::new(123).into())]; 1];
583 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
584 typed_circuit
585 .batch_evaluate(&batch_query.get_ref(), &mut evals)
586 .unwrap();
587 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
588 }
589
590 #[test]
591 fn test_add() {
592 type F = BinaryField8b;
593 type P = PackedBinaryField8x16b;
594
595 let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
597 let circuit = ArithCircuitPoly::<F>::new(expr.into());
598
599 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
600 assert_eq!(typed_circuit.binary_tower_level(), 3);
601 assert_eq!(typed_circuit.degree(), 1);
602 assert_eq!(typed_circuit.n_vars(), 1);
603
604 assert_eq!(
605 CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
606 P::broadcast(F::new(123).into())
607 );
608 }
609
610 #[test]
611 fn test_mul() {
612 type F = BinaryField8b;
613 type P = PackedBinaryField8x16b;
614
615 let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
617 let circuit = ArithCircuitPoly::<F>::new(expr.into());
618
619 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
620 assert_eq!(typed_circuit.binary_tower_level(), 3);
621 assert_eq!(typed_circuit.degree(), 1);
622 assert_eq!(typed_circuit.n_vars(), 1);
623
624 assert_eq!(
625 CompositionPoly::evaluate(
626 &circuit,
627 &[P::from_scalars(
628 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
629 )]
630 )
631 .unwrap(),
632 P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
633 );
634 }
635
636 #[test]
637 fn test_pow() {
638 type F = BinaryField8b;
639 type P = PackedBinaryField8x16b;
640
641 let expr = ArithExpr::Var(0).pow(13);
643 let circuit = ArithCircuitPoly::<F>::new(expr.into());
644
645 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
646 assert_eq!(typed_circuit.binary_tower_level(), 0);
647 assert_eq!(typed_circuit.degree(), 13);
648 assert_eq!(typed_circuit.n_vars(), 1);
649
650 assert_eq!(
651 CompositionPoly::evaluate(
652 &circuit,
653 &[P::from_scalars(
654 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
655 )]
656 )
657 .unwrap(),
658 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
659 );
660 }
661
662 #[test]
663 fn test_mixed() {
664 type F = BinaryField8b;
665 type P = PackedBinaryField8x16b;
666
667 let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
669 let circuit = ArithCircuitPoly::<F>::new(expr.into());
670
671 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
672 assert_eq!(typed_circuit.binary_tower_level(), 3);
673 assert_eq!(typed_circuit.degree(), 3);
674 assert_eq!(typed_circuit.n_vars(), 2);
675
676 assert_eq!(
678 CompositionPoly::evaluate(
679 &circuit,
680 &[
681 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
682 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
683 ]
684 )
685 .unwrap(),
686 P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
687 );
688
689 let query1 = &[
691 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
692 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
693 ];
694 let query2 = &[
695 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
696 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
697 ];
698 let query3 = &[
699 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
700 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
701 ];
702 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
703 let expected2 =
704 P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
705 let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
706
707 let mut batch_result = vec![P::zero(); 3];
708 let batch_query = &[
709 &[query1[0], query2[0], query3[0]],
710 &[query1[1], query2[1], query3[1]],
711 ];
712 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
713
714 CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
715 .unwrap();
716 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
717 }
718
719 #[test]
720 fn test_const_fold() {
721 type F = BinaryField8b;
722 type P = PackedBinaryField8x16b;
723
724 let expr = ArithExpr::Var(0)
726 * ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
727 + (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
728 + ArithExpr::Var(1);
729 let circuit = ArithCircuitPoly::<F>::new(expr.into());
730 assert_eq!(circuit.steps.len(), 2);
731
732 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
733 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
734 assert_eq!(typed_circuit.degree(), 1);
735 assert_eq!(typed_circuit.n_vars(), 2);
736
737 assert_eq!(
739 CompositionPoly::evaluate(
740 &circuit,
741 &[
742 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
743 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
744 ]
745 )
746 .unwrap(),
747 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
748 );
749
750 let query1 = &[
752 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
753 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
754 ];
755 let query2 = &[
756 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
757 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
758 ];
759 let query3 = &[
760 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
761 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
762 ];
763 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
764 let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
765 let expected3 =
766 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
767
768 let mut batch_result = vec![P::zero(); 3];
769 let batch_query = &[
770 &[query1[0], query2[0], query3[0]],
771 &[query1[1], query2[1], query3[1]],
772 ];
773 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
774 CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
775 .unwrap();
776 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
777 }
778
779 #[test]
780 fn test_pow_const_fold() {
781 type F = BinaryField8b;
782 type P = PackedBinaryField8x16b;
783
784 let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
786 let circuit = ArithCircuitPoly::<F>::new(expr.into());
787 assert_eq!(circuit.steps.len(), 1);
788
789 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
790 assert_eq!(typed_circuit.binary_tower_level(), 1);
791 assert_eq!(typed_circuit.degree(), 1);
792 assert_eq!(typed_circuit.n_vars(), 1);
793
794 assert_eq!(
795 CompositionPoly::evaluate(
796 &circuit,
797 &[P::from_scalars(
798 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
799 )]
800 )
801 .unwrap(),
802 P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
803 );
804 }
805
806 #[test]
807 fn test_pow_nested() {
808 type F = BinaryField8b;
809 type P = PackedBinaryField8x16b;
810
811 let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
813 let circuit = ArithCircuitPoly::<F>::new(expr.into());
814 assert_eq!(circuit.steps.len(), 5);
815
816 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
817 assert_eq!(typed_circuit.binary_tower_level(), 0);
818 assert_eq!(typed_circuit.degree(), 24);
819 assert_eq!(typed_circuit.n_vars(), 1);
820
821 assert_eq!(
822 CompositionPoly::evaluate(
823 &circuit,
824 &[P::from_scalars(
825 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
826 )]
827 )
828 .unwrap(),
829 P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
830 );
831 }
832
833 #[test]
834 fn test_circuit_steps_for_expr_constant() {
835 type F = BinaryField8b;
836
837 let expr = ArithExpr::Const(F::new(5));
838 let (steps, retval) = convert_circuit_steps(&expr.into());
839
840 assert!(steps.is_empty(), "No steps should be generated for a constant");
841 assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
842 }
843
844 #[test]
845 fn test_circuit_steps_for_expr_variable() {
846 type F = BinaryField8b;
847
848 let expr = ArithExpr::<F>::Var(18);
849 let (steps, retval) = convert_circuit_steps(&expr.into());
850
851 assert!(steps.is_empty(), "No steps should be generated for a variable");
852 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
853 }
854
855 #[test]
856 fn test_circuit_steps_for_expr_addition() {
857 type F = BinaryField8b;
858
859 let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
860 let (steps, retval) = convert_circuit_steps(&expr.into());
861
862 assert_eq!(steps.len(), 1, "One addition step should be generated");
863 assert!(matches!(
864 steps[0],
865 CircuitStep::Add(
866 CircuitStepArgument::Expr(CircuitNode::Var(14)),
867 CircuitStepArgument::Expr(CircuitNode::Var(56))
868 )
869 ));
870 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
871 }
872
873 #[test]
874 fn test_circuit_steps_for_expr_multiplication() {
875 type F = BinaryField8b;
876
877 let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
878 let (steps, retval) = convert_circuit_steps(&expr.into());
879
880 assert_eq!(steps.len(), 1, "One multiplication step should be generated");
881 assert!(matches!(
882 steps[0],
883 CircuitStep::Mul(
884 CircuitStepArgument::Expr(CircuitNode::Var(36)),
885 CircuitStepArgument::Expr(CircuitNode::Var(26))
886 )
887 ));
888 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
889 }
890
891 #[test]
892 fn test_circuit_steps_for_expr_pow_1() {
893 type F = BinaryField8b;
894
895 let expr = ArithExpr::<F>::Var(12).pow(1);
896 let (steps, retval) = convert_circuit_steps(&expr.into());
897
898 assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
900
901 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
903 }
904
905 #[test]
906 fn test_circuit_steps_for_expr_pow_2() {
907 type F = BinaryField8b;
908
909 let expr = ArithExpr::<F>::Var(10).pow(2);
910 let (steps, retval) = convert_circuit_steps(&expr.into());
911
912 assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
913 assert!(matches!(
914 steps[0],
915 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
916 ));
917 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
918 }
919
920 #[test]
921 fn test_circuit_steps_for_expr_pow_3() {
922 type F = BinaryField8b;
923
924 let expr = ArithExpr::<F>::Var(5).pow(3);
925 let (steps, retval) = convert_circuit_steps(&expr.into());
926
927 assert_eq!(
928 steps.len(),
929 2,
930 "Pow(3) should generate one squaring and one multiplication step"
931 );
932 assert!(matches!(
933 steps[0],
934 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
935 ));
936 assert!(matches!(
937 steps[1],
938 CircuitStep::Mul(
939 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
940 CircuitStepArgument::Expr(CircuitNode::Var(5))
941 )
942 ));
943 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
944 }
945
946 #[test]
947 fn test_circuit_steps_for_expr_pow_4() {
948 type F = BinaryField8b;
949
950 let expr = ArithExpr::<F>::Var(7).pow(4);
951 let (steps, retval) = convert_circuit_steps(&expr.into());
952
953 assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
954 assert!(matches!(
955 steps[0],
956 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
957 ));
958
959 assert!(matches!(
960 steps[1],
961 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
962 ));
963
964 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
965 }
966
967 #[test]
968 fn test_circuit_steps_for_expr_pow_5() {
969 type F = BinaryField8b;
970
971 let expr = ArithExpr::<F>::Var(3).pow(5);
972 let (steps, retval) = convert_circuit_steps(&expr.into());
973
974 assert_eq!(
975 steps.len(),
976 3,
977 "Pow(5) should generate two squaring steps and one multiplication"
978 );
979 assert!(matches!(
980 steps[0],
981 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
982 ));
983 assert!(matches!(
984 steps[1],
985 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
986 ));
987 assert!(matches!(
988 steps[2],
989 CircuitStep::Mul(
990 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
991 CircuitStepArgument::Expr(CircuitNode::Var(3))
992 )
993 ));
994
995 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
996 }
997
998 #[test]
999 fn test_circuit_steps_for_expr_pow_8() {
1000 type F = BinaryField8b;
1001
1002 let expr = ArithExpr::<F>::Var(4).pow(8);
1003 let (steps, retval) = convert_circuit_steps(&expr.into());
1004
1005 assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
1006 assert!(matches!(
1007 steps[0],
1008 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
1009 ));
1010 assert!(matches!(
1011 steps[1],
1012 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1013 ));
1014 assert!(matches!(
1015 steps[2],
1016 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1017 ));
1018
1019 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
1020 }
1021
1022 #[test]
1023 fn test_circuit_steps_for_expr_pow_9() {
1024 type F = BinaryField8b;
1025
1026 let expr = ArithExpr::<F>::Var(8).pow(9);
1027 let (steps, retval) = convert_circuit_steps(&expr.into());
1028
1029 assert_eq!(
1030 steps.len(),
1031 4,
1032 "Pow(9) should generate three squaring steps and one multiplication"
1033 );
1034 assert!(matches!(
1035 steps[0],
1036 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
1037 ));
1038 assert!(matches!(
1039 steps[1],
1040 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
1041 ));
1042 assert!(matches!(
1043 steps[2],
1044 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1045 ));
1046 assert!(matches!(
1047 steps[3],
1048 CircuitStep::Mul(
1049 CircuitStepArgument::Expr(CircuitNode::Slot(2)),
1050 CircuitStepArgument::Expr(CircuitNode::Var(8))
1051 )
1052 ));
1053
1054 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1055 }
1056
1057 #[test]
1058 fn test_circuit_steps_for_expr_pow_12() {
1059 type F = BinaryField8b;
1060 let expr = ArithExpr::<F>::Var(6).pow(12);
1061 let (steps, retval) = convert_circuit_steps(&expr.into());
1062
1063 assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1064
1065 assert!(matches!(
1066 steps[0],
1067 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1068 ));
1069 assert!(matches!(
1070 steps[1],
1071 CircuitStep::Mul(
1072 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1073 CircuitStepArgument::Expr(CircuitNode::Var(6))
1074 )
1075 ));
1076 assert!(matches!(
1077 steps[2],
1078 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1079 ));
1080 assert!(matches!(
1081 steps[3],
1082 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1083 ));
1084
1085 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1086 }
1087
1088 #[test]
1089 fn test_circuit_steps_for_expr_pow_13() {
1090 type F = BinaryField8b;
1091 let expr = ArithExpr::<F>::Var(7).pow(13);
1092 let (steps, retval) = convert_circuit_steps(&expr.into());
1093
1094 assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1095 assert!(matches!(
1096 steps[0],
1097 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1098 ));
1099 assert!(matches!(
1100 steps[1],
1101 CircuitStep::Mul(
1102 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1103 CircuitStepArgument::Expr(CircuitNode::Var(7))
1104 )
1105 ));
1106 assert!(matches!(
1107 steps[2],
1108 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1109 ));
1110 assert!(matches!(
1111 steps[3],
1112 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1113 ));
1114 assert!(matches!(
1115 steps[4],
1116 CircuitStep::Mul(
1117 CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1118 CircuitStepArgument::Expr(CircuitNode::Var(7))
1119 )
1120 ));
1121 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1122 }
1123
1124 #[test]
1125 fn test_circuit_steps_for_expr_complex() {
1126 type F = BinaryField8b;
1127
1128 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1129 + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1130 - ArithExpr::Var(3);
1131
1132 let (steps, retval) = convert_circuit_steps(&expr.into());
1133
1134 assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1135
1136 assert!(
1137 matches!(
1138 steps[0],
1139 CircuitStep::Mul(
1140 CircuitStepArgument::Expr(CircuitNode::Var(0)),
1141 CircuitStepArgument::Expr(CircuitNode::Var(1))
1142 )
1143 ),
1144 "First step should be multiplication x0 * x1"
1145 );
1146
1147 assert!(
1148 matches!(
1149 steps[1],
1150 CircuitStep::Add(
1151 CircuitStepArgument::Const(F::ONE),
1152 CircuitStepArgument::Expr(CircuitNode::Var(0))
1153 )
1154 ),
1155 "Second step should be (1 - x0)"
1156 );
1157
1158 assert!(
1159 matches!(
1160 steps[2],
1161 CircuitStep::AddMul(
1162 0,
1163 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1164 CircuitStepArgument::Expr(CircuitNode::Var(2))
1165 )
1166 ),
1167 "Third step should be (1 - x0) * x2"
1168 );
1169
1170 assert!(
1171 matches!(
1172 steps[3],
1173 CircuitStep::Add(
1174 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1175 CircuitStepArgument::Expr(CircuitNode::Var(3))
1176 )
1177 ),
1178 "Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1179 );
1180
1181 assert!(
1182 matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1183 "Final result should be stored in Slot(3)"
1184 );
1185 }
1186
1187 #[test]
1188 fn check_deduplication_in_steps() {
1189 type F = BinaryField8b;
1190
1191 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1192 + (ArithExpr::<F>::Var(0) * ArithExpr::Var(1)) * ArithExpr::Var(2)
1193 - ArithExpr::Var(3);
1194 let expr = ArithCircuit::from(&expr);
1195 let expr = expr.optimize();
1196
1197 let (steps, retval) = convert_circuit_steps(&expr);
1198
1199 assert_eq!(steps.len(), 3, "Expression should generate 3 computation steps");
1200
1201 assert!(
1202 matches!(
1203 steps[0],
1204 CircuitStep::Mul(
1205 CircuitStepArgument::Expr(CircuitNode::Var(0)),
1206 CircuitStepArgument::Expr(CircuitNode::Var(1))
1207 )
1208 ),
1209 "First step should be multiplication x0 * x1"
1210 );
1211
1212 assert!(
1213 matches!(
1214 steps[1],
1215 CircuitStep::AddMul(
1216 0,
1217 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1218 CircuitStepArgument::Expr(CircuitNode::Var(2))
1219 )
1220 ),
1221 "Second step should be (x0 * x1) * x2"
1222 );
1223
1224 assert!(
1225 matches!(
1226 steps[2],
1227 CircuitStep::Add(
1228 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1229 CircuitStepArgument::Expr(CircuitNode::Var(3))
1230 )
1231 ),
1232 "Third step should be x0 * x1 + (x0 * x1) * x2 + x3"
1233 );
1234
1235 assert!(
1236 matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))),
1237 "Final result should be stored in Slot(2)"
1238 );
1239 }
1240}