1use std::{fmt::Debug, mem::MaybeUninit, sync::Arc};
4
5use binius_field::{ExtensionField, Field, PackedField, TowerField};
6use binius_math::{ArithExpr, CompositionPoly, Error, RowsBatchRef};
7use binius_utils::{bail, DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
8use stackalloc::{
9 helpers::{slice_assume_init, slice_assume_init_mut},
10 stackalloc_uninit,
11};
12
13fn circuit_steps_for_expr<F: Field>(
15 expr: &ArithExpr<F>,
16) -> (Vec<CircuitStep<F>>, CircuitStepArgument<F>) {
17 let mut steps = Vec::new();
18
19 fn to_circuit_inner<F: Field>(
20 expr: &ArithExpr<F>,
21 result: &mut Vec<CircuitStep<F>>,
22 ) -> CircuitStepArgument<F> {
23 match expr {
24 ArithExpr::Const(value) => CircuitStepArgument::Const(*value),
25 ArithExpr::Var(index) => CircuitStepArgument::Expr(CircuitNode::Var(*index)),
26 ArithExpr::Add(left, right) => match &**right {
27 ArithExpr::Mul(mleft, mright) if left.is_composite() => {
28 let left = to_circuit_inner(left, result);
31 let CircuitStepArgument::Expr(CircuitNode::Slot(left)) = left else {
32 unreachable!("guaranteed by `is_composite` check above")
33 };
34 let mleft = to_circuit_inner(mleft, result);
35 let mright = to_circuit_inner(mright, result);
36 result.push(CircuitStep::AddMul(left, mleft, mright));
37 CircuitStepArgument::Expr(CircuitNode::Slot(left))
38 }
39 _ => {
40 let left = to_circuit_inner(left, result);
41 let right = to_circuit_inner(right, result);
42 result.push(CircuitStep::Add(left, right));
43 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
44 }
45 },
46 ArithExpr::Mul(left, right) => {
47 let left = to_circuit_inner(left, result);
48 let right = to_circuit_inner(right, result);
49 result.push(CircuitStep::Mul(left, right));
50 CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1))
51 }
52 ArithExpr::Pow(base, exp) => {
53 let mut acc = to_circuit_inner(base, result);
54 let base_expr = acc;
55 let highest_bit = exp.ilog2();
56
57 for i in (0..highest_bit).rev() {
58 result.push(CircuitStep::Square(acc));
59 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
60
61 if (exp >> i) & 1 != 0 {
62 result.push(CircuitStep::Mul(acc, base_expr));
63 acc = CircuitStepArgument::Expr(CircuitNode::Slot(result.len() - 1));
64 }
65 }
66
67 acc
68 }
69 }
70 }
71
72 let ret = to_circuit_inner(&expr.optimize(), &mut steps);
73 (steps, ret)
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78enum CircuitNode {
79 Var(usize),
81 Slot(usize),
83}
84
85impl CircuitNode {
86 fn get_sparse_chunk<'a, P: PackedField>(
89 &'a self,
90 inputs: &'a RowsBatchRef<'a, P>,
91 evals: &'a [P],
92 row_len: usize,
93 ) -> &'a [P] {
94 match self {
95 Self::Var(index) => inputs.rows()[*index],
96 Self::Slot(slot) => &evals[slot * row_len..(slot + 1) * row_len],
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102enum CircuitStepArgument<F> {
103 Expr(CircuitNode),
104 Const(F),
105}
106
107#[derive(Debug)]
112enum CircuitStep<F: Field> {
113 Add(CircuitStepArgument<F>, CircuitStepArgument<F>),
114 Mul(CircuitStepArgument<F>, CircuitStepArgument<F>),
115 Square(CircuitStepArgument<F>),
116 AddMul(usize, CircuitStepArgument<F>, CircuitStepArgument<F>),
117}
118
119#[derive(Debug, Clone)]
126pub struct ArithCircuitPoly<F: Field> {
127 expr: ArithExpr<F>,
128 steps: Arc<[CircuitStep<F>]>,
129 retval: CircuitStepArgument<F>,
131 degree: usize,
132 n_vars: usize,
133 tower_level: usize,
134}
135
136impl<F: Field> PartialEq for ArithCircuitPoly<F> {
137 fn eq(&self, other: &Self) -> bool {
138 self.n_vars == other.n_vars && self.expr == other.expr
139 }
140}
141
142impl<F: Field> Eq for ArithCircuitPoly<F> {}
143
144impl<F: TowerField> SerializeBytes for ArithCircuitPoly<F> {
145 fn serialize(
146 &self,
147 mut write_buf: impl bytes::BufMut,
148 mode: SerializationMode,
149 ) -> Result<(), SerializationError> {
150 (&self.expr, self.n_vars).serialize(&mut write_buf, mode)
151 }
152}
153
154impl<F: TowerField> DeserializeBytes for ArithCircuitPoly<F> {
155 fn deserialize(
156 read_buf: impl bytes::Buf,
157 mode: SerializationMode,
158 ) -> Result<Self, SerializationError>
159 where
160 Self: Sized,
161 {
162 let (expr, n_vars) = <(ArithExpr<F>, usize)>::deserialize(read_buf, mode)?;
163 Self::with_n_vars(n_vars, expr).map_err(|_| SerializationError::InvalidConstruction {
164 name: "ArithCircuitPoly",
165 })
166 }
167}
168
169impl<F: TowerField> ArithCircuitPoly<F> {
170 pub fn new(expr: ArithExpr<F>) -> Self {
171 let degree = expr.degree();
172 let n_vars = expr.n_vars();
173 let tower_level = expr.binary_tower_level();
174 let (exprs, retval) = circuit_steps_for_expr(&expr);
175
176 Self {
177 expr,
178 steps: exprs.into(),
179 retval,
180 degree,
181 n_vars,
182 tower_level,
183 }
184 }
185
186 pub fn with_n_vars(n_vars: usize, expr: ArithExpr<F>) -> Result<Self, Error> {
191 let degree = expr.degree();
192 let tower_level = expr.binary_tower_level();
193 if n_vars < expr.n_vars() {
194 return Err(Error::IncorrectNumberOfVariables {
195 expected: expr.n_vars(),
196 actual: n_vars,
197 });
198 }
199 let (exprs, retval) = circuit_steps_for_expr(&expr);
200
201 Ok(Self {
202 expr,
203 steps: exprs.into(),
204 retval,
205 n_vars,
206 degree,
207 tower_level,
208 })
209 }
210}
211
212impl<F: TowerField, P: PackedField<Scalar: ExtensionField<F>>> CompositionPoly<P>
213 for ArithCircuitPoly<F>
214{
215 fn degree(&self) -> usize {
216 self.degree
217 }
218
219 fn n_vars(&self) -> usize {
220 self.n_vars
221 }
222
223 fn binary_tower_level(&self) -> usize {
224 self.tower_level
225 }
226
227 fn expression(&self) -> ArithExpr<P::Scalar> {
228 self.expr.convert_field()
229 }
230
231 fn evaluate(&self, query: &[P]) -> Result<P, Error> {
232 if query.len() != self.n_vars {
233 return Err(Error::IncorrectQuerySize {
234 expected: self.n_vars,
235 });
236 }
237
238 fn write_result<T>(target: &mut [MaybeUninit<T>], value: T) {
239 unsafe {
242 target.get_unchecked_mut(0).write(value);
243 }
244 }
245
246 stackalloc_uninit::<P, _, _>(self.steps.len().max(1), |evals| {
248 let get_argument_value = |input: CircuitStepArgument<F>, evals: &[P]| match input {
249 CircuitStepArgument::Expr(CircuitNode::Var(index)) => unsafe {
251 *query.get_unchecked(index)
252 },
253 CircuitStepArgument::Expr(CircuitNode::Slot(slot)) => unsafe {
255 *evals.get_unchecked(slot)
256 },
257 CircuitStepArgument::Const(value) => P::broadcast(value.into()),
258 };
259
260 for (i, expr) in self.steps.iter().enumerate() {
261 let (before, after) = unsafe { evals.split_at_mut_unchecked(i) };
263 let before = unsafe { slice_assume_init_mut(before) };
264 match expr {
265 CircuitStep::Add(x, y) => write_result(
266 after,
267 get_argument_value(*x, before) + get_argument_value(*y, before),
268 ),
269 CircuitStep::AddMul(target_slot, x, y) => {
270 let intermediate =
271 get_argument_value(*x, before) * get_argument_value(*y, before);
272 let target_slot = unsafe { before.get_unchecked_mut(*target_slot) };
274 *target_slot += intermediate;
275 }
276 CircuitStep::Mul(x, y) => write_result(
277 after,
278 get_argument_value(*x, before) * get_argument_value(*y, before),
279 ),
280 CircuitStep::Square(x) => {
281 write_result(after, get_argument_value(*x, before).square())
282 }
283 };
284 }
285
286 unsafe {
289 let evals = slice_assume_init(evals);
290 Ok(get_argument_value(self.retval, evals))
291 }
292 })
293 }
294
295 fn batch_evaluate(&self, batch_query: &RowsBatchRef<P>, evals: &mut [P]) -> Result<(), Error> {
296 let row_len = evals.len();
297 if batch_query.row_len() != row_len {
298 bail!(Error::BatchEvaluateSizeMismatch {
299 expected: row_len,
300 actual: batch_query.row_len(),
301 });
302 }
303
304 stackalloc_uninit::<P, (), _>((self.steps.len() * row_len).max(1), |sparse_evals| {
306 for (i, expr) in self.steps.iter().enumerate() {
307 let (before, current) = sparse_evals.split_at_mut(i * row_len);
308
309 let before = unsafe { slice_assume_init_mut(before) };
311 let current = &mut current[..row_len];
312
313 match expr {
314 CircuitStep::Add(left, right) => {
315 apply_binary_op(
316 left,
317 right,
318 batch_query,
319 before,
320 current,
321 |left, right, out| {
322 out.write(left + right);
323 },
324 );
325 }
326 CircuitStep::Mul(left, right) => {
327 apply_binary_op(
328 left,
329 right,
330 batch_query,
331 before,
332 current,
333 |left, right, out| {
334 out.write(left * right);
335 },
336 );
337 }
338 CircuitStep::Square(arg) => {
339 match arg {
340 CircuitStepArgument::Expr(node) => {
341 let id_chunk = node.get_sparse_chunk(batch_query, before, row_len);
342 for j in 0..row_len {
343 unsafe {
345 current
346 .get_unchecked_mut(j)
347 .write(id_chunk.get_unchecked(j).square());
348 }
349 }
350 }
351 CircuitStepArgument::Const(value) => {
352 let value: P = P::broadcast((*value).into());
353 let result = value.square();
354 for j in 0..row_len {
355 unsafe {
357 current.get_unchecked_mut(j).write(result);
358 }
359 }
360 }
361 }
362 }
363 CircuitStep::AddMul(target, left, right) => {
364 let target = &before[row_len * target..(target + 1) * row_len];
365 let target: &mut [MaybeUninit<P>] = unsafe {
368 std::slice::from_raw_parts_mut(
369 target.as_ptr() as *mut MaybeUninit<P>,
370 target.len(),
371 )
372 };
373 apply_binary_op(
374 left,
375 right,
376 batch_query,
377 before,
378 target,
379 |left, right, out| unsafe {
382 let out = out.assume_init_mut();
383 *out += left * right;
384 },
385 );
386 }
387 }
388 }
389
390 match self.retval {
391 CircuitStepArgument::Expr(node) => {
392 let sparse_evals = unsafe { slice_assume_init(sparse_evals) };
394 evals.copy_from_slice(node.get_sparse_chunk(batch_query, sparse_evals, row_len))
395 }
396 CircuitStepArgument::Const(val) => evals.fill(P::broadcast(val.into())),
397 }
398 });
399
400 Ok(())
401 }
402}
403
404fn apply_binary_op<F: Field, P: PackedField<Scalar: ExtensionField<F>>>(
407 left: &CircuitStepArgument<F>,
408 right: &CircuitStepArgument<F>,
409 batch_query: &RowsBatchRef<P>,
410 evals_before: &[P],
411 current_evals: &mut [MaybeUninit<P>],
412 op: impl Fn(P, P, &mut MaybeUninit<P>),
413) {
414 let row_len = current_evals.len();
415
416 match (left, right) {
417 (CircuitStepArgument::Expr(left), CircuitStepArgument::Expr(right)) => {
418 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
419 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
420 for j in 0..row_len {
421 unsafe {
423 op(
424 *left.get_unchecked(j),
425 *right.get_unchecked(j),
426 current_evals.get_unchecked_mut(j),
427 )
428 }
429 }
430 }
431 (CircuitStepArgument::Expr(left), CircuitStepArgument::Const(right)) => {
432 let left = left.get_sparse_chunk(batch_query, evals_before, row_len);
433 let right = P::broadcast((*right).into());
434 for j in 0..row_len {
435 unsafe {
437 op(*left.get_unchecked(j), right, current_evals.get_unchecked_mut(j));
438 }
439 }
440 }
441 (CircuitStepArgument::Const(left), CircuitStepArgument::Expr(right)) => {
442 let left = P::broadcast((*left).into());
443 let right = right.get_sparse_chunk(batch_query, evals_before, row_len);
444 for j in 0..row_len {
445 unsafe {
447 op(left, *right.get_unchecked(j), current_evals.get_unchecked_mut(j));
448 }
449 }
450 }
451 (CircuitStepArgument::Const(left), CircuitStepArgument::Const(right)) => {
452 let left = P::broadcast((*left).into());
453 let right = P::broadcast((*right).into());
454 let mut result = MaybeUninit::uninit();
455 op(left, right, &mut result);
456 for j in 0..row_len {
457 unsafe {
461 current_evals
462 .get_unchecked_mut(j)
463 .write(result.assume_init());
464 }
465 }
466 }
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use binius_field::{
473 BinaryField16b, BinaryField8b, PackedBinaryField8x16b, PackedField, TowerField,
474 };
475 use binius_math::{CompositionPoly, RowsBatch};
476 use binius_utils::felts;
477
478 use super::*;
479
480 #[test]
481 fn test_constant() {
482 type F = BinaryField8b;
483 type P = PackedBinaryField8x16b;
484
485 let expr = ArithExpr::Const(F::new(123));
486 let circuit = ArithCircuitPoly::<F>::new(expr);
487
488 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
489 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
490 assert_eq!(typed_circuit.degree(), 0);
491 assert_eq!(typed_circuit.n_vars(), 0);
492
493 assert_eq!(typed_circuit.evaluate(&[]).unwrap(), P::broadcast(F::new(123).into()));
494
495 let mut evals = [P::default()];
496 typed_circuit
497 .batch_evaluate(&RowsBatchRef::new(&[], 1), &mut evals)
498 .unwrap();
499 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
500 }
501
502 #[test]
503 fn test_identity() {
504 type F = BinaryField8b;
505 type P = PackedBinaryField8x16b;
506
507 let expr = ArithExpr::Var(0);
509 let circuit = ArithCircuitPoly::<F>::new(expr);
510
511 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
512 assert_eq!(typed_circuit.binary_tower_level(), 0);
513 assert_eq!(typed_circuit.degree(), 1);
514 assert_eq!(typed_circuit.n_vars(), 1);
515
516 assert_eq!(
517 typed_circuit
518 .evaluate(&[P::broadcast(F::new(123).into())])
519 .unwrap(),
520 P::broadcast(F::new(123).into())
521 );
522
523 let mut evals = [P::default()];
524 let batch_query = [[P::broadcast(F::new(123).into())]; 1];
525 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 1);
526 typed_circuit
527 .batch_evaluate(&batch_query.get_ref(), &mut evals)
528 .unwrap();
529 assert_eq!(evals, [P::broadcast(F::new(123).into())]);
530 }
531
532 #[test]
533 fn test_add() {
534 type F = BinaryField8b;
535 type P = PackedBinaryField8x16b;
536
537 let expr = ArithExpr::Const(F::new(123)) + ArithExpr::Var(0);
539 let circuit = ArithCircuitPoly::<F>::new(expr);
540
541 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
542 assert_eq!(typed_circuit.binary_tower_level(), 3);
543 assert_eq!(typed_circuit.degree(), 1);
544 assert_eq!(typed_circuit.n_vars(), 1);
545
546 assert_eq!(
547 CompositionPoly::evaluate(&circuit, &[P::broadcast(F::new(0).into())]).unwrap(),
548 P::broadcast(F::new(123).into())
549 );
550 }
551
552 #[test]
553 fn test_mul() {
554 type F = BinaryField8b;
555 type P = PackedBinaryField8x16b;
556
557 let expr = ArithExpr::Const(F::new(123)) * ArithExpr::Var(0);
559 let circuit = ArithCircuitPoly::<F>::new(expr);
560
561 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
562 assert_eq!(typed_circuit.binary_tower_level(), 3);
563 assert_eq!(typed_circuit.degree(), 1);
564 assert_eq!(typed_circuit.n_vars(), 1);
565
566 assert_eq!(
567 CompositionPoly::evaluate(
568 &circuit,
569 &[P::from_scalars(
570 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
571 )]
572 )
573 .unwrap(),
574 P::from_scalars(felts!(BinaryField16b[0, 123, 157, 230, 85, 46, 154, 225])),
575 );
576 }
577
578 #[test]
579 fn test_pow() {
580 type F = BinaryField8b;
581 type P = PackedBinaryField8x16b;
582
583 let expr = ArithExpr::Var(0).pow(13);
585 let circuit = ArithCircuitPoly::<F>::new(expr);
586
587 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
588 assert_eq!(typed_circuit.binary_tower_level(), 0);
589 assert_eq!(typed_circuit.degree(), 13);
590 assert_eq!(typed_circuit.n_vars(), 1);
591
592 assert_eq!(
593 CompositionPoly::evaluate(
594 &circuit,
595 &[P::from_scalars(
596 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
597 )]
598 )
599 .unwrap(),
600 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 200, 52, 51, 115])),
601 );
602 }
603
604 #[test]
605 fn test_mixed() {
606 type F = BinaryField8b;
607 type P = PackedBinaryField8x16b;
608
609 let expr = ArithExpr::Var(0).pow(2) * (ArithExpr::Var(1) + ArithExpr::Const(F::new(123)));
611 let circuit = ArithCircuitPoly::<F>::new(expr);
612
613 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
614 assert_eq!(typed_circuit.binary_tower_level(), 3);
615 assert_eq!(typed_circuit.degree(), 3);
616 assert_eq!(typed_circuit.n_vars(), 2);
617
618 assert_eq!(
620 CompositionPoly::evaluate(
621 &circuit,
622 &[
623 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
624 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
625 ]
626 )
627 .unwrap(),
628 P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176])),
629 );
630
631 let query1 = &[
633 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
634 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
635 ];
636 let query2 = &[
637 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
638 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
639 ];
640 let query3 = &[
641 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
642 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
643 ];
644 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
645 let expected2 =
646 P::from_scalars(felts!(BinaryField16b[123, 122, 121, 120, 127, 126, 125, 124]));
647 let expected3 = P::from_scalars(felts!(BinaryField16b[0, 30, 59, 36, 151, 140, 170, 176]));
648
649 let mut batch_result = vec![P::zero(); 3];
650 let batch_query = &[
651 &[query1[0], query2[0], query3[0]],
652 &[query1[1], query2[1], query3[1]],
653 ];
654 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
655
656 CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
657 .unwrap();
658 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
659 }
660
661 #[test]
662 fn test_const_fold() {
663 type F = BinaryField8b;
664 type P = PackedBinaryField8x16b;
665
666 let expr = ArithExpr::Var(0)
668 * ((ArithExpr::Const(F::new(122)) * ArithExpr::Const(F::new(123)))
669 + (ArithExpr::Const(F::new(124)) + ArithExpr::Const(F::new(125))))
670 + ArithExpr::Var(1);
671 let circuit = ArithCircuitPoly::<F>::new(expr);
672 assert_eq!(circuit.steps.len(), 2);
673
674 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
675 assert_eq!(typed_circuit.binary_tower_level(), F::TOWER_LEVEL);
676 assert_eq!(typed_circuit.degree(), 1);
677 assert_eq!(typed_circuit.n_vars(), 2);
678
679 assert_eq!(
681 CompositionPoly::evaluate(
682 &circuit,
683 &[
684 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
685 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
686 ]
687 )
688 .unwrap(),
689 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78])),
690 );
691
692 let query1 = &[
694 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
695 P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0])),
696 ];
697 let query2 = &[
698 P::from_scalars(felts!(BinaryField16b[1, 1, 1, 1, 1, 1, 1, 1])),
699 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
700 ];
701 let query3 = &[
702 P::from_scalars(felts!(BinaryField16b[0, 1, 2, 3, 4, 5, 6, 7])),
703 P::from_scalars(felts!(BinaryField16b[100, 101, 102, 103, 104, 105, 106, 107])),
704 ];
705 let expected1 = P::from_scalars(felts!(BinaryField16b[0, 0, 0, 0, 0, 0, 0, 0]));
706 let expected2 = P::from_scalars(felts!(BinaryField16b[84, 85, 86, 87, 80, 81, 82, 83]));
707 let expected3 =
708 P::from_scalars(felts!(BinaryField16b[100, 49, 206, 155, 177, 228, 27, 78]));
709
710 let mut batch_result = vec![P::zero(); 3];
711 let batch_query = &[
712 &[query1[0], query2[0], query3[0]],
713 &[query1[1], query2[1], query3[1]],
714 ];
715 let batch_query = RowsBatch::new_from_iter(batch_query.iter().map(|x| x.as_slice()), 3);
716 CompositionPoly::batch_evaluate(&circuit, &batch_query.get_ref(), &mut batch_result)
717 .unwrap();
718 assert_eq!(&batch_result, &[expected1, expected2, expected3]);
719 }
720
721 #[test]
722 fn test_pow_const_fold() {
723 type F = BinaryField8b;
724 type P = PackedBinaryField8x16b;
725
726 let expr = ArithExpr::Var(0) + ArithExpr::Const(F::from(2)).pow(4);
728 let circuit = ArithCircuitPoly::<F>::new(expr);
729 assert_eq!(circuit.steps.len(), 1);
730
731 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
732 assert_eq!(typed_circuit.binary_tower_level(), 1);
733 assert_eq!(typed_circuit.degree(), 1);
734 assert_eq!(typed_circuit.n_vars(), 1);
735
736 assert_eq!(
737 CompositionPoly::evaluate(
738 &circuit,
739 &[P::from_scalars(
740 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
741 )]
742 )
743 .unwrap(),
744 P::from_scalars(felts!(BinaryField16b[2, 3, 0, 1, 120, 121, 126, 127])),
745 );
746 }
747
748 #[test]
749 fn test_pow_nested() {
750 type F = BinaryField8b;
751 type P = PackedBinaryField8x16b;
752
753 let expr = ArithExpr::Var(0).pow(2).pow(3).pow(4);
755 let circuit = ArithCircuitPoly::<F>::new(expr);
756 assert_eq!(circuit.steps.len(), 5);
757
758 let typed_circuit: &dyn CompositionPoly<P> = &circuit;
759 assert_eq!(typed_circuit.binary_tower_level(), 0);
760 assert_eq!(typed_circuit.degree(), 24);
761 assert_eq!(typed_circuit.n_vars(), 1);
762
763 assert_eq!(
764 CompositionPoly::evaluate(
765 &circuit,
766 &[P::from_scalars(
767 felts!(BinaryField16b[0, 1, 2, 3, 122, 123, 124, 125]),
768 )]
769 )
770 .unwrap(),
771 P::from_scalars(felts!(BinaryField16b[0, 1, 1, 1, 20, 152, 41, 170])),
772 );
773 }
774
775 #[test]
776 fn test_circuit_steps_for_expr_constant() {
777 type F = BinaryField8b;
778
779 let expr = ArithExpr::Const(F::new(5));
780 let (steps, retval) = circuit_steps_for_expr(&expr);
781
782 assert!(steps.is_empty(), "No steps should be generated for a constant");
783 assert_eq!(retval, CircuitStepArgument::Const(F::new(5)));
784 }
785
786 #[test]
787 fn test_circuit_steps_for_expr_variable() {
788 type F = BinaryField8b;
789
790 let expr = ArithExpr::<F>::Var(18);
791 let (steps, retval) = circuit_steps_for_expr(&expr);
792
793 assert!(steps.is_empty(), "No steps should be generated for a variable");
794 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(18))));
795 }
796
797 #[test]
798 fn test_circuit_steps_for_expr_addition() {
799 type F = BinaryField8b;
800
801 let expr = ArithExpr::<F>::Var(14) + ArithExpr::<F>::Var(56);
802 let (steps, retval) = circuit_steps_for_expr(&expr);
803
804 assert_eq!(steps.len(), 1, "One addition step should be generated");
805 assert!(matches!(
806 steps[0],
807 CircuitStep::Add(
808 CircuitStepArgument::Expr(CircuitNode::Var(14)),
809 CircuitStepArgument::Expr(CircuitNode::Var(56))
810 )
811 ));
812 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
813 }
814
815 #[test]
816 fn test_circuit_steps_for_expr_multiplication() {
817 type F = BinaryField8b;
818
819 let expr = ArithExpr::<F>::Var(36) * ArithExpr::Var(26);
820 let (steps, retval) = circuit_steps_for_expr(&expr);
821
822 assert_eq!(steps.len(), 1, "One multiplication step should be generated");
823 assert!(matches!(
824 steps[0],
825 CircuitStep::Mul(
826 CircuitStepArgument::Expr(CircuitNode::Var(36)),
827 CircuitStepArgument::Expr(CircuitNode::Var(26))
828 )
829 ));
830 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
831 }
832
833 #[test]
834 fn test_circuit_steps_for_expr_pow_1() {
835 type F = BinaryField8b;
836
837 let expr = ArithExpr::<F>::Var(12).pow(1);
838 let (steps, retval) = circuit_steps_for_expr(&expr);
839
840 assert_eq!(steps.len(), 0, "Pow(1) should not generate any computation steps");
842
843 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Var(12))));
845 }
846
847 #[test]
848 fn test_circuit_steps_for_expr_pow_2() {
849 type F = BinaryField8b;
850
851 let expr = ArithExpr::<F>::Var(10).pow(2);
852 let (steps, retval) = circuit_steps_for_expr(&expr);
853
854 assert_eq!(steps.len(), 1, "Pow(2) should generate one squaring step");
855 assert!(matches!(
856 steps[0],
857 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(10)))
858 ));
859 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(0))));
860 }
861
862 #[test]
863 fn test_circuit_steps_for_expr_pow_3() {
864 type F = BinaryField8b;
865
866 let expr = ArithExpr::<F>::Var(5).pow(3);
867 let (steps, retval) = circuit_steps_for_expr(&expr);
868
869 assert_eq!(
870 steps.len(),
871 2,
872 "Pow(3) should generate one squaring and one multiplication step"
873 );
874 assert!(matches!(
875 steps[0],
876 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(5)))
877 ));
878 assert!(matches!(
879 steps[1],
880 CircuitStep::Mul(
881 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
882 CircuitStepArgument::Expr(CircuitNode::Var(5))
883 )
884 ));
885 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
886 }
887
888 #[test]
889 fn test_circuit_steps_for_expr_pow_4() {
890 type F = BinaryField8b;
891
892 let expr = ArithExpr::<F>::Var(7).pow(4);
893 let (steps, retval) = circuit_steps_for_expr(&expr);
894
895 assert_eq!(steps.len(), 2, "Pow(4) should generate two squaring steps");
896 assert!(matches!(
897 steps[0],
898 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
899 ));
900
901 assert!(matches!(
902 steps[1],
903 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
904 ));
905
906 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(1))));
907 }
908
909 #[test]
910 fn test_circuit_steps_for_expr_pow_5() {
911 type F = BinaryField8b;
912
913 let expr = ArithExpr::<F>::Var(3).pow(5);
914 let (steps, retval) = circuit_steps_for_expr(&expr);
915
916 assert_eq!(
917 steps.len(),
918 3,
919 "Pow(5) should generate two squaring steps and one multiplication"
920 );
921 assert!(matches!(
922 steps[0],
923 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(3)))
924 ));
925 assert!(matches!(
926 steps[1],
927 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
928 ));
929 assert!(matches!(
930 steps[2],
931 CircuitStep::Mul(
932 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
933 CircuitStepArgument::Expr(CircuitNode::Var(3))
934 )
935 ));
936
937 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
938 }
939
940 #[test]
941 fn test_circuit_steps_for_expr_pow_8() {
942 type F = BinaryField8b;
943
944 let expr = ArithExpr::<F>::Var(4).pow(8);
945 let (steps, retval) = circuit_steps_for_expr(&expr);
946
947 assert_eq!(steps.len(), 3, "Pow(8) should generate three squaring steps");
948 assert!(matches!(
949 steps[0],
950 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(4)))
951 ));
952 assert!(matches!(
953 steps[1],
954 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
955 ));
956 assert!(matches!(
957 steps[2],
958 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
959 ));
960
961 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(2))));
962 }
963
964 #[test]
965 fn test_circuit_steps_for_expr_pow_9() {
966 type F = BinaryField8b;
967
968 let expr = ArithExpr::<F>::Var(8).pow(9);
969 let (steps, retval) = circuit_steps_for_expr(&expr);
970
971 assert_eq!(
972 steps.len(),
973 4,
974 "Pow(9) should generate three squaring steps and one multiplication"
975 );
976 assert!(matches!(
977 steps[0],
978 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(8)))
979 ));
980 assert!(matches!(
981 steps[1],
982 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(0)))
983 ));
984 assert!(matches!(
985 steps[2],
986 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
987 ));
988 assert!(matches!(
989 steps[3],
990 CircuitStep::Mul(
991 CircuitStepArgument::Expr(CircuitNode::Slot(2)),
992 CircuitStepArgument::Expr(CircuitNode::Var(8))
993 )
994 ));
995
996 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
997 }
998
999 #[test]
1000 fn test_circuit_steps_for_expr_pow_12() {
1001 type F = BinaryField8b;
1002 let expr = ArithExpr::<F>::Var(6).pow(12);
1003 let (steps, retval) = circuit_steps_for_expr(&expr);
1004
1005 assert_eq!(steps.len(), 4, "Pow(12) should use 4 steps.");
1006
1007 assert!(matches!(
1008 steps[0],
1009 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(6)))
1010 ));
1011 assert!(matches!(
1012 steps[1],
1013 CircuitStep::Mul(
1014 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1015 CircuitStepArgument::Expr(CircuitNode::Var(6))
1016 )
1017 ));
1018 assert!(matches!(
1019 steps[2],
1020 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1021 ));
1022 assert!(matches!(
1023 steps[3],
1024 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1025 ));
1026
1027 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))));
1028 }
1029
1030 #[test]
1031 fn test_circuit_steps_for_expr_pow_13() {
1032 type F = BinaryField8b;
1033 let expr = ArithExpr::<F>::Var(7).pow(13);
1034 let (steps, retval) = circuit_steps_for_expr(&expr);
1035
1036 assert_eq!(steps.len(), 5, "Pow(13) should use 5 steps.");
1037 assert!(matches!(
1038 steps[0],
1039 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Var(7)))
1040 ));
1041 assert!(matches!(
1042 steps[1],
1043 CircuitStep::Mul(
1044 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1045 CircuitStepArgument::Expr(CircuitNode::Var(7))
1046 )
1047 ));
1048 assert!(matches!(
1049 steps[2],
1050 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(1)))
1051 ));
1052 assert!(matches!(
1053 steps[3],
1054 CircuitStep::Square(CircuitStepArgument::Expr(CircuitNode::Slot(2)))
1055 ));
1056 assert!(matches!(
1057 steps[4],
1058 CircuitStep::Mul(
1059 CircuitStepArgument::Expr(CircuitNode::Slot(3)),
1060 CircuitStepArgument::Expr(CircuitNode::Var(7))
1061 )
1062 ));
1063 assert!(matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(4))));
1064 }
1065
1066 #[test]
1067 fn test_circuit_steps_for_expr_complex() {
1068 type F = BinaryField8b;
1069
1070 let expr = (ArithExpr::<F>::Var(0) * ArithExpr::Var(1))
1071 + (ArithExpr::Const(F::ONE) - ArithExpr::Var(0)) * ArithExpr::Var(2)
1072 - ArithExpr::Var(3);
1073
1074 let (steps, retval) = circuit_steps_for_expr(&expr);
1075
1076 assert_eq!(steps.len(), 4, "Expression should generate 4 computation steps");
1077
1078 assert!(
1079 matches!(
1080 steps[0],
1081 CircuitStep::Mul(
1082 CircuitStepArgument::Expr(CircuitNode::Var(0)),
1083 CircuitStepArgument::Expr(CircuitNode::Var(1))
1084 )
1085 ),
1086 "First step should be multiplication x0 * x1"
1087 );
1088
1089 assert!(
1090 matches!(
1091 steps[1],
1092 CircuitStep::Add(
1093 CircuitStepArgument::Const(F::ONE),
1094 CircuitStepArgument::Expr(CircuitNode::Var(0))
1095 )
1096 ),
1097 "Second step should be (1 - x0)"
1098 );
1099
1100 assert!(
1101 matches!(
1102 steps[2],
1103 CircuitStep::AddMul(
1104 0,
1105 CircuitStepArgument::Expr(CircuitNode::Slot(1)),
1106 CircuitStepArgument::Expr(CircuitNode::Var(2))
1107 )
1108 ),
1109 "Third step should be (1 - x0) * x2"
1110 );
1111
1112 assert!(
1113 matches!(
1114 steps[3],
1115 CircuitStep::Add(
1116 CircuitStepArgument::Expr(CircuitNode::Slot(0)),
1117 CircuitStepArgument::Expr(CircuitNode::Var(3))
1118 )
1119 ),
1120 "Fourth step should be x0 * x1 + (1 - x0) * x2 + x3"
1121 );
1122
1123 assert!(
1124 matches!(retval, CircuitStepArgument::Expr(CircuitNode::Slot(3))),
1125 "Final result should be stored in Slot(3)"
1126 );
1127 }
1128}