1use std::{
4 cmp::Ordering,
5 collections::{hash_map::Entry, HashMap},
6 fmt::{self, Display},
7 hash::{Hash, Hasher},
8 iter::{Product, Sum},
9 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
10 sync::Arc,
11};
12
13use binius_field::{Field, PackedField, TowerField};
14use binius_macros::{DeserializeBytes, SerializeBytes};
15
16use super::error::Error;
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum ArithExpr<F: Field> {
25 Const(F),
26 Var(usize),
27 Add(Arc<ArithExpr<F>>, Arc<ArithExpr<F>>),
28 Mul(Arc<ArithExpr<F>>, Arc<ArithExpr<F>>),
29 Pow(Arc<ArithExpr<F>>, u64),
30}
31
32impl<F: Field + Display> Display for ArithExpr<F> {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 Self::Const(v) => write!(f, "{v}"),
36 Self::Var(i) => write!(f, "x{i}"),
37 Self::Add(x, y) => write!(f, "({} + {})", &**x, &**y),
38 Self::Mul(x, y) => write!(f, "({} * {})", &**x, &**y),
39 Self::Pow(x, p) => write!(f, "({})^{p}", &**x),
40 }
41 }
42}
43
44impl<F: Field> ArithExpr<F> {
45 pub fn pow(self, exp: u64) -> Self {
46 Self::Pow(Arc::new(self), exp)
47 }
48
49 pub const fn zero() -> Self {
50 Self::Const(F::ZERO)
51 }
52
53 pub const fn one() -> Self {
54 Self::Const(F::ONE)
55 }
56}
57
58impl<F> Default for ArithExpr<F>
59where
60 F: Field,
61{
62 fn default() -> Self {
63 Self::zero()
64 }
65}
66
67impl<F> Add for ArithExpr<F>
68where
69 F: Field,
70{
71 type Output = Self;
72
73 fn add(self, rhs: Self) -> Self {
74 Self::Add(Arc::new(self), Arc::new(rhs))
75 }
76}
77
78impl<F> Add<Arc<Self>> for ArithExpr<F>
79where
80 F: Field,
81{
82 type Output = Self;
83
84 fn add(self, rhs: Arc<Self>) -> Self {
85 Self::Add(Arc::new(self), rhs)
86 }
87}
88
89impl<F> AddAssign for ArithExpr<F>
90where
91 F: Field,
92{
93 fn add_assign(&mut self, rhs: Self) {
94 *self = std::mem::take(self) + rhs;
95 }
96}
97
98impl<F> AddAssign<Arc<Self>> for ArithExpr<F>
99where
100 F: Field,
101{
102 fn add_assign(&mut self, rhs: Arc<Self>) {
103 *self = std::mem::take(self) + rhs;
104 }
105}
106
107impl<F> Sub for ArithExpr<F>
108where
109 F: Field,
110{
111 type Output = Self;
112
113 fn sub(self, rhs: Self) -> Self {
114 Self::Add(Arc::new(self), Arc::new(rhs))
115 }
116}
117
118impl<F> Sub<Arc<Self>> for ArithExpr<F>
119where
120 F: Field,
121{
122 type Output = Self;
123
124 fn sub(self, rhs: Arc<Self>) -> Self {
125 Self::Add(Arc::new(self), rhs)
126 }
127}
128
129impl<F> SubAssign for ArithExpr<F>
130where
131 F: Field,
132{
133 fn sub_assign(&mut self, rhs: Self) {
134 *self = std::mem::take(self) - rhs;
135 }
136}
137
138impl<F> SubAssign<Arc<Self>> for ArithExpr<F>
139where
140 F: Field,
141{
142 fn sub_assign(&mut self, rhs: Arc<Self>) {
143 *self = std::mem::take(self) - rhs;
144 }
145}
146
147impl<F> Mul for ArithExpr<F>
148where
149 F: Field,
150{
151 type Output = Self;
152
153 fn mul(self, rhs: Self) -> Self {
154 Self::Mul(Arc::new(self), Arc::new(rhs))
155 }
156}
157
158impl<F> Mul<Arc<Self>> for ArithExpr<F>
159where
160 F: Field,
161{
162 type Output = Self;
163
164 fn mul(self, rhs: Arc<Self>) -> Self {
165 Self::Mul(Arc::new(self), rhs)
166 }
167}
168
169impl<F> MulAssign for ArithExpr<F>
170where
171 F: Field,
172{
173 fn mul_assign(&mut self, rhs: Self) {
174 *self = std::mem::take(self) * rhs;
175 }
176}
177
178impl<F> MulAssign<Arc<Self>> for ArithExpr<F>
179where
180 F: Field,
181{
182 fn mul_assign(&mut self, rhs: Arc<Self>) {
183 *self = std::mem::take(self) * rhs;
184 }
185}
186
187impl<F: Field> Sum for ArithExpr<F> {
188 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
189 iter.reduce(|acc, item| acc + item).unwrap_or(Self::zero())
190 }
191}
192
193impl<F: Field> Product for ArithExpr<F> {
194 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
195 iter.reduce(|acc, item| acc * item).unwrap_or(Self::one())
196 }
197}
198
199#[derive(Clone, Copy, Debug, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
200pub enum ArithCircuitStep<F: Field> {
201 Add(usize, usize),
202 Mul(usize, usize),
203 Pow(usize, u64),
204 Const(F),
205 Var(usize),
206}
207
208impl<F: Field> Default for ArithCircuitStep<F> {
209 fn default() -> Self {
210 Self::Const(F::ZERO)
211 }
212}
213
214#[derive(Clone, Debug, SerializeBytes, DeserializeBytes, Eq)]
224pub struct ArithCircuit<F: Field> {
225 steps: Vec<ArithCircuitStep<F>>,
226}
227
228impl<F: Field> ArithCircuit<F> {
229 pub fn var(index: usize) -> Self {
230 Self {
231 steps: vec![ArithCircuitStep::Var(index)],
232 }
233 }
234
235 pub fn constant(value: F) -> Self {
236 Self {
237 steps: vec![ArithCircuitStep::Const(value)],
238 }
239 }
240
241 pub fn zero() -> Self {
242 Self::constant(F::ZERO)
243 }
244
245 pub fn one() -> Self {
246 Self::constant(F::ONE)
247 }
248
249 pub fn pow(mut self, exp: u64) -> Self {
250 self.steps
251 .push(ArithCircuitStep::Pow(self.steps.len() - 1, exp));
252
253 self
254 }
255
256 pub fn steps(&self) -> &[ArithCircuitStep<F>] {
258 &self.steps
259 }
260
261 pub fn degree(&self) -> usize {
263 fn step_degree<F: Field>(step: usize, steps: &[ArithCircuitStep<F>]) -> usize {
264 match steps[step] {
265 ArithCircuitStep::Const(_) => 0,
266 ArithCircuitStep::Var(_) => 1,
267 ArithCircuitStep::Add(left, right) => {
268 step_degree(left, steps).max(step_degree(right, steps))
269 }
270 ArithCircuitStep::Mul(left, right) => {
271 step_degree(left, steps) + step_degree(right, steps)
272 }
273 ArithCircuitStep::Pow(base, exp) => step_degree(base, steps) * (exp as usize),
274 }
275 }
276
277 step_degree(self.steps.len() - 1, &self.steps)
278 }
279
280 pub fn n_vars(&self) -> usize {
282 self.steps
283 .iter()
284 .map(|step| {
285 if let ArithCircuitStep::Var(index) = step {
286 *index + 1
287 } else {
288 0
289 }
290 })
291 .max()
292 .unwrap_or(0)
293 }
294
295 pub fn binary_tower_level(&self) -> usize
297 where
298 F: TowerField,
299 {
300 self.steps
301 .iter()
302 .map(|step| {
303 if let ArithCircuitStep::Const(value) = step {
304 value.min_tower_level()
305 } else {
306 0
307 }
308 })
309 .max()
310 .unwrap_or(0)
311 }
312
313 pub fn eval_cost(&self) -> EvalCost {
315 self.steps
316 .iter()
317 .fold(EvalCost::default(), |acc, step| match *step {
318 ArithCircuitStep::Const(_) | ArithCircuitStep::Var(_) => acc,
319 ArithCircuitStep::Add(_, _) => acc + EvalCost::add(1),
320 ArithCircuitStep::Mul(_, _) => acc + EvalCost::mul(1),
321 ArithCircuitStep::Pow(_, exp) => {
327 let n_squares = exp.ilog2() as usize;
328 let n_muls = exp.count_ones().saturating_sub(1) as usize;
329 acc + EvalCost::mul(n_muls) + EvalCost::square(n_squares)
330 }
331 })
332 }
333
334 pub fn leading_term(&self) -> Self {
337 let (_, expr) = self.leading_term_with_degree(self.steps.len() - 1);
338 expr
339 }
340
341 fn leading_term_with_degree(&self, step: usize) -> (usize, Self) {
343 match &self.steps[step] {
344 ArithCircuitStep::Const(value) => (0, Self::constant(*value)),
345 ArithCircuitStep::Var(index) => (1, Self::var(*index)),
346 ArithCircuitStep::Add(left, right) => {
347 let (lhs_degree, lhs) = self.leading_term_with_degree(*left);
348 let (rhs_degree, rhs) = self.leading_term_with_degree(*right);
349 match lhs_degree.cmp(&rhs_degree) {
350 Ordering::Less => (rhs_degree, rhs),
351 Ordering::Equal => (lhs_degree, lhs + rhs),
352 Ordering::Greater => (lhs_degree, lhs),
353 }
354 }
355 ArithCircuitStep::Mul(left, right) => {
356 let (lhs_degree, lhs) = self.leading_term_with_degree(*left);
357 let (rhs_degree, rhs) = self.leading_term_with_degree(*right);
358 (lhs_degree + rhs_degree, lhs * rhs)
359 }
360 ArithCircuitStep::Pow(base, exp) => {
361 let (base_degree, base) = self.leading_term_with_degree(*base);
362 (base_degree * (*exp as usize), base.pow(*exp))
363 }
364 }
365 }
366
367 pub fn evaluate(&self, query: &[F]) -> Result<F, Error> {
368 let mut step_evals = Vec::<F>::with_capacity(self.steps.len());
369 for step in &self.steps {
370 match step {
371 ArithCircuitStep::Add(left, right) => {
372 step_evals.push(step_evals[*left] + step_evals[*right])
373 }
374 ArithCircuitStep::Mul(left, right) => {
375 step_evals.push(step_evals[*left] * step_evals[*right])
376 }
377 ArithCircuitStep::Pow(base, exp) => step_evals.push(step_evals[*base].pow(*exp)),
378 ArithCircuitStep::Const(value) => step_evals.push(*value),
379 ArithCircuitStep::Var(index) => step_evals.push(query[*index]),
380 }
381 }
382 Ok(step_evals.pop().unwrap_or_default())
383 }
384
385 pub fn convert_field<FTgt: Field + From<F>>(&self) -> ArithCircuit<FTgt> {
386 ArithCircuit {
387 steps: self
388 .steps
389 .iter()
390 .map(|step| match step {
391 ArithCircuitStep::Const(value) => ArithCircuitStep::Const((*value).into()),
392 ArithCircuitStep::Var(index) => ArithCircuitStep::Var(*index),
393 ArithCircuitStep::Add(left, right) => ArithCircuitStep::Add(*left, *right),
394 ArithCircuitStep::Mul(left, right) => ArithCircuitStep::Mul(*left, *right),
395 ArithCircuitStep::Pow(base, exp) => ArithCircuitStep::Pow(*base, *exp),
396 })
397 .collect(),
398 }
399 }
400
401 pub fn try_convert_field<FTgt: Field + TryFrom<F>>(
402 &self,
403 ) -> Result<ArithCircuit<FTgt>, <FTgt as TryFrom<F>>::Error> {
404 let steps = self
405 .steps
406 .iter()
407 .map(|step| -> Result<ArithCircuitStep<FTgt>, <FTgt as TryFrom<F>>::Error> {
408 let result = match step {
409 ArithCircuitStep::Const(value) => {
410 ArithCircuitStep::Const(FTgt::try_from(*value)?)
411 }
412 ArithCircuitStep::Var(index) => ArithCircuitStep::Var(*index),
413 ArithCircuitStep::Add(left, right) => ArithCircuitStep::Add(*left, *right),
414 ArithCircuitStep::Mul(left, right) => ArithCircuitStep::Mul(*left, *right),
415 ArithCircuitStep::Pow(base, exp) => ArithCircuitStep::Pow(*base, *exp),
416 };
417 Ok(result)
418 })
419 .collect::<Result<Vec<_>, _>>()?;
420
421 Ok(ArithCircuit { steps })
422 }
423
424 pub fn remap_vars(&self, indices: &[usize]) -> Result<Self, Error> {
434 let steps = self
435 .steps
436 .iter()
437 .map(|step| -> Result<ArithCircuitStep<F>, Error> {
438 if let ArithCircuitStep::Var(index) = step {
439 let new_index = indices.get(*index).copied().ok_or_else(|| {
440 Error::IncorrectArgumentLength {
441 arg: "indices".to_string(),
442 expected: *index,
443 }
444 })?;
445 Ok(ArithCircuitStep::Var(new_index))
446 } else {
447 Ok(*step)
448 }
449 })
450 .collect::<Result<Vec<_>, _>>()?;
451 Ok(Self { steps })
452 }
453
454 pub fn const_subst(self, var: usize, value: F) -> Self {
456 let steps = self
457 .steps
458 .iter()
459 .map(|step| match step {
460 ArithCircuitStep::Var(index) if *index == var => ArithCircuitStep::Const(value),
461 _ => *step,
462 })
463 .collect();
464 Self { steps }
465 }
466
467 pub fn get_constant(&self) -> Option<F> {
469 if let ArithCircuitStep::Const(value) =
470 self.steps.last().expect("steps should not be empty")
471 {
472 Some(*value)
473 } else {
474 None
475 }
476 }
477
478 pub fn linear_normal_form(&self) -> Result<LinearNormalForm<F>, Error> {
484 self.sparse_linear_normal_form().map(Into::into)
485 }
486
487 fn sparse_linear_normal_form(&self) -> Result<SparseLinearNormalForm<F>, Error> {
488 fn sparse_linear_normal_form<F: Field>(
489 step: usize,
490 steps: &[ArithCircuitStep<F>],
491 ) -> Result<SparseLinearNormalForm<F>, Error> {
492 match &steps[step] {
493 ArithCircuitStep::Const(val) => Ok((*val).into()),
494 ArithCircuitStep::Var(index) => Ok(SparseLinearNormalForm {
495 constant: F::ZERO,
496 dense_linear_form_len: index + 1,
497 var_coeffs: [(*index, F::ONE)].into(),
498 }),
499 ArithCircuitStep::Add(left, right) => {
500 let left = sparse_linear_normal_form(*left, steps)?;
501 let right = sparse_linear_normal_form(*right, steps)?;
502 Ok(left + right)
503 }
504 ArithCircuitStep::Mul(left, right) => {
505 let left = sparse_linear_normal_form(*left, steps)?;
506 let right = sparse_linear_normal_form(*right, steps)?;
507 left * right
508 }
509 ArithCircuitStep::Pow(_, 0) => Ok(F::ONE.into()),
510 ArithCircuitStep::Pow(expr, 1) => sparse_linear_normal_form(*expr, steps),
511 ArithCircuitStep::Pow(expr, pow) => {
512 let linear_form = sparse_linear_normal_form(*expr, steps)?;
513 if linear_form.dense_linear_form_len != 0 {
514 return Err(Error::NonLinearExpression);
515 }
516 Ok(linear_form.constant.pow(*pow).into())
517 }
518 }
519 }
520
521 sparse_linear_normal_form(self.steps.len() - 1, &self.steps)
522 }
523
524 pub fn vars_usage(&self) -> Vec<bool> {
529 let mut usage = vec![false; self.n_vars()];
530
531 for step in &self.steps {
532 if let ArithCircuitStep::Var(index) = step {
533 usage[*index] = true;
534 }
535 }
536
537 usage
538 }
539
540 fn optimize_constants(&mut self) {
542 for step_index in 0..self.steps.len() {
543 let (prev_steps, curr_steps) = self.steps.split_at_mut(step_index);
544 let curr_step = &mut curr_steps[0];
545 match curr_step {
546 ArithCircuitStep::Const(_) | ArithCircuitStep::Var(_) => {}
547 ArithCircuitStep::Add(left, right) => {
548 match (&prev_steps[*left], &prev_steps[*right]) {
549 (ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => {
550 *curr_step = ArithCircuitStep::Const(*left + *right);
551 }
552 (ArithCircuitStep::Const(left), right) if *left == F::ZERO => {
553 *curr_step = *right;
554 }
555 (left, ArithCircuitStep::Const(right)) if *right == F::ZERO => {
556 *curr_step = *left;
557 }
558 (left, right) if left == right && F::CHARACTERISTIC == 2 => {
559 *curr_step = ArithCircuitStep::Const(F::ZERO);
560 }
561 _ => {}
562 }
563 }
564 ArithCircuitStep::Mul(left, right) => {
565 match (&prev_steps[*left], &prev_steps[*right]) {
566 (ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => {
567 *curr_step = ArithCircuitStep::Const(*left * *right);
568 }
569 (ArithCircuitStep::Const(left), _) if *left == F::ZERO => {
570 *curr_step = ArithCircuitStep::Const(F::ZERO);
571 }
572 (_, ArithCircuitStep::Const(right)) if *right == F::ZERO => {
573 *curr_step = ArithCircuitStep::Const(F::ZERO);
574 }
575 (ArithCircuitStep::Const(left), right) if *left == F::ONE => {
576 *curr_step = *right;
577 }
578 (left, ArithCircuitStep::Const(right)) if *right == F::ONE => {
579 *curr_step = *left;
580 }
581 _ => {}
582 }
583 }
584 ArithCircuitStep::Pow(base, exp) => match prev_steps[*base] {
585 ArithCircuitStep::Const(value) => {
586 *curr_step = ArithCircuitStep::Const(PackedField::pow(value, *exp));
587 }
588 ArithCircuitStep::Pow(base_inner, exp_inner) => {
589 *curr_step = ArithCircuitStep::Pow(base_inner, *exp * exp_inner);
590 }
591 _ => {}
592 },
593 }
594 }
595 }
596
597 fn deduplicate_steps(&mut self) {
599 let mut step_map = HashMap::new();
600 let mut step_indices = Vec::with_capacity(self.steps.len());
601 for step in 0..self.steps.len() {
602 let node = StepNode {
603 index: step,
604 steps: &self.steps,
605 };
606 match step_map.entry(node) {
607 Entry::Occupied(entry) => {
608 step_indices.push(*entry.get());
609 }
610 Entry::Vacant(entry) => {
611 entry.insert(step);
612 step_indices.push(step);
613 }
614 }
615 }
616
617 for step in &mut self.steps {
618 match step {
619 ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
620 *left = step_indices[*left];
621 *right = step_indices[*right];
622 }
623 ArithCircuitStep::Pow(base, _) => *base = step_indices[*base],
624 _ => (),
625 }
626 }
627 }
628
629 fn compress_unused_steps(&mut self) {
631 fn mark_used<F: Field>(step: usize, steps: &[ArithCircuitStep<F>], used: &mut [bool]) {
632 if used[step] {
633 return;
634 }
635 used[step] = true;
636 match steps[step] {
637 ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
638 mark_used(left, steps, used);
639 mark_used(right, steps, used);
640 }
641 ArithCircuitStep::Pow(base, _) => mark_used(base, steps, used),
642 _ => (),
643 }
644 }
645
646 let mut used = vec![false; self.steps.len()];
647 mark_used(self.steps.len() - 1, &self.steps, &mut used);
648
649 let mut steps_map = (0..self.steps.len()).collect::<Vec<_>>();
650 let mut target_index = 0;
651 for source_index in 0..self.steps.len() {
652 if used[source_index] {
653 if target_index != source_index {
654 match &mut self.steps[source_index] {
655 ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
656 *left = steps_map[*left];
657 *right = steps_map[*right];
658 }
659 ArithCircuitStep::Pow(base, _) => *base = steps_map[*base],
660 _ => (),
661 }
662
663 steps_map[source_index] = target_index;
664 self.steps[target_index] = self.steps[source_index];
665 }
666
667 target_index += 1;
668 }
669 }
670
671 self.steps.truncate(target_index);
672 }
673
674 pub fn optimize_in_place(&mut self) {
679 self.optimize_constants();
680 self.deduplicate_steps();
681 self.compress_unused_steps();
682 }
683
684 pub fn optimize(mut self) -> Self {
686 self.optimize_constants();
687 self.deduplicate_steps();
688 self.compress_unused_steps();
689
690 self
691 }
692}
693
694impl<F: Field> From<&ArithExpr<F>> for ArithCircuit<F> {
695 fn from(expr: &ArithExpr<F>) -> Self {
696 fn visit_node<F: Field>(
697 node: &Arc<ArithExpr<F>>,
698 node_to_index: &mut HashMap<*const ArithExpr<F>, usize>,
699 steps: &mut Vec<ArithCircuitStep<F>>,
700 ) -> usize {
701 if let Some(index) = node_to_index.get(&Arc::as_ptr(node)) {
702 return *index;
703 }
704
705 let step = match &**node {
706 ArithExpr::Const(value) => ArithCircuitStep::Const(*value),
707 ArithExpr::Var(index) => ArithCircuitStep::Var(*index),
708 ArithExpr::Add(left, right) => {
709 let left = visit_node(left, node_to_index, steps);
710 let right = visit_node(right, node_to_index, steps);
711 ArithCircuitStep::Add(left, right)
712 }
713 ArithExpr::Mul(left, right) => {
714 let left = visit_node(left, node_to_index, steps);
715 let right = visit_node(right, node_to_index, steps);
716 ArithCircuitStep::Mul(left, right)
717 }
718 ArithExpr::Pow(base, exp) => {
719 let base = visit_node(base, node_to_index, steps);
720 ArithCircuitStep::Pow(base, *exp)
721 }
722 };
723
724 steps.push(step);
725 node_to_index.insert(Arc::as_ptr(node), steps.len() - 1);
726 steps.len() - 1
727 }
728
729 let mut steps = Vec::new();
730 let mut node_to_index = HashMap::new();
731 match expr {
732 ArithExpr::Const(c) => {
733 steps.push(ArithCircuitStep::Const(*c));
734 }
735 ArithExpr::Var(var) => {
736 steps.push(ArithCircuitStep::Var(*var));
737 }
738 ArithExpr::Add(left, right) => {
739 let left = visit_node(left, &mut node_to_index, &mut steps);
740 let right = visit_node(right, &mut node_to_index, &mut steps);
741 steps.push(ArithCircuitStep::Add(left, right));
742 }
743 ArithExpr::Mul(left, right) => {
744 let left = visit_node(left, &mut node_to_index, &mut steps);
745 let right = visit_node(right, &mut node_to_index, &mut steps);
746 steps.push(ArithCircuitStep::Mul(left, right));
747 }
748 ArithExpr::Pow(base, exp) => {
749 let base = visit_node(base, &mut node_to_index, &mut steps);
750 steps.push(ArithCircuitStep::Pow(base, *exp));
751 }
752 }
753
754 Self { steps }
755 }
756}
757
758impl<F: Field> From<ArithExpr<F>> for ArithCircuit<F> {
759 fn from(expr: ArithExpr<F>) -> Self {
760 Self::from(&expr)
761 }
762}
763
764impl<F: Field> From<&ArithCircuit<F>> for ArithExpr<F> {
765 fn from(circuit: &ArithCircuit<F>) -> Self {
766 let mut step_to_node = vec![Option::<Arc<Self>>::None; circuit.steps.len()];
767
768 for step in 0..circuit.steps.len() {
769 let node = match &circuit.steps[step] {
770 ArithCircuitStep::Const(value) => Self::Const(*value),
771 ArithCircuitStep::Var(index) => Self::Var(*index),
772 ArithCircuitStep::Add(left, right) => {
773 let left = step_to_node[*left].clone().expect("step must be present");
774 let right = step_to_node[*right].clone().expect("step must be present");
775 Self::Add(left, right)
776 }
777 ArithCircuitStep::Mul(left, right) => {
778 let left = step_to_node[*left].clone().expect("step must be present");
779 let right = step_to_node[*right].clone().expect("step must be present");
780 Self::Mul(left, right)
781 }
782 ArithCircuitStep::Pow(base, exp) => {
783 let base = step_to_node[*base].clone().expect("step must be present");
784 Self::Pow(base, *exp)
785 }
786 };
787 step_to_node[step] = Some(Arc::new(node));
788 }
789
790 Arc::into_inner(
791 step_to_node
792 .pop()
793 .expect("steps should not be empty")
794 .expect("last step should be initialized"),
795 )
796 .expect("last step must have a single instance")
797 }
798}
799
800impl<F: Field> From<ArithCircuit<F>> for ArithExpr<F> {
801 fn from(circuit: ArithCircuit<F>) -> Self {
802 Self::from(&circuit)
803 }
804}
805
806impl<F: Field> Display for ArithCircuit<F> {
807 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
808 fn display_step<F: Field>(
809 step: usize,
810 steps: &[ArithCircuitStep<F>],
811 f: &mut fmt::Formatter<'_>,
812 ) -> Result<(), fmt::Error> {
813 match &steps[step] {
814 ArithCircuitStep::Const(value) => write!(f, "{value}"),
815 ArithCircuitStep::Var(index) => write!(f, "x{index}"),
816 ArithCircuitStep::Add(left, right) => {
817 write!(f, "(")?;
818 display_step(*left, steps, f)?;
819 write!(f, " + ")?;
820 display_step(*right, steps, f)?;
821 write!(f, ")")
822 }
823 ArithCircuitStep::Mul(left, right) => {
824 write!(f, "(")?;
825 display_step(*left, steps, f)?;
826 write!(f, " * ")?;
827 display_step(*right, steps, f)?;
828 write!(f, ")")
829 }
830 ArithCircuitStep::Pow(base, exp) => {
831 write!(f, "(")?;
832 display_step(*base, steps, f)?;
833 write!(f, ")^{exp}")
834 }
835 }
836 }
837
838 display_step(self.steps.len() - 1, &self.steps, f)
839 }
840}
841
842impl<F: Field> PartialEq for ArithCircuit<F> {
843 fn eq(&self, other: &Self) -> bool {
844 StepNode {
845 index: self.steps.len() - 1,
846 steps: &self.steps,
847 } == StepNode {
848 index: other.steps.len() - 1,
849 steps: &other.steps,
850 }
851 }
852}
853
854impl<F: Field> Add for ArithCircuit<F> {
855 type Output = Self;
856
857 fn add(mut self, rhs: Self) -> Self {
858 self += rhs;
859 self
860 }
861}
862
863impl<F: Field> AddAssign for ArithCircuit<F> {
864 fn add_assign(&mut self, mut rhs: Self) {
865 let old_len = self.steps.len();
866 add_offset(&mut rhs.steps, old_len);
867 self.steps.extend(rhs.steps);
868 self.steps
869 .push(ArithCircuitStep::Add(old_len - 1, self.steps.len() - 1));
870 }
871}
872
873impl<F: Field> Sub for ArithCircuit<F> {
874 type Output = Self;
875
876 fn sub(mut self, rhs: Self) -> Self {
877 self -= rhs;
878 self
879 }
880}
881
882impl<F: Field> SubAssign for ArithCircuit<F> {
883 #[allow(clippy::suspicious_op_assign_impl)]
884 fn sub_assign(&mut self, rhs: Self) {
885 *self += rhs;
886 }
887}
888
889impl<F: Field> Mul for ArithCircuit<F> {
890 type Output = Self;
891
892 fn mul(mut self, rhs: Self) -> Self {
893 self *= rhs;
894 self
895 }
896}
897
898impl<F: Field> MulAssign for ArithCircuit<F> {
899 fn mul_assign(&mut self, mut rhs: Self) {
900 let old_len = self.steps.len();
901 add_offset(&mut rhs.steps, old_len);
902 self.steps.extend(rhs.steps);
903 self.steps
904 .push(ArithCircuitStep::Mul(old_len - 1, self.steps.len() - 1));
905 }
906}
907
908impl<F: Field> Sum for ArithCircuit<F> {
909 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
910 iter.fold(Self::zero(), |sum, item| sum + item)
911 }
912}
913
914impl<F: Field> Product for ArithCircuit<F> {
915 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
916 iter.fold(Self::one(), |product, item| product * item)
917 }
918}
919
920fn add_offset<F: Field>(steps: &mut [ArithCircuitStep<F>], offset: usize) {
921 for step in steps.iter_mut() {
922 match step {
923 ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
924 *left += offset;
925 *right += offset;
926 }
927 ArithCircuitStep::Pow(base, _) => {
928 *base += offset;
929 }
930 _ => (),
931 }
932 }
933}
934#[derive(Debug, Default, Clone, PartialEq, Eq)]
936pub struct LinearNormalForm<F: Field> {
937 pub constant: F,
939 pub var_coeffs: Vec<F>,
941}
942
943struct SparseLinearNormalForm<F: Field> {
944 pub constant: F,
946 pub var_coeffs: HashMap<usize, F>,
948 pub dense_linear_form_len: usize,
951}
952
953impl<F: Field> From<F> for SparseLinearNormalForm<F> {
954 fn from(value: F) -> Self {
955 Self {
956 constant: value,
957 dense_linear_form_len: 0,
958 var_coeffs: HashMap::new(),
959 }
960 }
961}
962
963impl<F: Field> Add for SparseLinearNormalForm<F> {
964 type Output = Self;
965 fn add(self, rhs: Self) -> Self::Output {
966 let (mut result, consumable) = if self.var_coeffs.len() < rhs.var_coeffs.len() {
967 (rhs, self)
968 } else {
969 (self, rhs)
970 };
971 result.constant += consumable.constant;
972 if consumable.dense_linear_form_len > result.dense_linear_form_len {
973 result.dense_linear_form_len = consumable.dense_linear_form_len;
974 }
975
976 for (index, coeff) in consumable.var_coeffs {
977 result
978 .var_coeffs
979 .entry(index)
980 .and_modify(|res_coeff| {
981 *res_coeff += coeff;
982 })
983 .or_insert(coeff);
984 }
985 result
986 }
987}
988
989impl<F: Field> Mul for SparseLinearNormalForm<F> {
990 type Output = Result<Self, Error>;
991 fn mul(self, rhs: Self) -> Result<Self, Error> {
992 if !self.var_coeffs.is_empty() && !rhs.var_coeffs.is_empty() {
993 return Err(Error::NonLinearExpression);
994 }
995 let (mut result, consumable) = if self.var_coeffs.is_empty() {
996 (rhs, self)
997 } else {
998 (self, rhs)
999 };
1000 result.constant *= consumable.constant;
1001 for coeff in result.var_coeffs.values_mut() {
1002 *coeff *= consumable.constant;
1003 }
1004 Ok(result)
1005 }
1006}
1007
1008impl<F: Field> From<SparseLinearNormalForm<F>> for LinearNormalForm<F> {
1009 fn from(value: SparseLinearNormalForm<F>) -> Self {
1010 let mut var_coeffs = vec![F::ZERO; value.dense_linear_form_len];
1011 for (i, coeff) in value.var_coeffs {
1012 var_coeffs[i] = coeff;
1013 }
1014 Self {
1015 constant: value.constant,
1016 var_coeffs,
1017 }
1018 }
1019}
1020
1021#[derive(Eq)]
1024struct StepNode<'a, F: Field> {
1025 index: usize,
1026 steps: &'a [ArithCircuitStep<F>],
1027}
1028
1029impl<F: Field> StepNode<'_, F> {
1030 fn prev_step(&self, step: usize) -> Self {
1031 StepNode {
1032 index: step,
1033 steps: self.steps,
1034 }
1035 }
1036}
1037
1038impl<F: Field> PartialEq for StepNode<'_, F> {
1039 #[allow(clippy::suspicious_operation_groupings)] fn eq(&self, other: &Self) -> bool {
1041 match (&self.steps[self.index], &other.steps[other.index]) {
1042 (ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => left == right,
1043 (ArithCircuitStep::Var(left), ArithCircuitStep::Var(right)) => left == right,
1044 (
1045 ArithCircuitStep::Add(left, right),
1046 ArithCircuitStep::Add(other_left, other_right),
1047 )
1048 | (
1049 ArithCircuitStep::Mul(left, right),
1050 ArithCircuitStep::Mul(other_left, other_right),
1051 ) => {
1052 self.prev_step(*left) == other.prev_step(*other_left)
1053 && self.prev_step(*right) == other.prev_step(*other_right)
1054 }
1055 (ArithCircuitStep::Pow(base, exp), ArithCircuitStep::Pow(other_base, other_exp)) => {
1056 self.prev_step(*base) == other.prev_step(*other_base) && exp == other_exp
1057 }
1058 _ => false,
1059 }
1060 }
1061}
1062
1063impl<F: Field> Hash for StepNode<'_, F> {
1064 fn hash<H: Hasher>(&self, state: &mut H) {
1065 match self.steps[self.index] {
1066 ArithCircuitStep::Const(value) => {
1067 0u8.hash(state);
1068 value.hash(state);
1069 }
1070 ArithCircuitStep::Var(index) => {
1071 1u8.hash(state);
1072 index.hash(state);
1073 }
1074 ArithCircuitStep::Add(left, right) => {
1075 2u8.hash(state);
1076 self.prev_step(left).hash(state);
1077 self.prev_step(right).hash(state);
1078 }
1079 ArithCircuitStep::Mul(left, right) => {
1080 3u8.hash(state);
1081 self.prev_step(left).hash(state);
1082 self.prev_step(right).hash(state);
1083 }
1084 ArithCircuitStep::Pow(base, exp) => {
1085 4u8.hash(state);
1086 self.prev_step(base).hash(state);
1087 exp.hash(state);
1088 }
1089 }
1090 }
1091}
1092
1093#[derive(Default)]
1095pub struct EvalCost {
1096 pub n_adds: usize,
1100 pub n_muls: usize,
1104 pub n_squares: usize,
1109}
1110
1111impl EvalCost {
1112 pub fn add(n_adds: usize) -> Self {
1113 Self {
1114 n_adds,
1115 ..Self::default()
1116 }
1117 }
1118
1119 pub fn mul(n_muls: usize) -> Self {
1120 Self {
1121 n_muls,
1122 ..Self::default()
1123 }
1124 }
1125
1126 pub fn square(n_squares: usize) -> Self {
1127 Self {
1128 n_squares,
1129 ..Self::default()
1130 }
1131 }
1132
1133 pub fn mult_cost_approx(&self) -> usize {
1139 self.n_muls + self.n_squares.div_ceil(5)
1140 }
1141}
1142
1143impl Add for EvalCost {
1144 type Output = Self;
1145 fn add(self, other: Self) -> Self {
1146 Self {
1147 n_adds: self.n_adds + other.n_adds,
1148 n_muls: self.n_muls + other.n_muls,
1149 n_squares: self.n_squares + other.n_squares,
1150 }
1151 }
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156 use std::collections::HashSet;
1157
1158 use assert_matches::assert_matches;
1159 use binius_field::{BinaryField, BinaryField128b, BinaryField1b, BinaryField8b};
1160 use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes};
1161
1162 use super::*;
1163
1164 #[test]
1165 fn test_degree_with_pow() {
1166 let expr = ArithCircuit::constant(BinaryField8b::new(6)).pow(7);
1167 assert_eq!(expr.degree(), 0);
1168
1169 let expr: ArithCircuit<BinaryField8b> = ArithCircuit::var(0).pow(7);
1170 assert_eq!(expr.degree(), 7);
1171
1172 let expr: ArithCircuit<BinaryField8b> =
1173 (ArithCircuit::var(0) * ArithCircuit::var(1)).pow(7);
1174 assert_eq!(expr.degree(), 14);
1175 }
1176
1177 #[test]
1178 fn test_n_vars() {
1179 type F = BinaryField8b;
1180 let expr = ArithCircuit::<F>::var(0) * ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR)
1181 + ArithCircuit::var(2).pow(2);
1182 assert_eq!(expr.n_vars(), 3);
1183 }
1184
1185 #[test]
1186 fn test_leading_term_with_degree() {
1187 let expr = ArithCircuit::var(0)
1188 * (ArithCircuit::var(1)
1189 * ArithCircuit::var(2)
1190 * ArithCircuit::constant(BinaryField8b::MULTIPLICATIVE_GENERATOR)
1191 + ArithCircuit::var(4))
1192 + ArithCircuit::var(5).pow(3)
1193 + ArithCircuit::constant(BinaryField8b::ONE);
1194
1195 let expected_expr = ArithCircuit::var(0)
1196 * ((ArithCircuit::var(1) * ArithCircuit::var(2))
1197 * ArithCircuit::constant(BinaryField8b::MULTIPLICATIVE_GENERATOR))
1198 + ArithCircuit::var(5).pow(3);
1199
1200 assert_eq!(expr.leading_term_with_degree(expr.steps().len() - 1), (3, expected_expr));
1201 }
1202
1203 #[test]
1204 fn test_remap_vars_with_too_few_vars() {
1205 type F = BinaryField8b;
1206 let expr =
1207 ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(1)).pow(3);
1208 assert_matches!(expr.remap_vars(&[5]), Err(Error::IncorrectArgumentLength { .. }));
1209 }
1210
1211 #[test]
1212 fn test_remap_vars_works() {
1213 type F = BinaryField8b;
1214 let expr =
1215 ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(1)).pow(3);
1216 let new_expr = expr.remap_vars(&[5, 3]);
1217
1218 let expected =
1219 ((ArithCircuit::var(5) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(3)).pow(3);
1220 assert_eq!(new_expr.unwrap(), expected);
1221 }
1222
1223 #[test]
1224 fn test_optimize_identity_handling() {
1225 type F = BinaryField8b;
1226 let zero = ArithCircuit::<F>::zero();
1227 let one = ArithCircuit::<F>::one();
1228
1229 assert_eq!((zero.clone() * ArithCircuit::<F>::var(0)).optimize(), zero);
1230 assert_eq!((ArithCircuit::<F>::var(0) * zero.clone()).optimize(), zero);
1231
1232 assert_eq!((ArithCircuit::<F>::var(0) * one.clone()).optimize(), ArithCircuit::var(0));
1233 assert_eq!((one * ArithCircuit::<F>::var(0)).optimize(), ArithCircuit::var(0));
1234
1235 assert_eq!((ArithCircuit::<F>::var(0) + zero.clone()).optimize(), ArithCircuit::var(0));
1236 assert_eq!((zero.clone() + ArithCircuit::<F>::var(0)).optimize(), ArithCircuit::var(0));
1237
1238 assert_eq!((ArithCircuit::<F>::var(0) + ArithCircuit::var(0)).optimize(), zero);
1239 }
1240
1241 #[test]
1242 fn test_const_subst_and_optimize() {
1243 type F = BinaryField8b;
1245 let expr = ArithCircuit::var(0) * ArithCircuit::var(1) + ArithCircuit::one()
1246 - ArithCircuit::var(1);
1247 assert_eq!(expr.const_subst(1, F::ZERO).optimize().get_constant(), Some(F::ONE));
1248 }
1249
1250 #[test]
1251 fn test_expression_upcast() {
1252 type F8 = BinaryField8b;
1253 type F = BinaryField128b;
1254
1255 let expr = ((ArithCircuit::var(0) + ArithCircuit::constant(F8::ONE))
1256 * ArithCircuit::constant(F8::new(222)))
1257 .pow(3);
1258
1259 let expected = ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1260 * ArithCircuit::constant(F::new(222)))
1261 .pow(3);
1262 assert_eq!(expr.convert_field::<F>(), expected);
1263 }
1264
1265 #[test]
1266 fn test_expression_downcast() {
1267 type F8 = BinaryField8b;
1268 type F = BinaryField128b;
1269
1270 let expr = ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1271 * ArithCircuit::constant(F::new(222)))
1272 .pow(3);
1273
1274 assert!(expr.try_convert_field::<BinaryField1b>().is_err());
1275
1276 let expected = ((ArithCircuit::var(0) + ArithCircuit::constant(F8::ONE))
1277 * ArithCircuit::constant(F8::new(222)))
1278 .pow(3);
1279 assert_eq!(expr.try_convert_field::<BinaryField8b>().unwrap(), expected);
1280 }
1281
1282 #[test]
1283 fn test_linear_normal_form() {
1284 type F = BinaryField128b;
1285 struct Case {
1286 expr: ArithCircuit<F>,
1287 expected: LinearNormalForm<F>,
1288 }
1289 let cases = vec![
1290 Case {
1291 expr: ArithCircuit::constant(F::ONE),
1292 expected: LinearNormalForm {
1293 constant: F::ONE,
1294 var_coeffs: vec![],
1295 },
1296 },
1297 Case {
1298 expr: (ArithCircuit::constant(F::new(2)) * ArithCircuit::constant(F::new(3)))
1299 .pow(2) + ArithCircuit::constant(F::new(3))
1300 * (ArithCircuit::constant(F::new(4)) + ArithCircuit::var(0)),
1301 expected: LinearNormalForm {
1302 constant: (F::new(2) * F::new(3)).pow(2) + F::new(3) * F::new(4),
1303 var_coeffs: vec![F::new(3)],
1304 },
1305 },
1306 Case {
1307 expr: ArithCircuit::constant(F::new(133))
1308 + ArithCircuit::constant(F::new(42)) * ArithCircuit::var(0)
1309 + ArithCircuit::var(2)
1310 + ArithCircuit::constant(F::new(11))
1311 * ArithCircuit::constant(F::new(37))
1312 * ArithCircuit::var(3),
1313 expected: LinearNormalForm {
1314 constant: F::new(133),
1315 var_coeffs: vec![F::new(42), F::ZERO, F::ONE, F::new(11) * F::new(37)],
1316 },
1317 },
1318 ];
1319 for Case { expr, expected } in cases {
1320 let normal_form = expr.linear_normal_form().unwrap();
1321 assert_eq!(normal_form.constant, expected.constant);
1322 assert_eq!(normal_form.var_coeffs, expected.var_coeffs);
1323 }
1324 }
1325
1326 fn unique_nodes_count<F: Field>(expr: &ArithCircuit<F>) -> usize {
1327 let mut unique_nodes = HashSet::new();
1328
1329 for step in 0..expr.steps.len() {
1330 unique_nodes.insert(StepNode {
1331 index: step,
1332 steps: &expr.steps,
1333 });
1334 }
1335
1336 unique_nodes.len()
1337 }
1338
1339 fn check_serialize_bytes_roundtrip<F: Field>(expr: ArithCircuit<F>) {
1340 let mut buf = Vec::new();
1341
1342 expr.serialize(&mut buf, SerializationMode::CanonicalTower)
1343 .unwrap();
1344 let deserialized =
1345 ArithCircuit::<F>::deserialize(&buf[..], SerializationMode::CanonicalTower).unwrap();
1346 assert_eq!(expr, deserialized);
1347 assert_eq!(unique_nodes_count(&expr), unique_nodes_count(&deserialized));
1348 }
1349
1350 #[test]
1351 fn test_serialize_bytes_roundtrip() {
1352 type F = BinaryField128b;
1353 let expr = ArithCircuit::var(0)
1354 * (ArithCircuit::var(1)
1355 * ArithCircuit::var(2)
1356 * ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR)
1357 + ArithCircuit::var(4))
1358 + ArithCircuit::var(5).pow(3)
1359 + ArithCircuit::constant(F::ONE);
1360
1361 check_serialize_bytes_roundtrip(expr);
1362 }
1363
1364 #[test]
1365 fn test_serialize_bytes_rountrip_with_duplicates() {
1366 type F = BinaryField128b;
1367 let expr = (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1368 * (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1369 + (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1370 + ArithCircuit::var(1);
1371
1372 check_serialize_bytes_roundtrip(expr);
1373 }
1374
1375 #[test]
1376 fn test_binary_tower_level() {
1377 type F = BinaryField128b;
1378 let expr =
1379 ArithCircuit::constant(F::ONE) + ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR);
1380 assert_eq!(expr.binary_tower_level(), F::MULTIPLICATIVE_GENERATOR.min_tower_level());
1381 }
1382
1383 #[test]
1384 fn test_arith_circuit_steps() {
1385 type F = BinaryField8b;
1386 let expr = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1)) * ArithCircuit::var(2);
1387 let steps = expr.steps();
1388 assert_eq!(steps.len(), 5); assert!(matches!(steps[0], ArithCircuitStep::Var(0)));
1390 assert!(matches!(steps[1], ArithCircuitStep::Var(1)));
1391 assert!(matches!(steps[2], ArithCircuitStep::Add(_, _)));
1392 assert!(matches!(steps[3], ArithCircuitStep::Var(2)));
1393 assert!(matches!(steps[4], ArithCircuitStep::Mul(_, _)));
1394 }
1395
1396 #[test]
1397 fn test_optimize_constants() {
1398 type F = BinaryField8b;
1399 let mut circuit = (ArithCircuit::<F>::var(0) + ArithCircuit::constant(F::ZERO))
1400 * ArithCircuit::var(1)
1401 + ArithCircuit::constant(F::ONE) * ArithCircuit::var(2)
1402 + ArithCircuit::constant(F::ONE).pow(4).pow(5)
1403 + (ArithCircuit::var(5) + ArithCircuit::var(5));
1404 circuit.optimize_constants();
1405
1406 let expected_ciruit = ArithCircuit::var(0) * ArithCircuit::var(1)
1407 + ArithCircuit::var(2)
1408 + ArithCircuit::constant(F::ONE);
1409
1410 assert_eq!(circuit, expected_ciruit);
1411 }
1412
1413 #[test]
1414 fn test_deduplicate_steps() {
1415 type F = BinaryField8b;
1416 let mut circuit = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1))
1417 * (ArithCircuit::var(0) + ArithCircuit::var(1))
1418 + (ArithCircuit::var(0) + ArithCircuit::var(1));
1419 circuit.deduplicate_steps();
1420
1421 let expected_circuit = ArithCircuit::<F> {
1422 steps: vec![
1423 ArithCircuitStep::Var(0),
1424 ArithCircuitStep::Var(1),
1425 ArithCircuitStep::Add(0, 1),
1426 ArithCircuitStep::Mul(2, 2),
1427 ArithCircuitStep::Add(3, 2),
1428 ],
1429 };
1430 assert_eq!(circuit, expected_circuit);
1431 }
1432
1433 #[test]
1434 fn test_compress_unused_steps() {
1435 type F = BinaryField8b;
1436 let mut circuit = ArithCircuit::<F> {
1437 steps: vec![
1438 ArithCircuitStep::Var(0),
1439 ArithCircuitStep::Var(1),
1440 ArithCircuitStep::Var(2),
1441 ArithCircuitStep::Add(0, 1),
1442 ArithCircuitStep::Var(3),
1443 ArithCircuitStep::Const(F::ZERO),
1444 ArithCircuitStep::Var(2),
1445 ArithCircuitStep::Mul(3, 3),
1446 ],
1447 };
1448 circuit.compress_unused_steps();
1449
1450 let expected_circuit = ArithCircuit::<F> {
1451 steps: vec![
1452 ArithCircuitStep::Var(0),
1453 ArithCircuitStep::Var(1),
1454 ArithCircuitStep::Add(0, 1),
1455 ArithCircuitStep::Mul(2, 2),
1456 ],
1457 };
1458 assert_eq!(circuit.steps, expected_circuit.steps);
1459 }
1460
1461 #[test]
1462 fn test_conversion_from_expr_node_doesnt_create_duplicated_steps() {
1463 type F = BinaryField8b;
1464 let sub_expr = Arc::new(ArithExpr::<F>::Var(0) + ArithExpr::<F>::Var(1));
1465 let expr = ArithExpr::Mul(sub_expr.clone(), sub_expr.clone()) + sub_expr;
1466 let circuit = ArithCircuit::<F>::from(&expr);
1467 assert_eq!(circuit.steps.len(), 5);
1468 assert_eq!(unique_nodes_count(&circuit), 5);
1469 }
1470
1471 fn unique_nodes_count_expr<F: Field>(expr: &ArithExpr<F>) -> usize {
1472 fn add_nodes<F: Field>(
1473 expr: &Arc<ArithExpr<F>>,
1474 unique_nodes: &mut HashSet<*const ArithExpr<F>>,
1475 ) {
1476 if !unique_nodes.insert(Arc::as_ptr(expr)) {
1477 return;
1478 }
1479
1480 match expr.as_ref() {
1481 ArithExpr::Const(_) | ArithExpr::Var(_) => {}
1482 ArithExpr::Add(left, right) | ArithExpr::Mul(left, right) => {
1483 add_nodes(left, unique_nodes);
1484 add_nodes(right, unique_nodes);
1485 }
1486 ArithExpr::Pow(base, _) => {
1487 add_nodes(base, unique_nodes);
1488 }
1489 }
1490 }
1491
1492 let mut unique_nodes = HashSet::new();
1493 add_nodes(&Arc::new(expr.clone()), &mut unique_nodes);
1494
1495 unique_nodes.len()
1496 }
1497
1498 #[test]
1499 fn test_conversion_from_circuit_to_expr_node() {
1500 type F = BinaryField8b;
1501
1502 let arith_circuit = ArithCircuit::<F> {
1503 steps: vec![
1504 ArithCircuitStep::Var(0),
1505 ArithCircuitStep::Var(1),
1506 ArithCircuitStep::Add(0, 1),
1507 ArithCircuitStep::Mul(2, 2),
1508 ArithCircuitStep::Add(3, 2),
1509 ],
1510 };
1511 let expr = ArithExpr::from(&arith_circuit);
1512 let expected_expr = (ArithExpr::Var(0) + ArithExpr::Var(1))
1513 * (ArithExpr::Var(0) + ArithExpr::Var(1))
1514 + (ArithExpr::Var(0) + ArithExpr::Var(1));
1515 assert_eq!(expr, expected_expr);
1516 assert_eq!(unique_nodes_count_expr(&expr), 5);
1517 }
1518
1519 #[test]
1520 fn test_evaluate() {
1521 type F = BinaryField8b;
1522 let expr = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1))
1523 * (ArithCircuit::var(2) + ArithCircuit::var(3)).pow(5);
1524 let result = expr
1525 .evaluate(&[F::new(2), F::new(3), F::new(4), F::new(5)])
1526 .unwrap();
1527 assert_eq!(result, F::new(2) + F::new(3) * (F::new(4) + F::new(5)).pow(5));
1528 }
1529}