1use std::{
4 cmp::Ordering,
5 collections::{HashMap, hash_map::Entry},
6 fmt::{self, Display},
7 hash::{Hash, Hasher},
8 iter::{Product, Sum},
9 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
10 sync::Arc,
11};
12
13use binius_field::{Field, PackedField, TowerField};
14use binius_macros::{DeserializeBytes, SerializeBytes};
15
16use super::error::Error;
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum ArithExpr<F: Field> {
25 Const(F),
26 Var(usize),
27 Add(Arc<ArithExpr<F>>, Arc<ArithExpr<F>>),
28 Mul(Arc<ArithExpr<F>>, Arc<ArithExpr<F>>),
29 Pow(Arc<ArithExpr<F>>, u64),
30}
31
32impl<F: Field + Display> Display for ArithExpr<F> {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 Self::Const(v) => write!(f, "{v}"),
36 Self::Var(i) => write!(f, "x{i}"),
37 Self::Add(x, y) => write!(f, "({} + {})", &**x, &**y),
38 Self::Mul(x, y) => write!(f, "({} * {})", &**x, &**y),
39 Self::Pow(x, p) => write!(f, "({})^{p}", &**x),
40 }
41 }
42}
43
44impl<F: Field> ArithExpr<F> {
45 pub fn pow(self, exp: u64) -> Self {
46 Self::Pow(Arc::new(self), exp)
47 }
48
49 pub const fn zero() -> Self {
50 Self::Const(F::ZERO)
51 }
52
53 pub const fn one() -> Self {
54 Self::Const(F::ONE)
55 }
56}
57
58impl<F> Default for ArithExpr<F>
59where
60 F: Field,
61{
62 fn default() -> Self {
63 Self::zero()
64 }
65}
66
67impl<F> Add for ArithExpr<F>
68where
69 F: Field,
70{
71 type Output = Self;
72
73 fn add(self, rhs: Self) -> Self {
74 Self::Add(Arc::new(self), Arc::new(rhs))
75 }
76}
77
78impl<F> Add<Arc<Self>> for ArithExpr<F>
79where
80 F: Field,
81{
82 type Output = Self;
83
84 fn add(self, rhs: Arc<Self>) -> Self {
85 Self::Add(Arc::new(self), rhs)
86 }
87}
88
89impl<F> AddAssign for ArithExpr<F>
90where
91 F: Field,
92{
93 fn add_assign(&mut self, rhs: Self) {
94 *self = std::mem::take(self) + rhs;
95 }
96}
97
98impl<F> AddAssign<Arc<Self>> for ArithExpr<F>
99where
100 F: Field,
101{
102 fn add_assign(&mut self, rhs: Arc<Self>) {
103 *self = std::mem::take(self) + rhs;
104 }
105}
106
107impl<F> Sub for ArithExpr<F>
108where
109 F: Field,
110{
111 type Output = Self;
112
113 fn sub(self, rhs: Self) -> Self {
114 Self::Add(Arc::new(self), Arc::new(rhs))
115 }
116}
117
118impl<F> Sub<Arc<Self>> for ArithExpr<F>
119where
120 F: Field,
121{
122 type Output = Self;
123
124 fn sub(self, rhs: Arc<Self>) -> Self {
125 Self::Add(Arc::new(self), rhs)
126 }
127}
128
129impl<F> SubAssign for ArithExpr<F>
130where
131 F: Field,
132{
133 fn sub_assign(&mut self, rhs: Self) {
134 *self = std::mem::take(self) - rhs;
135 }
136}
137
138impl<F> SubAssign<Arc<Self>> for ArithExpr<F>
139where
140 F: Field,
141{
142 fn sub_assign(&mut self, rhs: Arc<Self>) {
143 *self = std::mem::take(self) - rhs;
144 }
145}
146
147impl<F> Mul for ArithExpr<F>
148where
149 F: Field,
150{
151 type Output = Self;
152
153 fn mul(self, rhs: Self) -> Self {
154 Self::Mul(Arc::new(self), Arc::new(rhs))
155 }
156}
157
158impl<F> Mul<Arc<Self>> for ArithExpr<F>
159where
160 F: Field,
161{
162 type Output = Self;
163
164 fn mul(self, rhs: Arc<Self>) -> Self {
165 Self::Mul(Arc::new(self), rhs)
166 }
167}
168
169impl<F> MulAssign for ArithExpr<F>
170where
171 F: Field,
172{
173 fn mul_assign(&mut self, rhs: Self) {
174 *self = std::mem::take(self) * rhs;
175 }
176}
177
178impl<F> MulAssign<Arc<Self>> for ArithExpr<F>
179where
180 F: Field,
181{
182 fn mul_assign(&mut self, rhs: Arc<Self>) {
183 *self = std::mem::take(self) * rhs;
184 }
185}
186
187impl<F: Field> Sum for ArithExpr<F> {
188 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
189 iter.reduce(|acc, item| acc + item).unwrap_or(Self::zero())
190 }
191}
192
193impl<F: Field> Product for ArithExpr<F> {
194 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
195 iter.reduce(|acc, item| acc * item).unwrap_or(Self::one())
196 }
197}
198
199#[derive(Clone, Copy, Debug, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
200pub enum ArithCircuitStep<F: Field> {
201 Add(usize, usize),
202 Mul(usize, usize),
203 Pow(usize, u64),
204 Const(F),
205 Var(usize),
206}
207
208impl<F: Field> Default for ArithCircuitStep<F> {
209 fn default() -> Self {
210 Self::Const(F::ZERO)
211 }
212}
213
214#[derive(Clone, Debug, SerializeBytes, DeserializeBytes, Eq)]
224pub struct ArithCircuit<F: Field> {
225 steps: Vec<ArithCircuitStep<F>>,
226}
227
228impl<F: Field> ArithCircuit<F> {
229 pub fn var(index: usize) -> Self {
230 Self {
231 steps: vec![ArithCircuitStep::Var(index)],
232 }
233 }
234
235 pub fn constant(value: F) -> Self {
236 Self {
237 steps: vec![ArithCircuitStep::Const(value)],
238 }
239 }
240
241 pub fn zero() -> Self {
242 Self::constant(F::ZERO)
243 }
244
245 pub fn one() -> Self {
246 Self::constant(F::ONE)
247 }
248
249 pub fn pow(mut self, exp: u64) -> Self {
250 self.steps
251 .push(ArithCircuitStep::Pow(self.steps.len() - 1, exp));
252
253 self
254 }
255
256 pub fn steps(&self) -> &[ArithCircuitStep<F>] {
258 &self.steps
259 }
260
261 pub fn degree(&self) -> usize {
263 fn step_degree<F: Field>(step: usize, steps: &[ArithCircuitStep<F>]) -> usize {
264 match steps[step] {
265 ArithCircuitStep::Const(_) => 0,
266 ArithCircuitStep::Var(_) => 1,
267 ArithCircuitStep::Add(left, right) => {
268 step_degree(left, steps).max(step_degree(right, steps))
269 }
270 ArithCircuitStep::Mul(left, right) => {
271 step_degree(left, steps) + step_degree(right, steps)
272 }
273 ArithCircuitStep::Pow(base, exp) => step_degree(base, steps) * (exp as usize),
274 }
275 }
276
277 step_degree(self.steps.len() - 1, &self.steps)
278 }
279
280 pub fn n_vars(&self) -> usize {
282 self.steps
283 .iter()
284 .map(|step| {
285 if let ArithCircuitStep::Var(index) = step {
286 *index + 1
287 } else {
288 0
289 }
290 })
291 .max()
292 .unwrap_or(0)
293 }
294
295 pub fn binary_tower_level(&self) -> usize
297 where
298 F: TowerField,
299 {
300 self.steps
301 .iter()
302 .map(|step| {
303 if let ArithCircuitStep::Const(value) = step {
304 value.min_tower_level()
305 } else {
306 0
307 }
308 })
309 .max()
310 .unwrap_or(0)
311 }
312
313 pub fn eval_cost(&self) -> EvalCost {
315 self.steps
316 .iter()
317 .fold(EvalCost::default(), |acc, step| match *step {
318 ArithCircuitStep::Const(_) | ArithCircuitStep::Var(_) => acc,
319 ArithCircuitStep::Add(_, _) => acc + EvalCost::add(1),
320 ArithCircuitStep::Mul(_, _) => acc + EvalCost::mul(1),
321 ArithCircuitStep::Pow(_, exp) => {
327 let n_squares = exp.ilog2() as usize;
328 let n_muls = exp.count_ones().saturating_sub(1) as usize;
329 acc + EvalCost::mul(n_muls) + EvalCost::square(n_squares)
330 }
331 })
332 }
333
334 pub fn leading_term(&self) -> Self {
337 let (_, expr) = self.leading_term_with_degree(self.steps.len() - 1);
338 expr
339 }
340
341 fn leading_term_with_degree(&self, step: usize) -> (usize, Self) {
343 match &self.steps[step] {
344 ArithCircuitStep::Const(value) => (0, Self::constant(*value)),
345 ArithCircuitStep::Var(index) => (1, Self::var(*index)),
346 ArithCircuitStep::Add(left, right) => {
347 let (lhs_degree, lhs) = self.leading_term_with_degree(*left);
348 let (rhs_degree, rhs) = self.leading_term_with_degree(*right);
349 match lhs_degree.cmp(&rhs_degree) {
350 Ordering::Less => (rhs_degree, rhs),
351 Ordering::Equal => (lhs_degree, lhs + rhs),
352 Ordering::Greater => (lhs_degree, lhs),
353 }
354 }
355 ArithCircuitStep::Mul(left, right) => {
356 let (lhs_degree, lhs) = self.leading_term_with_degree(*left);
357 let (rhs_degree, rhs) = self.leading_term_with_degree(*right);
358 (lhs_degree + rhs_degree, lhs * rhs)
359 }
360 ArithCircuitStep::Pow(base, exp) => {
361 let (base_degree, base) = self.leading_term_with_degree(*base);
362 (base_degree * (*exp as usize), base.pow(*exp))
363 }
364 }
365 }
366
367 pub fn evaluate(&self, query: &[F]) -> Result<F, Error> {
368 let mut step_evals = Vec::<F>::with_capacity(self.steps.len());
369 for step in &self.steps {
370 match step {
371 ArithCircuitStep::Add(left, right) => {
372 step_evals.push(step_evals[*left] + step_evals[*right])
373 }
374 ArithCircuitStep::Mul(left, right) => {
375 step_evals.push(step_evals[*left] * step_evals[*right])
376 }
377 ArithCircuitStep::Pow(base, exp) => step_evals.push(step_evals[*base].pow(*exp)),
378 ArithCircuitStep::Const(value) => step_evals.push(*value),
379 ArithCircuitStep::Var(index) => step_evals.push(query[*index]),
380 }
381 }
382 Ok(step_evals.pop().unwrap_or_default())
383 }
384
385 pub fn convert_field<FTgt: Field + From<F>>(&self) -> ArithCircuit<FTgt> {
386 ArithCircuit {
387 steps: self
388 .steps
389 .iter()
390 .map(|step| match step {
391 ArithCircuitStep::Const(value) => ArithCircuitStep::Const((*value).into()),
392 ArithCircuitStep::Var(index) => ArithCircuitStep::Var(*index),
393 ArithCircuitStep::Add(left, right) => ArithCircuitStep::Add(*left, *right),
394 ArithCircuitStep::Mul(left, right) => ArithCircuitStep::Mul(*left, *right),
395 ArithCircuitStep::Pow(base, exp) => ArithCircuitStep::Pow(*base, *exp),
396 })
397 .collect(),
398 }
399 }
400
401 pub fn try_convert_field<FTgt: Field + TryFrom<F>>(
402 &self,
403 ) -> Result<ArithCircuit<FTgt>, <FTgt as TryFrom<F>>::Error> {
404 let steps = self
405 .steps
406 .iter()
407 .map(|step| -> Result<ArithCircuitStep<FTgt>, <FTgt as TryFrom<F>>::Error> {
408 let result = match step {
409 ArithCircuitStep::Const(value) => {
410 ArithCircuitStep::Const(FTgt::try_from(*value)?)
411 }
412 ArithCircuitStep::Var(index) => ArithCircuitStep::Var(*index),
413 ArithCircuitStep::Add(left, right) => ArithCircuitStep::Add(*left, *right),
414 ArithCircuitStep::Mul(left, right) => ArithCircuitStep::Mul(*left, *right),
415 ArithCircuitStep::Pow(base, exp) => ArithCircuitStep::Pow(*base, *exp),
416 };
417 Ok(result)
418 })
419 .collect::<Result<Vec<_>, _>>()?;
420
421 Ok(ArithCircuit { steps })
422 }
423
424 pub fn remap_vars(&self, indices: &[usize]) -> Result<Self, Error> {
434 let steps = self
435 .steps
436 .iter()
437 .map(|step| -> Result<ArithCircuitStep<F>, Error> {
438 if let ArithCircuitStep::Var(index) = step {
439 let new_index = indices.get(*index).copied().ok_or_else(|| {
440 Error::IncorrectArgumentLength {
441 arg: "indices".to_string(),
442 expected: *index,
443 }
444 })?;
445 Ok(ArithCircuitStep::Var(new_index))
446 } else {
447 Ok(*step)
448 }
449 })
450 .collect::<Result<Vec<_>, _>>()?;
451 Ok(Self { steps })
452 }
453
454 pub fn const_subst(self, var: usize, value: F) -> Self {
456 let steps = self
457 .steps
458 .iter()
459 .map(|step| match step {
460 ArithCircuitStep::Var(index) if *index == var => ArithCircuitStep::Const(value),
461 _ => *step,
462 })
463 .collect();
464 Self { steps }
465 }
466
467 pub fn get_constant(&self) -> Option<F> {
469 if let ArithCircuitStep::Const(value) =
470 self.steps.last().expect("steps should not be empty")
471 {
472 Some(*value)
473 } else {
474 None
475 }
476 }
477
478 pub fn linear_normal_form(&self) -> Result<LinearNormalForm<F>, Error> {
484 self.sparse_linear_normal_form().map(Into::into)
485 }
486
487 fn sparse_linear_normal_form(&self) -> Result<SparseLinearNormalForm<F>, Error> {
488 fn sparse_linear_normal_form<F: Field>(
489 step: usize,
490 steps: &[ArithCircuitStep<F>],
491 ) -> Result<SparseLinearNormalForm<F>, Error> {
492 match &steps[step] {
493 ArithCircuitStep::Const(val) => Ok((*val).into()),
494 ArithCircuitStep::Var(index) => Ok(SparseLinearNormalForm {
495 constant: F::ZERO,
496 dense_linear_form_len: index + 1,
497 var_coeffs: [(*index, F::ONE)].into(),
498 }),
499 ArithCircuitStep::Add(left, right) => {
500 let left = sparse_linear_normal_form(*left, steps)?;
501 let right = sparse_linear_normal_form(*right, steps)?;
502 Ok(left + right)
503 }
504 ArithCircuitStep::Mul(left, right) => {
505 let left = sparse_linear_normal_form(*left, steps)?;
506 let right = sparse_linear_normal_form(*right, steps)?;
507 left * right
508 }
509 ArithCircuitStep::Pow(_, 0) => Ok(F::ONE.into()),
510 ArithCircuitStep::Pow(expr, 1) => sparse_linear_normal_form(*expr, steps),
511 ArithCircuitStep::Pow(expr, pow) => {
512 let linear_form = sparse_linear_normal_form(*expr, steps)?;
513 if linear_form.dense_linear_form_len != 0 {
514 return Err(Error::NonLinearExpression);
515 }
516 Ok(linear_form.constant.pow(*pow).into())
517 }
518 }
519 }
520
521 sparse_linear_normal_form(self.steps.len() - 1, &self.steps)
522 }
523
524 pub fn vars_usage(&self) -> Vec<bool> {
529 let mut usage = vec![false; self.n_vars()];
530
531 for step in &self.steps {
532 if let ArithCircuitStep::Var(index) = step {
533 usage[*index] = true;
534 }
535 }
536
537 usage
538 }
539
540 pub fn optimize_constants_in_place(&mut self) {
542 for step_index in 0..self.steps.len() {
543 let (prev_steps, curr_steps) = self.steps.split_at_mut(step_index);
544 let curr_step = &mut curr_steps[0];
545 match curr_step {
546 ArithCircuitStep::Const(_) | ArithCircuitStep::Var(_) => {}
547 ArithCircuitStep::Add(left, right) => {
548 match (&prev_steps[*left], &prev_steps[*right]) {
549 (ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => {
550 *curr_step = ArithCircuitStep::Const(*left + *right);
551 }
552 (ArithCircuitStep::Const(left), right) if *left == F::ZERO => {
553 *curr_step = *right;
554 }
555 (left, ArithCircuitStep::Const(right)) if *right == F::ZERO => {
556 *curr_step = *left;
557 }
558 (left, right) if left == right && F::CHARACTERISTIC == 2 => {
559 *curr_step = ArithCircuitStep::Const(F::ZERO);
560 }
561 _ => {}
562 }
563 }
564 ArithCircuitStep::Mul(left, right) => {
565 match (&prev_steps[*left], &prev_steps[*right]) {
566 (ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => {
567 *curr_step = ArithCircuitStep::Const(*left * *right);
568 }
569 (ArithCircuitStep::Const(left), _) if *left == F::ZERO => {
570 *curr_step = ArithCircuitStep::Const(F::ZERO);
571 }
572 (_, ArithCircuitStep::Const(right)) if *right == F::ZERO => {
573 *curr_step = ArithCircuitStep::Const(F::ZERO);
574 }
575 (ArithCircuitStep::Const(left), right) if *left == F::ONE => {
576 *curr_step = *right;
577 }
578 (left, ArithCircuitStep::Const(right)) if *right == F::ONE => {
579 *curr_step = *left;
580 }
581 _ => {}
582 }
583 }
584 ArithCircuitStep::Pow(base, exp) => match prev_steps[*base] {
585 ArithCircuitStep::Const(value) => {
586 *curr_step = ArithCircuitStep::Const(PackedField::pow(value, *exp));
587 }
588 ArithCircuitStep::Pow(base_inner, exp_inner) => {
589 *curr_step = ArithCircuitStep::Pow(base_inner, *exp * exp_inner);
590 }
591 _ => {}
592 },
593 }
594 }
595 }
596
597 pub fn optimize_constants(mut self) -> Self {
599 self.optimize_constants_in_place();
600 self
601 }
602
603 fn deduplicate_steps(&mut self) {
605 let mut step_map = HashMap::new();
606 let mut step_indices = Vec::with_capacity(self.steps.len());
607 for step in 0..self.steps.len() {
608 let node = StepNode {
609 index: step,
610 steps: &self.steps,
611 };
612 match step_map.entry(node) {
613 Entry::Occupied(entry) => {
614 step_indices.push(*entry.get());
615 }
616 Entry::Vacant(entry) => {
617 entry.insert(step);
618 step_indices.push(step);
619 }
620 }
621 }
622
623 for step in &mut self.steps {
624 match step {
625 ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
626 *left = step_indices[*left];
627 *right = step_indices[*right];
628 }
629 ArithCircuitStep::Pow(base, _) => *base = step_indices[*base],
630 _ => (),
631 }
632 }
633 }
634
635 fn compress_unused_steps(&mut self) {
637 fn mark_used<F: Field>(step: usize, steps: &[ArithCircuitStep<F>], used: &mut [bool]) {
638 if used[step] {
639 return;
640 }
641 used[step] = true;
642 match steps[step] {
643 ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
644 mark_used(left, steps, used);
645 mark_used(right, steps, used);
646 }
647 ArithCircuitStep::Pow(base, _) => mark_used(base, steps, used),
648 _ => (),
649 }
650 }
651
652 let mut used = vec![false; self.steps.len()];
653 mark_used(self.steps.len() - 1, &self.steps, &mut used);
654
655 let mut steps_map = (0..self.steps.len()).collect::<Vec<_>>();
656 let mut target_index = 0;
657 for source_index in 0..self.steps.len() {
658 if used[source_index] {
659 if target_index != source_index {
660 match &mut self.steps[source_index] {
661 ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
662 *left = steps_map[*left];
663 *right = steps_map[*right];
664 }
665 ArithCircuitStep::Pow(base, _) => *base = steps_map[*base],
666 _ => (),
667 }
668
669 steps_map[source_index] = target_index;
670 self.steps[target_index] = self.steps[source_index];
671 }
672
673 target_index += 1;
674 }
675 }
676
677 self.steps.truncate(target_index);
678 }
679
680 pub fn optimize_in_place(&mut self) {
685 self.optimize_constants_in_place();
686 self.deduplicate_steps();
687 self.compress_unused_steps();
688 }
689
690 pub fn optimize(mut self) -> Self {
692 self.optimize_constants_in_place();
693 self.deduplicate_steps();
694 self.compress_unused_steps();
695
696 self
697 }
698}
699
700impl<F: Field> From<&ArithExpr<F>> for ArithCircuit<F> {
701 fn from(expr: &ArithExpr<F>) -> Self {
702 fn visit_node<F: Field>(
703 node: &Arc<ArithExpr<F>>,
704 node_to_index: &mut HashMap<*const ArithExpr<F>, usize>,
705 steps: &mut Vec<ArithCircuitStep<F>>,
706 ) -> usize {
707 if let Some(index) = node_to_index.get(&Arc::as_ptr(node)) {
708 return *index;
709 }
710
711 let step = match &**node {
712 ArithExpr::Const(value) => ArithCircuitStep::Const(*value),
713 ArithExpr::Var(index) => ArithCircuitStep::Var(*index),
714 ArithExpr::Add(left, right) => {
715 let left = visit_node(left, node_to_index, steps);
716 let right = visit_node(right, node_to_index, steps);
717 ArithCircuitStep::Add(left, right)
718 }
719 ArithExpr::Mul(left, right) => {
720 let left = visit_node(left, node_to_index, steps);
721 let right = visit_node(right, node_to_index, steps);
722 ArithCircuitStep::Mul(left, right)
723 }
724 ArithExpr::Pow(base, exp) => {
725 let base = visit_node(base, node_to_index, steps);
726 ArithCircuitStep::Pow(base, *exp)
727 }
728 };
729
730 steps.push(step);
731 node_to_index.insert(Arc::as_ptr(node), steps.len() - 1);
732 steps.len() - 1
733 }
734
735 let mut steps = Vec::new();
736 let mut node_to_index = HashMap::new();
737 match expr {
738 ArithExpr::Const(c) => {
739 steps.push(ArithCircuitStep::Const(*c));
740 }
741 ArithExpr::Var(var) => {
742 steps.push(ArithCircuitStep::Var(*var));
743 }
744 ArithExpr::Add(left, right) => {
745 let left = visit_node(left, &mut node_to_index, &mut steps);
746 let right = visit_node(right, &mut node_to_index, &mut steps);
747 steps.push(ArithCircuitStep::Add(left, right));
748 }
749 ArithExpr::Mul(left, right) => {
750 let left = visit_node(left, &mut node_to_index, &mut steps);
751 let right = visit_node(right, &mut node_to_index, &mut steps);
752 steps.push(ArithCircuitStep::Mul(left, right));
753 }
754 ArithExpr::Pow(base, exp) => {
755 let base = visit_node(base, &mut node_to_index, &mut steps);
756 steps.push(ArithCircuitStep::Pow(base, *exp));
757 }
758 }
759
760 Self { steps }
761 }
762}
763
764impl<F: Field> From<ArithExpr<F>> for ArithCircuit<F> {
765 fn from(expr: ArithExpr<F>) -> Self {
766 Self::from(&expr)
767 }
768}
769
770impl<F: Field> From<&ArithCircuit<F>> for ArithExpr<F> {
771 fn from(circuit: &ArithCircuit<F>) -> Self {
772 let mut step_to_node = vec![Option::<Arc<Self>>::None; circuit.steps.len()];
773
774 for step in 0..circuit.steps.len() {
775 let node = match &circuit.steps[step] {
776 ArithCircuitStep::Const(value) => Self::Const(*value),
777 ArithCircuitStep::Var(index) => Self::Var(*index),
778 ArithCircuitStep::Add(left, right) => {
779 let left = step_to_node[*left].clone().expect("step must be present");
780 let right = step_to_node[*right].clone().expect("step must be present");
781 Self::Add(left, right)
782 }
783 ArithCircuitStep::Mul(left, right) => {
784 let left = step_to_node[*left].clone().expect("step must be present");
785 let right = step_to_node[*right].clone().expect("step must be present");
786 Self::Mul(left, right)
787 }
788 ArithCircuitStep::Pow(base, exp) => {
789 let base = step_to_node[*base].clone().expect("step must be present");
790 Self::Pow(base, *exp)
791 }
792 };
793 step_to_node[step] = Some(Arc::new(node));
794 }
795
796 Arc::into_inner(
797 step_to_node
798 .pop()
799 .expect("steps should not be empty")
800 .expect("last step should be initialized"),
801 )
802 .expect("last step must have a single instance")
803 }
804}
805
806impl<F: Field> From<ArithCircuit<F>> for ArithExpr<F> {
807 fn from(circuit: ArithCircuit<F>) -> Self {
808 Self::from(&circuit)
809 }
810}
811
812impl<F: Field> Display for ArithCircuit<F> {
813 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
814 fn display_step<F: Field>(
815 step: usize,
816 steps: &[ArithCircuitStep<F>],
817 f: &mut fmt::Formatter<'_>,
818 ) -> Result<(), fmt::Error> {
819 match &steps[step] {
820 ArithCircuitStep::Const(value) => write!(f, "{value}"),
821 ArithCircuitStep::Var(index) => write!(f, "x{index}"),
822 ArithCircuitStep::Add(left, right) => {
823 write!(f, "(")?;
824 display_step(*left, steps, f)?;
825 write!(f, " + ")?;
826 display_step(*right, steps, f)?;
827 write!(f, ")")
828 }
829 ArithCircuitStep::Mul(left, right) => {
830 write!(f, "(")?;
831 display_step(*left, steps, f)?;
832 write!(f, " * ")?;
833 display_step(*right, steps, f)?;
834 write!(f, ")")
835 }
836 ArithCircuitStep::Pow(base, exp) => {
837 write!(f, "(")?;
838 display_step(*base, steps, f)?;
839 write!(f, ")^{exp}")
840 }
841 }
842 }
843
844 display_step(self.steps.len() - 1, &self.steps, f)
845 }
846}
847
848impl<F: Field> PartialEq for ArithCircuit<F> {
849 fn eq(&self, other: &Self) -> bool {
850 StepNode {
851 index: self.steps.len() - 1,
852 steps: &self.steps,
853 } == StepNode {
854 index: other.steps.len() - 1,
855 steps: &other.steps,
856 }
857 }
858}
859
860impl<F: Field> Add for ArithCircuit<F> {
861 type Output = Self;
862
863 fn add(mut self, rhs: Self) -> Self {
864 self += rhs;
865 self
866 }
867}
868
869impl<F: Field> AddAssign for ArithCircuit<F> {
870 fn add_assign(&mut self, mut rhs: Self) {
871 let old_len = self.steps.len();
872 add_offset(&mut rhs.steps, old_len);
873 self.steps.extend(rhs.steps);
874 self.steps
875 .push(ArithCircuitStep::Add(old_len - 1, self.steps.len() - 1));
876 }
877}
878
879impl<F: Field> Sub for ArithCircuit<F> {
880 type Output = Self;
881
882 fn sub(mut self, rhs: Self) -> Self {
883 self -= rhs;
884 self
885 }
886}
887
888impl<F: Field> SubAssign for ArithCircuit<F> {
889 #[allow(clippy::suspicious_op_assign_impl)]
890 fn sub_assign(&mut self, rhs: Self) {
891 *self += rhs;
892 }
893}
894
895impl<F: Field> Mul for ArithCircuit<F> {
896 type Output = Self;
897
898 fn mul(mut self, rhs: Self) -> Self {
899 self *= rhs;
900 self
901 }
902}
903
904impl<F: Field> MulAssign for ArithCircuit<F> {
905 fn mul_assign(&mut self, mut rhs: Self) {
906 let old_len = self.steps.len();
907 add_offset(&mut rhs.steps, old_len);
908 self.steps.extend(rhs.steps);
909 self.steps
910 .push(ArithCircuitStep::Mul(old_len - 1, self.steps.len() - 1));
911 }
912}
913
914impl<F: Field> Sum for ArithCircuit<F> {
915 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
916 iter.fold(Self::zero(), |sum, item| sum + item)
917 }
918}
919
920impl<F: Field> Product for ArithCircuit<F> {
921 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
922 iter.fold(Self::one(), |product, item| product * item)
923 }
924}
925
926fn add_offset<F: Field>(steps: &mut [ArithCircuitStep<F>], offset: usize) {
927 for step in steps.iter_mut() {
928 match step {
929 ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
930 *left += offset;
931 *right += offset;
932 }
933 ArithCircuitStep::Pow(base, _) => {
934 *base += offset;
935 }
936 _ => (),
937 }
938 }
939}
940#[derive(Debug, Default, Clone, PartialEq, Eq)]
942pub struct LinearNormalForm<F: Field> {
943 pub constant: F,
945 pub var_coeffs: Vec<F>,
947}
948
949struct SparseLinearNormalForm<F: Field> {
950 pub constant: F,
952 pub var_coeffs: HashMap<usize, F>,
954 pub dense_linear_form_len: usize,
957}
958
959impl<F: Field> From<F> for SparseLinearNormalForm<F> {
960 fn from(value: F) -> Self {
961 Self {
962 constant: value,
963 dense_linear_form_len: 0,
964 var_coeffs: HashMap::new(),
965 }
966 }
967}
968
969impl<F: Field> Add for SparseLinearNormalForm<F> {
970 type Output = Self;
971 fn add(self, rhs: Self) -> Self::Output {
972 let (mut result, consumable) = if self.var_coeffs.len() < rhs.var_coeffs.len() {
973 (rhs, self)
974 } else {
975 (self, rhs)
976 };
977 result.constant += consumable.constant;
978 if consumable.dense_linear_form_len > result.dense_linear_form_len {
979 result.dense_linear_form_len = consumable.dense_linear_form_len;
980 }
981
982 for (index, coeff) in consumable.var_coeffs {
983 result
984 .var_coeffs
985 .entry(index)
986 .and_modify(|res_coeff| {
987 *res_coeff += coeff;
988 })
989 .or_insert(coeff);
990 }
991 result
992 }
993}
994
995impl<F: Field> Mul for SparseLinearNormalForm<F> {
996 type Output = Result<Self, Error>;
997 fn mul(self, rhs: Self) -> Result<Self, Error> {
998 if !self.var_coeffs.is_empty() && !rhs.var_coeffs.is_empty() {
999 return Err(Error::NonLinearExpression);
1000 }
1001 let (mut result, consumable) = if self.var_coeffs.is_empty() {
1002 (rhs, self)
1003 } else {
1004 (self, rhs)
1005 };
1006 result.constant *= consumable.constant;
1007 for coeff in result.var_coeffs.values_mut() {
1008 *coeff *= consumable.constant;
1009 }
1010 Ok(result)
1011 }
1012}
1013
1014impl<F: Field> From<SparseLinearNormalForm<F>> for LinearNormalForm<F> {
1015 fn from(value: SparseLinearNormalForm<F>) -> Self {
1016 let mut var_coeffs = vec![F::ZERO; value.dense_linear_form_len];
1017 for (i, coeff) in value.var_coeffs {
1018 var_coeffs[i] = coeff;
1019 }
1020 Self {
1021 constant: value.constant,
1022 var_coeffs,
1023 }
1024 }
1025}
1026
1027#[derive(Eq)]
1030struct StepNode<'a, F: Field> {
1031 index: usize,
1032 steps: &'a [ArithCircuitStep<F>],
1033}
1034
1035impl<F: Field> StepNode<'_, F> {
1036 fn prev_step(&self, step: usize) -> Self {
1037 StepNode {
1038 index: step,
1039 steps: self.steps,
1040 }
1041 }
1042}
1043
1044impl<F: Field> PartialEq for StepNode<'_, F> {
1045 #[allow(clippy::suspicious_operation_groupings)] fn eq(&self, other: &Self) -> bool {
1047 match (&self.steps[self.index], &other.steps[other.index]) {
1048 (ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => left == right,
1049 (ArithCircuitStep::Var(left), ArithCircuitStep::Var(right)) => left == right,
1050 (
1051 ArithCircuitStep::Add(left, right),
1052 ArithCircuitStep::Add(other_left, other_right),
1053 )
1054 | (
1055 ArithCircuitStep::Mul(left, right),
1056 ArithCircuitStep::Mul(other_left, other_right),
1057 ) => {
1058 self.prev_step(*left) == other.prev_step(*other_left)
1059 && self.prev_step(*right) == other.prev_step(*other_right)
1060 }
1061 (ArithCircuitStep::Pow(base, exp), ArithCircuitStep::Pow(other_base, other_exp)) => {
1062 self.prev_step(*base) == other.prev_step(*other_base) && exp == other_exp
1063 }
1064 _ => false,
1065 }
1066 }
1067}
1068
1069impl<F: Field> Hash for StepNode<'_, F> {
1070 fn hash<H: Hasher>(&self, state: &mut H) {
1071 match self.steps[self.index] {
1072 ArithCircuitStep::Const(value) => {
1073 0u8.hash(state);
1074 value.hash(state);
1075 }
1076 ArithCircuitStep::Var(index) => {
1077 1u8.hash(state);
1078 index.hash(state);
1079 }
1080 ArithCircuitStep::Add(left, right) => {
1081 2u8.hash(state);
1082 self.prev_step(left).hash(state);
1083 self.prev_step(right).hash(state);
1084 }
1085 ArithCircuitStep::Mul(left, right) => {
1086 3u8.hash(state);
1087 self.prev_step(left).hash(state);
1088 self.prev_step(right).hash(state);
1089 }
1090 ArithCircuitStep::Pow(base, exp) => {
1091 4u8.hash(state);
1092 self.prev_step(base).hash(state);
1093 exp.hash(state);
1094 }
1095 }
1096 }
1097}
1098
1099#[derive(Default)]
1101pub struct EvalCost {
1102 pub n_adds: usize,
1106 pub n_muls: usize,
1110 pub n_squares: usize,
1115}
1116
1117impl EvalCost {
1118 pub fn add(n_adds: usize) -> Self {
1119 Self {
1120 n_adds,
1121 ..Self::default()
1122 }
1123 }
1124
1125 pub fn mul(n_muls: usize) -> Self {
1126 Self {
1127 n_muls,
1128 ..Self::default()
1129 }
1130 }
1131
1132 pub fn square(n_squares: usize) -> Self {
1133 Self {
1134 n_squares,
1135 ..Self::default()
1136 }
1137 }
1138
1139 pub fn mult_cost_approx(&self) -> usize {
1145 self.n_muls + self.n_squares.div_ceil(5)
1146 }
1147}
1148
1149impl Add for EvalCost {
1150 type Output = Self;
1151 fn add(self, other: Self) -> Self {
1152 Self {
1153 n_adds: self.n_adds + other.n_adds,
1154 n_muls: self.n_muls + other.n_muls,
1155 n_squares: self.n_squares + other.n_squares,
1156 }
1157 }
1158}
1159
1160#[cfg(test)]
1161mod tests {
1162 use std::collections::HashSet;
1163
1164 use assert_matches::assert_matches;
1165 use binius_field::{BinaryField, BinaryField1b, BinaryField8b, BinaryField128b};
1166 use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes};
1167
1168 use super::*;
1169
1170 #[test]
1171 fn test_degree_with_pow() {
1172 let expr = ArithCircuit::constant(BinaryField8b::new(6)).pow(7);
1173 assert_eq!(expr.degree(), 0);
1174
1175 let expr: ArithCircuit<BinaryField8b> = ArithCircuit::var(0).pow(7);
1176 assert_eq!(expr.degree(), 7);
1177
1178 let expr: ArithCircuit<BinaryField8b> =
1179 (ArithCircuit::var(0) * ArithCircuit::var(1)).pow(7);
1180 assert_eq!(expr.degree(), 14);
1181 }
1182
1183 #[test]
1184 fn test_n_vars() {
1185 type F = BinaryField8b;
1186 let expr = ArithCircuit::<F>::var(0) * ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR)
1187 + ArithCircuit::var(2).pow(2);
1188 assert_eq!(expr.n_vars(), 3);
1189 }
1190
1191 #[test]
1192 fn test_leading_term_with_degree() {
1193 let expr = ArithCircuit::var(0)
1194 * (ArithCircuit::var(1)
1195 * ArithCircuit::var(2)
1196 * ArithCircuit::constant(BinaryField8b::MULTIPLICATIVE_GENERATOR)
1197 + ArithCircuit::var(4))
1198 + ArithCircuit::var(5).pow(3)
1199 + ArithCircuit::constant(BinaryField8b::ONE);
1200
1201 let expected_expr = ArithCircuit::var(0)
1202 * ((ArithCircuit::var(1) * ArithCircuit::var(2))
1203 * ArithCircuit::constant(BinaryField8b::MULTIPLICATIVE_GENERATOR))
1204 + ArithCircuit::var(5).pow(3);
1205
1206 assert_eq!(expr.leading_term_with_degree(expr.steps().len() - 1), (3, expected_expr));
1207 }
1208
1209 #[test]
1210 fn test_remap_vars_with_too_few_vars() {
1211 type F = BinaryField8b;
1212 let expr =
1213 ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(1)).pow(3);
1214 assert_matches!(expr.remap_vars(&[5]), Err(Error::IncorrectArgumentLength { .. }));
1215 }
1216
1217 #[test]
1218 fn test_remap_vars_works() {
1219 type F = BinaryField8b;
1220 let expr =
1221 ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(1)).pow(3);
1222 let new_expr = expr.remap_vars(&[5, 3]);
1223
1224 let expected =
1225 ((ArithCircuit::var(5) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(3)).pow(3);
1226 assert_eq!(new_expr.unwrap(), expected);
1227 }
1228
1229 #[test]
1230 fn test_optimize_identity_handling() {
1231 type F = BinaryField8b;
1232 let zero = ArithCircuit::<F>::zero();
1233 let one = ArithCircuit::<F>::one();
1234
1235 assert_eq!((zero.clone() * ArithCircuit::<F>::var(0)).optimize(), zero);
1236 assert_eq!((ArithCircuit::<F>::var(0) * zero.clone()).optimize(), zero);
1237
1238 assert_eq!((ArithCircuit::<F>::var(0) * one.clone()).optimize(), ArithCircuit::var(0));
1239 assert_eq!((one * ArithCircuit::<F>::var(0)).optimize(), ArithCircuit::var(0));
1240
1241 assert_eq!((ArithCircuit::<F>::var(0) + zero.clone()).optimize(), ArithCircuit::var(0));
1242 assert_eq!((zero.clone() + ArithCircuit::<F>::var(0)).optimize(), ArithCircuit::var(0));
1243
1244 assert_eq!((ArithCircuit::<F>::var(0) + ArithCircuit::var(0)).optimize(), zero);
1245 }
1246
1247 #[test]
1248 fn test_const_subst_and_optimize() {
1249 type F = BinaryField8b;
1251 let expr = ArithCircuit::var(0) * ArithCircuit::var(1) + ArithCircuit::one()
1252 - ArithCircuit::var(1);
1253 assert_eq!(expr.const_subst(1, F::ZERO).optimize().get_constant(), Some(F::ONE));
1254 }
1255
1256 #[test]
1257 fn test_expression_upcast() {
1258 type F8 = BinaryField8b;
1259 type F = BinaryField128b;
1260
1261 let expr = ((ArithCircuit::var(0) + ArithCircuit::constant(F8::ONE))
1262 * ArithCircuit::constant(F8::new(222)))
1263 .pow(3);
1264
1265 let expected = ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1266 * ArithCircuit::constant(F::new(222)))
1267 .pow(3);
1268 assert_eq!(expr.convert_field::<F>(), expected);
1269 }
1270
1271 #[test]
1272 fn test_expression_downcast() {
1273 type F8 = BinaryField8b;
1274 type F = BinaryField128b;
1275
1276 let expr = ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1277 * ArithCircuit::constant(F::new(222)))
1278 .pow(3);
1279
1280 assert!(expr.try_convert_field::<BinaryField1b>().is_err());
1281
1282 let expected = ((ArithCircuit::var(0) + ArithCircuit::constant(F8::ONE))
1283 * ArithCircuit::constant(F8::new(222)))
1284 .pow(3);
1285 assert_eq!(expr.try_convert_field::<BinaryField8b>().unwrap(), expected);
1286 }
1287
1288 #[test]
1289 fn test_linear_normal_form() {
1290 type F = BinaryField128b;
1291 struct Case {
1292 expr: ArithCircuit<F>,
1293 expected: LinearNormalForm<F>,
1294 }
1295 let cases = vec![
1296 Case {
1297 expr: ArithCircuit::constant(F::ONE),
1298 expected: LinearNormalForm {
1299 constant: F::ONE,
1300 var_coeffs: vec![],
1301 },
1302 },
1303 Case {
1304 expr: (ArithCircuit::constant(F::new(2)) * ArithCircuit::constant(F::new(3)))
1305 .pow(2) + ArithCircuit::constant(F::new(3))
1306 * (ArithCircuit::constant(F::new(4)) + ArithCircuit::var(0)),
1307 expected: LinearNormalForm {
1308 constant: (F::new(2) * F::new(3)).pow(2) + F::new(3) * F::new(4),
1309 var_coeffs: vec![F::new(3)],
1310 },
1311 },
1312 Case {
1313 expr: ArithCircuit::constant(F::new(133))
1314 + ArithCircuit::constant(F::new(42)) * ArithCircuit::var(0)
1315 + ArithCircuit::var(2)
1316 + ArithCircuit::constant(F::new(11))
1317 * ArithCircuit::constant(F::new(37))
1318 * ArithCircuit::var(3),
1319 expected: LinearNormalForm {
1320 constant: F::new(133),
1321 var_coeffs: vec![F::new(42), F::ZERO, F::ONE, F::new(11) * F::new(37)],
1322 },
1323 },
1324 ];
1325 for Case { expr, expected } in cases {
1326 let normal_form = expr.linear_normal_form().unwrap();
1327 assert_eq!(normal_form.constant, expected.constant);
1328 assert_eq!(normal_form.var_coeffs, expected.var_coeffs);
1329 }
1330 }
1331
1332 fn unique_nodes_count<F: Field>(expr: &ArithCircuit<F>) -> usize {
1333 let mut unique_nodes = HashSet::new();
1334
1335 for step in 0..expr.steps.len() {
1336 unique_nodes.insert(StepNode {
1337 index: step,
1338 steps: &expr.steps,
1339 });
1340 }
1341
1342 unique_nodes.len()
1343 }
1344
1345 fn check_serialize_bytes_roundtrip<F: Field>(expr: ArithCircuit<F>) {
1346 let mut buf = Vec::new();
1347
1348 expr.serialize(&mut buf, SerializationMode::CanonicalTower)
1349 .unwrap();
1350 let deserialized =
1351 ArithCircuit::<F>::deserialize(&buf[..], SerializationMode::CanonicalTower).unwrap();
1352 assert_eq!(expr, deserialized);
1353 assert_eq!(unique_nodes_count(&expr), unique_nodes_count(&deserialized));
1354 }
1355
1356 #[test]
1357 fn test_serialize_bytes_roundtrip() {
1358 type F = BinaryField128b;
1359 let expr = ArithCircuit::var(0)
1360 * (ArithCircuit::var(1)
1361 * ArithCircuit::var(2)
1362 * ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR)
1363 + ArithCircuit::var(4))
1364 + ArithCircuit::var(5).pow(3)
1365 + ArithCircuit::constant(F::ONE);
1366
1367 check_serialize_bytes_roundtrip(expr);
1368 }
1369
1370 #[test]
1371 fn test_serialize_bytes_roundtrip_with_duplicates() {
1372 type F = BinaryField128b;
1373 let expr = (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1374 * (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1375 + (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1376 + ArithCircuit::var(1);
1377
1378 check_serialize_bytes_roundtrip(expr);
1379 }
1380
1381 #[test]
1382 fn test_binary_tower_level() {
1383 type F = BinaryField128b;
1384 let expr =
1385 ArithCircuit::constant(F::ONE) + ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR);
1386 assert_eq!(expr.binary_tower_level(), F::MULTIPLICATIVE_GENERATOR.min_tower_level());
1387 }
1388
1389 #[test]
1390 fn test_arith_circuit_steps() {
1391 type F = BinaryField8b;
1392 let expr = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1)) * ArithCircuit::var(2);
1393 let steps = expr.steps();
1394 assert_eq!(steps.len(), 5); assert!(matches!(steps[0], ArithCircuitStep::Var(0)));
1396 assert!(matches!(steps[1], ArithCircuitStep::Var(1)));
1397 assert!(matches!(steps[2], ArithCircuitStep::Add(_, _)));
1398 assert!(matches!(steps[3], ArithCircuitStep::Var(2)));
1399 assert!(matches!(steps[4], ArithCircuitStep::Mul(_, _)));
1400 }
1401
1402 #[test]
1403 fn test_optimize_constants() {
1404 type F = BinaryField8b;
1405 let mut circuit = (ArithCircuit::<F>::var(0) + ArithCircuit::constant(F::ZERO))
1406 * ArithCircuit::var(1)
1407 + ArithCircuit::constant(F::ONE) * ArithCircuit::var(2)
1408 + ArithCircuit::constant(F::ONE).pow(4).pow(5)
1409 + (ArithCircuit::var(5) + ArithCircuit::var(5));
1410 circuit.optimize_constants_in_place();
1411
1412 let expected_circuit = ArithCircuit::var(0) * ArithCircuit::var(1)
1413 + ArithCircuit::var(2)
1414 + ArithCircuit::constant(F::ONE);
1415
1416 assert_eq!(circuit, expected_circuit);
1417 }
1418
1419 #[test]
1420 fn test_deduplicate_steps() {
1421 type F = BinaryField8b;
1422 let mut circuit = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1))
1423 * (ArithCircuit::var(0) + ArithCircuit::var(1))
1424 + (ArithCircuit::var(0) + ArithCircuit::var(1));
1425 circuit.deduplicate_steps();
1426
1427 let expected_circuit = ArithCircuit::<F> {
1428 steps: vec![
1429 ArithCircuitStep::Var(0),
1430 ArithCircuitStep::Var(1),
1431 ArithCircuitStep::Add(0, 1),
1432 ArithCircuitStep::Mul(2, 2),
1433 ArithCircuitStep::Add(3, 2),
1434 ],
1435 };
1436 assert_eq!(circuit, expected_circuit);
1437 }
1438
1439 #[test]
1440 fn test_compress_unused_steps() {
1441 type F = BinaryField8b;
1442 let mut circuit = ArithCircuit::<F> {
1443 steps: vec![
1444 ArithCircuitStep::Var(0),
1445 ArithCircuitStep::Var(1),
1446 ArithCircuitStep::Var(2),
1447 ArithCircuitStep::Add(0, 1),
1448 ArithCircuitStep::Var(3),
1449 ArithCircuitStep::Const(F::ZERO),
1450 ArithCircuitStep::Var(2),
1451 ArithCircuitStep::Mul(3, 3),
1452 ],
1453 };
1454 circuit.compress_unused_steps();
1455
1456 let expected_circuit = ArithCircuit::<F> {
1457 steps: vec![
1458 ArithCircuitStep::Var(0),
1459 ArithCircuitStep::Var(1),
1460 ArithCircuitStep::Add(0, 1),
1461 ArithCircuitStep::Mul(2, 2),
1462 ],
1463 };
1464 assert_eq!(circuit.steps, expected_circuit.steps);
1465 }
1466
1467 #[test]
1468 fn test_conversion_from_expr_node_doesnt_create_duplicated_steps() {
1469 type F = BinaryField8b;
1470 let sub_expr = Arc::new(ArithExpr::<F>::Var(0) + ArithExpr::<F>::Var(1));
1471 let expr = ArithExpr::Mul(sub_expr.clone(), sub_expr.clone()) + sub_expr;
1472 let circuit = ArithCircuit::<F>::from(&expr);
1473 assert_eq!(circuit.steps.len(), 5);
1474 assert_eq!(unique_nodes_count(&circuit), 5);
1475 }
1476
1477 fn unique_nodes_count_expr<F: Field>(expr: &ArithExpr<F>) -> usize {
1478 fn add_nodes<F: Field>(
1479 expr: &Arc<ArithExpr<F>>,
1480 unique_nodes: &mut HashSet<*const ArithExpr<F>>,
1481 ) {
1482 if !unique_nodes.insert(Arc::as_ptr(expr)) {
1483 return;
1484 }
1485
1486 match expr.as_ref() {
1487 ArithExpr::Const(_) | ArithExpr::Var(_) => {}
1488 ArithExpr::Add(left, right) | ArithExpr::Mul(left, right) => {
1489 add_nodes(left, unique_nodes);
1490 add_nodes(right, unique_nodes);
1491 }
1492 ArithExpr::Pow(base, _) => {
1493 add_nodes(base, unique_nodes);
1494 }
1495 }
1496 }
1497
1498 let mut unique_nodes = HashSet::new();
1499 add_nodes(&Arc::new(expr.clone()), &mut unique_nodes);
1500
1501 unique_nodes.len()
1502 }
1503
1504 #[test]
1505 fn test_conversion_from_circuit_to_expr_node() {
1506 type F = BinaryField8b;
1507
1508 let arith_circuit = ArithCircuit::<F> {
1509 steps: vec![
1510 ArithCircuitStep::Var(0),
1511 ArithCircuitStep::Var(1),
1512 ArithCircuitStep::Add(0, 1),
1513 ArithCircuitStep::Mul(2, 2),
1514 ArithCircuitStep::Add(3, 2),
1515 ],
1516 };
1517 let expr = ArithExpr::from(&arith_circuit);
1518 let expected_expr = (ArithExpr::Var(0) + ArithExpr::Var(1))
1519 * (ArithExpr::Var(0) + ArithExpr::Var(1))
1520 + (ArithExpr::Var(0) + ArithExpr::Var(1));
1521 assert_eq!(expr, expected_expr);
1522 assert_eq!(unique_nodes_count_expr(&expr), 5);
1523 }
1524
1525 #[test]
1526 fn test_evaluate() {
1527 type F = BinaryField8b;
1528 let expr = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1))
1529 * (ArithCircuit::var(2) + ArithCircuit::var(3)).pow(5);
1530 let result = expr
1531 .evaluate(&[F::new(2), F::new(3), F::new(4), F::new(5)])
1532 .unwrap();
1533 assert_eq!(result, F::new(2) + F::new(3) * (F::new(4) + F::new(5)).pow(5));
1534 }
1535}