binius_math/
arith_expr.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	cmp::Ordering,
5	collections::{HashMap, hash_map::Entry},
6	fmt::{self, Display},
7	hash::{Hash, Hasher},
8	iter::{Product, Sum},
9	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
10	sync::Arc,
11};
12
13use binius_field::{Field, PackedField, TowerField};
14use binius_macros::{DeserializeBytes, SerializeBytes};
15
16use super::error::Error;
17
18/// A builder for arithmetic expressions.
19///
20/// This is a simple representation of multivariate polynomials. The user can explicitly
21/// specify which subexpressions are shared by using `Arc` to wrap them. The same subexpressions
22/// are guaranteed to be converted to the same steps `ArithCircuit` after the conversion.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum ArithExpr<F: Field> {
25	Const(F),
26	Var(usize),
27	Add(Arc<ArithExpr<F>>, Arc<ArithExpr<F>>),
28	Mul(Arc<ArithExpr<F>>, Arc<ArithExpr<F>>),
29	Pow(Arc<ArithExpr<F>>, u64),
30}
31
32impl<F: Field + Display> Display for ArithExpr<F> {
33	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34		match self {
35			Self::Const(v) => write!(f, "{v}"),
36			Self::Var(i) => write!(f, "x{i}"),
37			Self::Add(x, y) => write!(f, "({} + {})", &**x, &**y),
38			Self::Mul(x, y) => write!(f, "({} * {})", &**x, &**y),
39			Self::Pow(x, p) => write!(f, "({})^{p}", &**x),
40		}
41	}
42}
43
44impl<F: Field> ArithExpr<F> {
45	pub fn pow(self, exp: u64) -> Self {
46		Self::Pow(Arc::new(self), exp)
47	}
48
49	pub const fn zero() -> Self {
50		Self::Const(F::ZERO)
51	}
52
53	pub const fn one() -> Self {
54		Self::Const(F::ONE)
55	}
56}
57
58impl<F> Default for ArithExpr<F>
59where
60	F: Field,
61{
62	fn default() -> Self {
63		Self::zero()
64	}
65}
66
67impl<F> Add for ArithExpr<F>
68where
69	F: Field,
70{
71	type Output = Self;
72
73	fn add(self, rhs: Self) -> Self {
74		Self::Add(Arc::new(self), Arc::new(rhs))
75	}
76}
77
78impl<F> Add<Arc<Self>> for ArithExpr<F>
79where
80	F: Field,
81{
82	type Output = Self;
83
84	fn add(self, rhs: Arc<Self>) -> Self {
85		Self::Add(Arc::new(self), rhs)
86	}
87}
88
89impl<F> AddAssign for ArithExpr<F>
90where
91	F: Field,
92{
93	fn add_assign(&mut self, rhs: Self) {
94		*self = std::mem::take(self) + rhs;
95	}
96}
97
98impl<F> AddAssign<Arc<Self>> for ArithExpr<F>
99where
100	F: Field,
101{
102	fn add_assign(&mut self, rhs: Arc<Self>) {
103		*self = std::mem::take(self) + rhs;
104	}
105}
106
107impl<F> Sub for ArithExpr<F>
108where
109	F: Field,
110{
111	type Output = Self;
112
113	fn sub(self, rhs: Self) -> Self {
114		Self::Add(Arc::new(self), Arc::new(rhs))
115	}
116}
117
118impl<F> Sub<Arc<Self>> for ArithExpr<F>
119where
120	F: Field,
121{
122	type Output = Self;
123
124	fn sub(self, rhs: Arc<Self>) -> Self {
125		Self::Add(Arc::new(self), rhs)
126	}
127}
128
129impl<F> SubAssign for ArithExpr<F>
130where
131	F: Field,
132{
133	fn sub_assign(&mut self, rhs: Self) {
134		*self = std::mem::take(self) - rhs;
135	}
136}
137
138impl<F> SubAssign<Arc<Self>> for ArithExpr<F>
139where
140	F: Field,
141{
142	fn sub_assign(&mut self, rhs: Arc<Self>) {
143		*self = std::mem::take(self) - rhs;
144	}
145}
146
147impl<F> Mul for ArithExpr<F>
148where
149	F: Field,
150{
151	type Output = Self;
152
153	fn mul(self, rhs: Self) -> Self {
154		Self::Mul(Arc::new(self), Arc::new(rhs))
155	}
156}
157
158impl<F> Mul<Arc<Self>> for ArithExpr<F>
159where
160	F: Field,
161{
162	type Output = Self;
163
164	fn mul(self, rhs: Arc<Self>) -> Self {
165		Self::Mul(Arc::new(self), rhs)
166	}
167}
168
169impl<F> MulAssign for ArithExpr<F>
170where
171	F: Field,
172{
173	fn mul_assign(&mut self, rhs: Self) {
174		*self = std::mem::take(self) * rhs;
175	}
176}
177
178impl<F> MulAssign<Arc<Self>> for ArithExpr<F>
179where
180	F: Field,
181{
182	fn mul_assign(&mut self, rhs: Arc<Self>) {
183		*self = std::mem::take(self) * rhs;
184	}
185}
186
187impl<F: Field> Sum for ArithExpr<F> {
188	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
189		iter.reduce(|acc, item| acc + item).unwrap_or(Self::zero())
190	}
191}
192
193impl<F: Field> Product for ArithExpr<F> {
194	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
195		iter.reduce(|acc, item| acc * item).unwrap_or(Self::one())
196	}
197}
198
199#[derive(Clone, Copy, Debug, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
200pub enum ArithCircuitStep<F: Field> {
201	Add(usize, usize),
202	Mul(usize, usize),
203	Pow(usize, u64),
204	Const(F),
205	Var(usize),
206}
207
208impl<F: Field> Default for ArithCircuitStep<F> {
209	fn default() -> Self {
210		Self::Const(F::ZERO)
211	}
212}
213
214/// Arithmetic expressions that can be evaluated symbolically.
215///
216/// Arithmetic expressions are trees, where the leaves are either constants or variables, and the
217/// non-leaf nodes are arithmetic operations, such as addition, multiplication, etc. They are
218/// specific representations of multivariate polynomials.
219///
220/// We store a tree in a form of a simple circuit.
221/// This implementation isn't optimized for performance, but rather for simplicity
222/// to allow easy conversion and preservation of the common subexpressions
223#[derive(Clone, Debug, SerializeBytes, DeserializeBytes, Eq)]
224pub struct ArithCircuit<F: Field> {
225	steps: Vec<ArithCircuitStep<F>>,
226}
227
228impl<F: Field> ArithCircuit<F> {
229	pub fn var(index: usize) -> Self {
230		Self {
231			steps: vec![ArithCircuitStep::Var(index)],
232		}
233	}
234
235	pub fn constant(value: F) -> Self {
236		Self {
237			steps: vec![ArithCircuitStep::Const(value)],
238		}
239	}
240
241	pub fn zero() -> Self {
242		Self::constant(F::ZERO)
243	}
244
245	pub fn one() -> Self {
246		Self::constant(F::ONE)
247	}
248
249	pub fn pow(mut self, exp: u64) -> Self {
250		self.steps
251			.push(ArithCircuitStep::Pow(self.steps.len() - 1, exp));
252
253		self
254	}
255
256	/// Steps of the circuit.
257	pub fn steps(&self) -> &[ArithCircuitStep<F>] {
258		&self.steps
259	}
260
261	/// The total degree of the polynomial the expression represents.
262	pub fn degree(&self) -> usize {
263		fn step_degree<F: Field>(step: usize, steps: &[ArithCircuitStep<F>]) -> usize {
264			match steps[step] {
265				ArithCircuitStep::Const(_) => 0,
266				ArithCircuitStep::Var(_) => 1,
267				ArithCircuitStep::Add(left, right) => {
268					step_degree(left, steps).max(step_degree(right, steps))
269				}
270				ArithCircuitStep::Mul(left, right) => {
271					step_degree(left, steps) + step_degree(right, steps)
272				}
273				ArithCircuitStep::Pow(base, exp) => step_degree(base, steps) * (exp as usize),
274			}
275		}
276
277		step_degree(self.steps.len() - 1, &self.steps)
278	}
279
280	/// The number of variables the expression contains.
281	pub fn n_vars(&self) -> usize {
282		self.steps
283			.iter()
284			.map(|step| {
285				if let ArithCircuitStep::Var(index) = step {
286					*index + 1
287				} else {
288					0
289				}
290			})
291			.max()
292			.unwrap_or(0)
293	}
294
295	/// The maximum tower level of the constant terms in the circuit.
296	pub fn binary_tower_level(&self) -> usize
297	where
298		F: TowerField,
299	{
300		self.steps
301			.iter()
302			.map(|step| {
303				if let ArithCircuitStep::Const(value) = step {
304					value.min_tower_level()
305				} else {
306					0
307				}
308			})
309			.max()
310			.unwrap_or(0)
311	}
312
313	/// Returns the evaluation cost of the circuit.
314	pub fn eval_cost(&self) -> EvalCost {
315		self.steps
316			.iter()
317			.fold(EvalCost::default(), |acc, step| match *step {
318				ArithCircuitStep::Const(_) | ArithCircuitStep::Var(_) => acc,
319				ArithCircuitStep::Add(_, _) => acc + EvalCost::add(1),
320				ArithCircuitStep::Mul(_, _) => acc + EvalCost::mul(1),
321				// To exponentiate b^e the square-and-multiply method is used.
322				//
323				// That takes:
324				// 1. log2(e) squarings.
325				// 1. popcount(e) - 1 multiplications.
326				ArithCircuitStep::Pow(_, exp) => {
327					let n_squares = exp.ilog2() as usize;
328					let n_muls = exp.count_ones().saturating_sub(1) as usize;
329					acc + EvalCost::mul(n_muls) + EvalCost::square(n_squares)
330				}
331			})
332	}
333
334	/// Return a new arithmetic expression that contains only the terms of highest degree
335	/// (useful for interpolation at Karatsuba infinity point).
336	pub fn leading_term(&self) -> Self {
337		let (_, expr) = self.leading_term_with_degree(self.steps.len() - 1);
338		expr
339	}
340
341	/// Same as `leading_term`, but returns the total degree as the first tuple element as well.
342	fn leading_term_with_degree(&self, step: usize) -> (usize, Self) {
343		match &self.steps[step] {
344			ArithCircuitStep::Const(value) => (0, Self::constant(*value)),
345			ArithCircuitStep::Var(index) => (1, Self::var(*index)),
346			ArithCircuitStep::Add(left, right) => {
347				let (lhs_degree, lhs) = self.leading_term_with_degree(*left);
348				let (rhs_degree, rhs) = self.leading_term_with_degree(*right);
349				match lhs_degree.cmp(&rhs_degree) {
350					Ordering::Less => (rhs_degree, rhs),
351					Ordering::Equal => (lhs_degree, lhs + rhs),
352					Ordering::Greater => (lhs_degree, lhs),
353				}
354			}
355			ArithCircuitStep::Mul(left, right) => {
356				let (lhs_degree, lhs) = self.leading_term_with_degree(*left);
357				let (rhs_degree, rhs) = self.leading_term_with_degree(*right);
358				(lhs_degree + rhs_degree, lhs * rhs)
359			}
360			ArithCircuitStep::Pow(base, exp) => {
361				let (base_degree, base) = self.leading_term_with_degree(*base);
362				(base_degree * (*exp as usize), base.pow(*exp))
363			}
364		}
365	}
366
367	pub fn evaluate(&self, query: &[F]) -> Result<F, Error> {
368		let mut step_evals = Vec::<F>::with_capacity(self.steps.len());
369		for step in &self.steps {
370			match step {
371				ArithCircuitStep::Add(left, right) => {
372					step_evals.push(step_evals[*left] + step_evals[*right])
373				}
374				ArithCircuitStep::Mul(left, right) => {
375					step_evals.push(step_evals[*left] * step_evals[*right])
376				}
377				ArithCircuitStep::Pow(base, exp) => step_evals.push(step_evals[*base].pow(*exp)),
378				ArithCircuitStep::Const(value) => step_evals.push(*value),
379				ArithCircuitStep::Var(index) => step_evals.push(query[*index]),
380			}
381		}
382		Ok(step_evals.pop().unwrap_or_default())
383	}
384
385	pub fn convert_field<FTgt: Field + From<F>>(&self) -> ArithCircuit<FTgt> {
386		ArithCircuit {
387			steps: self
388				.steps
389				.iter()
390				.map(|step| match step {
391					ArithCircuitStep::Const(value) => ArithCircuitStep::Const((*value).into()),
392					ArithCircuitStep::Var(index) => ArithCircuitStep::Var(*index),
393					ArithCircuitStep::Add(left, right) => ArithCircuitStep::Add(*left, *right),
394					ArithCircuitStep::Mul(left, right) => ArithCircuitStep::Mul(*left, *right),
395					ArithCircuitStep::Pow(base, exp) => ArithCircuitStep::Pow(*base, *exp),
396				})
397				.collect(),
398		}
399	}
400
401	pub fn try_convert_field<FTgt: Field + TryFrom<F>>(
402		&self,
403	) -> Result<ArithCircuit<FTgt>, <FTgt as TryFrom<F>>::Error> {
404		let steps = self
405			.steps
406			.iter()
407			.map(|step| -> Result<ArithCircuitStep<FTgt>, <FTgt as TryFrom<F>>::Error> {
408				let result = match step {
409					ArithCircuitStep::Const(value) => {
410						ArithCircuitStep::Const(FTgt::try_from(*value)?)
411					}
412					ArithCircuitStep::Var(index) => ArithCircuitStep::Var(*index),
413					ArithCircuitStep::Add(left, right) => ArithCircuitStep::Add(*left, *right),
414					ArithCircuitStep::Mul(left, right) => ArithCircuitStep::Mul(*left, *right),
415					ArithCircuitStep::Pow(base, exp) => ArithCircuitStep::Pow(*base, *exp),
416				};
417				Ok(result)
418			})
419			.collect::<Result<Vec<_>, _>>()?;
420
421		Ok(ArithCircuit { steps })
422	}
423
424	/// Creates a new expression with the variable indices remapped.
425	///
426	/// This recursively replaces the variable sub-expressions with an index `i` with the variable
427	/// `indices[i]`.
428	///
429	/// ## Throws
430	///
431	/// * [`Error::IncorrectArgumentLength`] if indices has length less than the current number of
432	///   variables
433	pub fn remap_vars(&self, indices: &[usize]) -> Result<Self, Error> {
434		let steps = self
435			.steps
436			.iter()
437			.map(|step| -> Result<ArithCircuitStep<F>, Error> {
438				if let ArithCircuitStep::Var(index) = step {
439					let new_index = indices.get(*index).copied().ok_or_else(|| {
440						Error::IncorrectArgumentLength {
441							arg: "indices".to_string(),
442							expected: *index,
443						}
444					})?;
445					Ok(ArithCircuitStep::Var(new_index))
446				} else {
447					Ok(*step)
448				}
449			})
450			.collect::<Result<Vec<_>, _>>()?;
451		Ok(Self { steps })
452	}
453
454	/// Substitute variable with index `var` with a constant `value`
455	pub fn const_subst(self, var: usize, value: F) -> Self {
456		let steps = self
457			.steps
458			.iter()
459			.map(|step| match step {
460				ArithCircuitStep::Var(index) if *index == var => ArithCircuitStep::Const(value),
461				_ => *step,
462			})
463			.collect();
464		Self { steps }
465	}
466
467	/// Returns `Some(F)` if the expression is a constant.
468	pub fn get_constant(&self) -> Option<F> {
469		if let ArithCircuitStep::Const(value) =
470			self.steps.last().expect("steps should not be empty")
471		{
472			Some(*value)
473		} else {
474			None
475		}
476	}
477
478	/// Returns the normal form of an expression if it is linear.
479	///
480	/// ## Throws
481	///
482	/// - [`Error::NonLinearExpression`] if the expression is not linear.
483	pub fn linear_normal_form(&self) -> Result<LinearNormalForm<F>, Error> {
484		self.sparse_linear_normal_form().map(Into::into)
485	}
486
487	fn sparse_linear_normal_form(&self) -> Result<SparseLinearNormalForm<F>, Error> {
488		fn sparse_linear_normal_form<F: Field>(
489			step: usize,
490			steps: &[ArithCircuitStep<F>],
491		) -> Result<SparseLinearNormalForm<F>, Error> {
492			match &steps[step] {
493				ArithCircuitStep::Const(val) => Ok((*val).into()),
494				ArithCircuitStep::Var(index) => Ok(SparseLinearNormalForm {
495					constant: F::ZERO,
496					dense_linear_form_len: index + 1,
497					var_coeffs: [(*index, F::ONE)].into(),
498				}),
499				ArithCircuitStep::Add(left, right) => {
500					let left = sparse_linear_normal_form(*left, steps)?;
501					let right = sparse_linear_normal_form(*right, steps)?;
502					Ok(left + right)
503				}
504				ArithCircuitStep::Mul(left, right) => {
505					let left = sparse_linear_normal_form(*left, steps)?;
506					let right = sparse_linear_normal_form(*right, steps)?;
507					left * right
508				}
509				ArithCircuitStep::Pow(_, 0) => Ok(F::ONE.into()),
510				ArithCircuitStep::Pow(expr, 1) => sparse_linear_normal_form(*expr, steps),
511				ArithCircuitStep::Pow(expr, pow) => {
512					let linear_form = sparse_linear_normal_form(*expr, steps)?;
513					if linear_form.dense_linear_form_len != 0 {
514						return Err(Error::NonLinearExpression);
515					}
516					Ok(linear_form.constant.pow(*pow).into())
517				}
518			}
519		}
520
521		sparse_linear_normal_form(self.steps.len() - 1, &self.steps)
522	}
523
524	/// Returns a vector of booleans indicating which variables are used in the expression.
525	///
526	/// The vector is indexed by variable index, and the value at index `i` is `true` if and only
527	/// if the variable is used in the expression.
528	pub fn vars_usage(&self) -> Vec<bool> {
529		let mut usage = vec![false; self.n_vars()];
530
531		for step in &self.steps {
532			if let ArithCircuitStep::Var(index) = step {
533				usage[*index] = true;
534			}
535		}
536
537		usage
538	}
539
540	/// Fold constants in the circuit.
541	pub fn optimize_constants_in_place(&mut self) {
542		for step_index in 0..self.steps.len() {
543			let (prev_steps, curr_steps) = self.steps.split_at_mut(step_index);
544			let curr_step = &mut curr_steps[0];
545			match curr_step {
546				ArithCircuitStep::Const(_) | ArithCircuitStep::Var(_) => {}
547				ArithCircuitStep::Add(left, right) => {
548					match (&prev_steps[*left], &prev_steps[*right]) {
549						(ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => {
550							*curr_step = ArithCircuitStep::Const(*left + *right);
551						}
552						(ArithCircuitStep::Const(left), right) if *left == F::ZERO => {
553							*curr_step = *right;
554						}
555						(left, ArithCircuitStep::Const(right)) if *right == F::ZERO => {
556							*curr_step = *left;
557						}
558						(left, right) if left == right && F::CHARACTERISTIC == 2 => {
559							*curr_step = ArithCircuitStep::Const(F::ZERO);
560						}
561						_ => {}
562					}
563				}
564				ArithCircuitStep::Mul(left, right) => {
565					match (&prev_steps[*left], &prev_steps[*right]) {
566						(ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => {
567							*curr_step = ArithCircuitStep::Const(*left * *right);
568						}
569						(ArithCircuitStep::Const(left), _) if *left == F::ZERO => {
570							*curr_step = ArithCircuitStep::Const(F::ZERO);
571						}
572						(_, ArithCircuitStep::Const(right)) if *right == F::ZERO => {
573							*curr_step = ArithCircuitStep::Const(F::ZERO);
574						}
575						(ArithCircuitStep::Const(left), right) if *left == F::ONE => {
576							*curr_step = *right;
577						}
578						(left, ArithCircuitStep::Const(right)) if *right == F::ONE => {
579							*curr_step = *left;
580						}
581						_ => {}
582					}
583				}
584				ArithCircuitStep::Pow(base, exp) => match prev_steps[*base] {
585					ArithCircuitStep::Const(value) => {
586						*curr_step = ArithCircuitStep::Const(PackedField::pow(value, *exp));
587					}
588					ArithCircuitStep::Pow(base_inner, exp_inner) => {
589						*curr_step = ArithCircuitStep::Pow(base_inner, *exp * exp_inner);
590					}
591					_ => {}
592				},
593			}
594		}
595	}
596
597	/// Same as `optimize_constants_in_place`, but returns a new instance of the circuit.
598	pub fn optimize_constants(mut self) -> Self {
599		self.optimize_constants_in_place();
600		self
601	}
602
603	/// Deduplicate steps in the circuit by using a single instance of each step.
604	fn deduplicate_steps(&mut self) {
605		let mut step_map = HashMap::new();
606		let mut step_indices = Vec::with_capacity(self.steps.len());
607		for step in 0..self.steps.len() {
608			let node = StepNode {
609				index: step,
610				steps: &self.steps,
611			};
612			match step_map.entry(node) {
613				Entry::Occupied(entry) => {
614					step_indices.push(*entry.get());
615				}
616				Entry::Vacant(entry) => {
617					entry.insert(step);
618					step_indices.push(step);
619				}
620			}
621		}
622
623		for step in &mut self.steps {
624			match step {
625				ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
626					*left = step_indices[*left];
627					*right = step_indices[*right];
628				}
629				ArithCircuitStep::Pow(base, _) => *base = step_indices[*base],
630				_ => (),
631			}
632		}
633	}
634
635	/// Remove the unused steps from the circuit.
636	fn compress_unused_steps(&mut self) {
637		fn mark_used<F: Field>(step: usize, steps: &[ArithCircuitStep<F>], used: &mut [bool]) {
638			if used[step] {
639				return;
640			}
641			used[step] = true;
642			match steps[step] {
643				ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
644					mark_used(left, steps, used);
645					mark_used(right, steps, used);
646				}
647				ArithCircuitStep::Pow(base, _) => mark_used(base, steps, used),
648				_ => (),
649			}
650		}
651
652		let mut used = vec![false; self.steps.len()];
653		mark_used(self.steps.len() - 1, &self.steps, &mut used);
654
655		let mut steps_map = (0..self.steps.len()).collect::<Vec<_>>();
656		let mut target_index = 0;
657		for source_index in 0..self.steps.len() {
658			if used[source_index] {
659				if target_index != source_index {
660					match &mut self.steps[source_index] {
661						ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
662							*left = steps_map[*left];
663							*right = steps_map[*right];
664						}
665						ArithCircuitStep::Pow(base, _) => *base = steps_map[*base],
666						_ => (),
667					}
668
669					steps_map[source_index] = target_index;
670					self.steps[target_index] = self.steps[source_index];
671				}
672
673				target_index += 1;
674			}
675		}
676
677		self.steps.truncate(target_index);
678	}
679
680	/// Optimize the current instance of the circuit:
681	/// - Fold constants
682	/// - Extract common steps
683	/// - Remove unused steps
684	pub fn optimize_in_place(&mut self) {
685		self.optimize_constants_in_place();
686		self.deduplicate_steps();
687		self.compress_unused_steps();
688	}
689
690	/// Same as `optimize_in_place`, but returns a new instance of the circuit.
691	pub fn optimize(mut self) -> Self {
692		self.optimize_constants_in_place();
693		self.deduplicate_steps();
694		self.compress_unused_steps();
695
696		self
697	}
698}
699
700impl<F: Field> From<&ArithExpr<F>> for ArithCircuit<F> {
701	fn from(expr: &ArithExpr<F>) -> Self {
702		fn visit_node<F: Field>(
703			node: &Arc<ArithExpr<F>>,
704			node_to_index: &mut HashMap<*const ArithExpr<F>, usize>,
705			steps: &mut Vec<ArithCircuitStep<F>>,
706		) -> usize {
707			if let Some(index) = node_to_index.get(&Arc::as_ptr(node)) {
708				return *index;
709			}
710
711			let step = match &**node {
712				ArithExpr::Const(value) => ArithCircuitStep::Const(*value),
713				ArithExpr::Var(index) => ArithCircuitStep::Var(*index),
714				ArithExpr::Add(left, right) => {
715					let left = visit_node(left, node_to_index, steps);
716					let right = visit_node(right, node_to_index, steps);
717					ArithCircuitStep::Add(left, right)
718				}
719				ArithExpr::Mul(left, right) => {
720					let left = visit_node(left, node_to_index, steps);
721					let right = visit_node(right, node_to_index, steps);
722					ArithCircuitStep::Mul(left, right)
723				}
724				ArithExpr::Pow(base, exp) => {
725					let base = visit_node(base, node_to_index, steps);
726					ArithCircuitStep::Pow(base, *exp)
727				}
728			};
729
730			steps.push(step);
731			node_to_index.insert(Arc::as_ptr(node), steps.len() - 1);
732			steps.len() - 1
733		}
734
735		let mut steps = Vec::new();
736		let mut node_to_index = HashMap::new();
737		match expr {
738			ArithExpr::Const(c) => {
739				steps.push(ArithCircuitStep::Const(*c));
740			}
741			ArithExpr::Var(var) => {
742				steps.push(ArithCircuitStep::Var(*var));
743			}
744			ArithExpr::Add(left, right) => {
745				let left = visit_node(left, &mut node_to_index, &mut steps);
746				let right = visit_node(right, &mut node_to_index, &mut steps);
747				steps.push(ArithCircuitStep::Add(left, right));
748			}
749			ArithExpr::Mul(left, right) => {
750				let left = visit_node(left, &mut node_to_index, &mut steps);
751				let right = visit_node(right, &mut node_to_index, &mut steps);
752				steps.push(ArithCircuitStep::Mul(left, right));
753			}
754			ArithExpr::Pow(base, exp) => {
755				let base = visit_node(base, &mut node_to_index, &mut steps);
756				steps.push(ArithCircuitStep::Pow(base, *exp));
757			}
758		}
759
760		Self { steps }
761	}
762}
763
764impl<F: Field> From<ArithExpr<F>> for ArithCircuit<F> {
765	fn from(expr: ArithExpr<F>) -> Self {
766		Self::from(&expr)
767	}
768}
769
770impl<F: Field> From<&ArithCircuit<F>> for ArithExpr<F> {
771	fn from(circuit: &ArithCircuit<F>) -> Self {
772		let mut step_to_node = vec![Option::<Arc<Self>>::None; circuit.steps.len()];
773
774		for step in 0..circuit.steps.len() {
775			let node = match &circuit.steps[step] {
776				ArithCircuitStep::Const(value) => Self::Const(*value),
777				ArithCircuitStep::Var(index) => Self::Var(*index),
778				ArithCircuitStep::Add(left, right) => {
779					let left = step_to_node[*left].clone().expect("step must be present");
780					let right = step_to_node[*right].clone().expect("step must be present");
781					Self::Add(left, right)
782				}
783				ArithCircuitStep::Mul(left, right) => {
784					let left = step_to_node[*left].clone().expect("step must be present");
785					let right = step_to_node[*right].clone().expect("step must be present");
786					Self::Mul(left, right)
787				}
788				ArithCircuitStep::Pow(base, exp) => {
789					let base = step_to_node[*base].clone().expect("step must be present");
790					Self::Pow(base, *exp)
791				}
792			};
793			step_to_node[step] = Some(Arc::new(node));
794		}
795
796		Arc::into_inner(
797			step_to_node
798				.pop()
799				.expect("steps should not be empty")
800				.expect("last step should be initialized"),
801		)
802		.expect("last step must have a single instance")
803	}
804}
805
806impl<F: Field> From<ArithCircuit<F>> for ArithExpr<F> {
807	fn from(circuit: ArithCircuit<F>) -> Self {
808		Self::from(&circuit)
809	}
810}
811
812impl<F: Field> Display for ArithCircuit<F> {
813	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
814		fn display_step<F: Field>(
815			step: usize,
816			steps: &[ArithCircuitStep<F>],
817			f: &mut fmt::Formatter<'_>,
818		) -> Result<(), fmt::Error> {
819			match &steps[step] {
820				ArithCircuitStep::Const(value) => write!(f, "{value}"),
821				ArithCircuitStep::Var(index) => write!(f, "x{index}"),
822				ArithCircuitStep::Add(left, right) => {
823					write!(f, "(")?;
824					display_step(*left, steps, f)?;
825					write!(f, " + ")?;
826					display_step(*right, steps, f)?;
827					write!(f, ")")
828				}
829				ArithCircuitStep::Mul(left, right) => {
830					write!(f, "(")?;
831					display_step(*left, steps, f)?;
832					write!(f, " * ")?;
833					display_step(*right, steps, f)?;
834					write!(f, ")")
835				}
836				ArithCircuitStep::Pow(base, exp) => {
837					write!(f, "(")?;
838					display_step(*base, steps, f)?;
839					write!(f, ")^{exp}")
840				}
841			}
842		}
843
844		display_step(self.steps.len() - 1, &self.steps, f)
845	}
846}
847
848impl<F: Field> PartialEq for ArithCircuit<F> {
849	fn eq(&self, other: &Self) -> bool {
850		StepNode {
851			index: self.steps.len() - 1,
852			steps: &self.steps,
853		} == StepNode {
854			index: other.steps.len() - 1,
855			steps: &other.steps,
856		}
857	}
858}
859
860impl<F: Field> Add for ArithCircuit<F> {
861	type Output = Self;
862
863	fn add(mut self, rhs: Self) -> Self {
864		self += rhs;
865		self
866	}
867}
868
869impl<F: Field> AddAssign for ArithCircuit<F> {
870	fn add_assign(&mut self, mut rhs: Self) {
871		let old_len = self.steps.len();
872		add_offset(&mut rhs.steps, old_len);
873		self.steps.extend(rhs.steps);
874		self.steps
875			.push(ArithCircuitStep::Add(old_len - 1, self.steps.len() - 1));
876	}
877}
878
879impl<F: Field> Sub for ArithCircuit<F> {
880	type Output = Self;
881
882	fn sub(mut self, rhs: Self) -> Self {
883		self -= rhs;
884		self
885	}
886}
887
888impl<F: Field> SubAssign for ArithCircuit<F> {
889	#[allow(clippy::suspicious_op_assign_impl)]
890	fn sub_assign(&mut self, rhs: Self) {
891		*self += rhs;
892	}
893}
894
895impl<F: Field> Mul for ArithCircuit<F> {
896	type Output = Self;
897
898	fn mul(mut self, rhs: Self) -> Self {
899		self *= rhs;
900		self
901	}
902}
903
904impl<F: Field> MulAssign for ArithCircuit<F> {
905	fn mul_assign(&mut self, mut rhs: Self) {
906		let old_len = self.steps.len();
907		add_offset(&mut rhs.steps, old_len);
908		self.steps.extend(rhs.steps);
909		self.steps
910			.push(ArithCircuitStep::Mul(old_len - 1, self.steps.len() - 1));
911	}
912}
913
914impl<F: Field> Sum for ArithCircuit<F> {
915	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
916		iter.fold(Self::zero(), |sum, item| sum + item)
917	}
918}
919
920impl<F: Field> Product for ArithCircuit<F> {
921	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
922		iter.fold(Self::one(), |product, item| product * item)
923	}
924}
925
926fn add_offset<F: Field>(steps: &mut [ArithCircuitStep<F>], offset: usize) {
927	for step in steps.iter_mut() {
928		match step {
929			ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
930				*left += offset;
931				*right += offset;
932			}
933			ArithCircuitStep::Pow(base, _) => {
934				*base += offset;
935			}
936			_ => (),
937		}
938	}
939}
940/// A normal form for a linear expression.
941#[derive(Debug, Default, Clone, PartialEq, Eq)]
942pub struct LinearNormalForm<F: Field> {
943	/// The constant offset of the expression.
944	pub constant: F,
945	/// A vector mapping variable indices to their coefficients.
946	pub var_coeffs: Vec<F>,
947}
948
949struct SparseLinearNormalForm<F: Field> {
950	/// The constant offset of the expression.
951	pub constant: F,
952	/// A map of variable indices to their coefficients.
953	pub var_coeffs: HashMap<usize, F>,
954	/// The `var_coeffs` vector len if converted to [`LinearNormalForm`].
955	/// It is used for optimization of conversion to [`LinearNormalForm`].
956	pub dense_linear_form_len: usize,
957}
958
959impl<F: Field> From<F> for SparseLinearNormalForm<F> {
960	fn from(value: F) -> Self {
961		Self {
962			constant: value,
963			dense_linear_form_len: 0,
964			var_coeffs: HashMap::new(),
965		}
966	}
967}
968
969impl<F: Field> Add for SparseLinearNormalForm<F> {
970	type Output = Self;
971	fn add(self, rhs: Self) -> Self::Output {
972		let (mut result, consumable) = if self.var_coeffs.len() < rhs.var_coeffs.len() {
973			(rhs, self)
974		} else {
975			(self, rhs)
976		};
977		result.constant += consumable.constant;
978		if consumable.dense_linear_form_len > result.dense_linear_form_len {
979			result.dense_linear_form_len = consumable.dense_linear_form_len;
980		}
981
982		for (index, coeff) in consumable.var_coeffs {
983			result
984				.var_coeffs
985				.entry(index)
986				.and_modify(|res_coeff| {
987					*res_coeff += coeff;
988				})
989				.or_insert(coeff);
990		}
991		result
992	}
993}
994
995impl<F: Field> Mul for SparseLinearNormalForm<F> {
996	type Output = Result<Self, Error>;
997	fn mul(self, rhs: Self) -> Result<Self, Error> {
998		if !self.var_coeffs.is_empty() && !rhs.var_coeffs.is_empty() {
999			return Err(Error::NonLinearExpression);
1000		}
1001		let (mut result, consumable) = if self.var_coeffs.is_empty() {
1002			(rhs, self)
1003		} else {
1004			(self, rhs)
1005		};
1006		result.constant *= consumable.constant;
1007		for coeff in result.var_coeffs.values_mut() {
1008			*coeff *= consumable.constant;
1009		}
1010		Ok(result)
1011	}
1012}
1013
1014impl<F: Field> From<SparseLinearNormalForm<F>> for LinearNormalForm<F> {
1015	fn from(value: SparseLinearNormalForm<F>) -> Self {
1016		let mut var_coeffs = vec![F::ZERO; value.dense_linear_form_len];
1017		for (i, coeff) in value.var_coeffs {
1018			var_coeffs[i] = coeff;
1019		}
1020		Self {
1021			constant: value.constant,
1022			var_coeffs,
1023		}
1024	}
1025}
1026
1027/// A helper structure to be used for by-value comparison and hashing of the subexpressions
1028/// regardless if they are in the same ArithCircuit or not.
1029#[derive(Eq)]
1030struct StepNode<'a, F: Field> {
1031	index: usize,
1032	steps: &'a [ArithCircuitStep<F>],
1033}
1034
1035impl<F: Field> StepNode<'_, F> {
1036	fn prev_step(&self, step: usize) -> Self {
1037		StepNode {
1038			index: step,
1039			steps: self.steps,
1040		}
1041	}
1042}
1043
1044impl<F: Field> PartialEq for StepNode<'_, F> {
1045	#[allow(clippy::suspicious_operation_groupings)] // false positive
1046	fn eq(&self, other: &Self) -> bool {
1047		match (&self.steps[self.index], &other.steps[other.index]) {
1048			(ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => left == right,
1049			(ArithCircuitStep::Var(left), ArithCircuitStep::Var(right)) => left == right,
1050			(
1051				ArithCircuitStep::Add(left, right),
1052				ArithCircuitStep::Add(other_left, other_right),
1053			)
1054			| (
1055				ArithCircuitStep::Mul(left, right),
1056				ArithCircuitStep::Mul(other_left, other_right),
1057			) => {
1058				self.prev_step(*left) == other.prev_step(*other_left)
1059					&& self.prev_step(*right) == other.prev_step(*other_right)
1060			}
1061			(ArithCircuitStep::Pow(base, exp), ArithCircuitStep::Pow(other_base, other_exp)) => {
1062				self.prev_step(*base) == other.prev_step(*other_base) && exp == other_exp
1063			}
1064			_ => false,
1065		}
1066	}
1067}
1068
1069impl<F: Field> Hash for StepNode<'_, F> {
1070	fn hash<H: Hasher>(&self, state: &mut H) {
1071		match self.steps[self.index] {
1072			ArithCircuitStep::Const(value) => {
1073				0u8.hash(state);
1074				value.hash(state);
1075			}
1076			ArithCircuitStep::Var(index) => {
1077				1u8.hash(state);
1078				index.hash(state);
1079			}
1080			ArithCircuitStep::Add(left, right) => {
1081				2u8.hash(state);
1082				self.prev_step(left).hash(state);
1083				self.prev_step(right).hash(state);
1084			}
1085			ArithCircuitStep::Mul(left, right) => {
1086				3u8.hash(state);
1087				self.prev_step(left).hash(state);
1088				self.prev_step(right).hash(state);
1089			}
1090			ArithCircuitStep::Pow(base, exp) => {
1091				4u8.hash(state);
1092				self.prev_step(base).hash(state);
1093				exp.hash(state);
1094			}
1095		}
1096	}
1097}
1098
1099/// The breakdown of the number of operations it takes to evaluate this circuit.
1100#[derive(Default)]
1101pub struct EvalCost {
1102	/// Number of additions.
1103	///
1104	/// Addition is a relatively cheap operation.
1105	pub n_adds: usize,
1106	/// Number of multiplications.
1107	///
1108	/// Multiplications are a dominating factor in the evaluation cost.
1109	pub n_muls: usize,
1110	/// Number of squares.
1111	///
1112	/// Squares are a relatively cheap operation. Generally we assume that squaring takes the 1/5th
1113	/// of a single multiplication.
1114	pub n_squares: usize,
1115}
1116
1117impl EvalCost {
1118	pub fn add(n_adds: usize) -> Self {
1119		Self {
1120			n_adds,
1121			..Self::default()
1122		}
1123	}
1124
1125	pub fn mul(n_muls: usize) -> Self {
1126		Self {
1127			n_muls,
1128			..Self::default()
1129		}
1130	}
1131
1132	pub fn square(n_squares: usize) -> Self {
1133		Self {
1134			n_squares,
1135			..Self::default()
1136		}
1137	}
1138
1139	/// Returns an approximate cost to evaluate an expression as a single number.
1140	///
1141	/// Since the evaluation time is dominated by the number of multiplications we can roughly
1142	/// gauge the cost by the number of the multiplications. We count the number of squares
1143	/// as 1/5th of a single multiplication.
1144	pub fn mult_cost_approx(&self) -> usize {
1145		self.n_muls + self.n_squares.div_ceil(5)
1146	}
1147}
1148
1149impl Add for EvalCost {
1150	type Output = Self;
1151	fn add(self, other: Self) -> Self {
1152		Self {
1153			n_adds: self.n_adds + other.n_adds,
1154			n_muls: self.n_muls + other.n_muls,
1155			n_squares: self.n_squares + other.n_squares,
1156		}
1157	}
1158}
1159
1160#[cfg(test)]
1161mod tests {
1162	use std::collections::HashSet;
1163
1164	use assert_matches::assert_matches;
1165	use binius_field::{BinaryField, BinaryField1b, BinaryField8b, BinaryField128b};
1166	use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes};
1167
1168	use super::*;
1169
1170	#[test]
1171	fn test_degree_with_pow() {
1172		let expr = ArithCircuit::constant(BinaryField8b::new(6)).pow(7);
1173		assert_eq!(expr.degree(), 0);
1174
1175		let expr: ArithCircuit<BinaryField8b> = ArithCircuit::var(0).pow(7);
1176		assert_eq!(expr.degree(), 7);
1177
1178		let expr: ArithCircuit<BinaryField8b> =
1179			(ArithCircuit::var(0) * ArithCircuit::var(1)).pow(7);
1180		assert_eq!(expr.degree(), 14);
1181	}
1182
1183	#[test]
1184	fn test_n_vars() {
1185		type F = BinaryField8b;
1186		let expr = ArithCircuit::<F>::var(0) * ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR)
1187			+ ArithCircuit::var(2).pow(2);
1188		assert_eq!(expr.n_vars(), 3);
1189	}
1190
1191	#[test]
1192	fn test_leading_term_with_degree() {
1193		let expr = ArithCircuit::var(0)
1194			* (ArithCircuit::var(1)
1195				* ArithCircuit::var(2)
1196				* ArithCircuit::constant(BinaryField8b::MULTIPLICATIVE_GENERATOR)
1197				+ ArithCircuit::var(4))
1198			+ ArithCircuit::var(5).pow(3)
1199			+ ArithCircuit::constant(BinaryField8b::ONE);
1200
1201		let expected_expr = ArithCircuit::var(0)
1202			* ((ArithCircuit::var(1) * ArithCircuit::var(2))
1203				* ArithCircuit::constant(BinaryField8b::MULTIPLICATIVE_GENERATOR))
1204			+ ArithCircuit::var(5).pow(3);
1205
1206		assert_eq!(expr.leading_term_with_degree(expr.steps().len() - 1), (3, expected_expr));
1207	}
1208
1209	#[test]
1210	fn test_remap_vars_with_too_few_vars() {
1211		type F = BinaryField8b;
1212		let expr =
1213			((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(1)).pow(3);
1214		assert_matches!(expr.remap_vars(&[5]), Err(Error::IncorrectArgumentLength { .. }));
1215	}
1216
1217	#[test]
1218	fn test_remap_vars_works() {
1219		type F = BinaryField8b;
1220		let expr =
1221			((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(1)).pow(3);
1222		let new_expr = expr.remap_vars(&[5, 3]);
1223
1224		let expected =
1225			((ArithCircuit::var(5) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(3)).pow(3);
1226		assert_eq!(new_expr.unwrap(), expected);
1227	}
1228
1229	#[test]
1230	fn test_optimize_identity_handling() {
1231		type F = BinaryField8b;
1232		let zero = ArithCircuit::<F>::zero();
1233		let one = ArithCircuit::<F>::one();
1234
1235		assert_eq!((zero.clone() * ArithCircuit::<F>::var(0)).optimize(), zero);
1236		assert_eq!((ArithCircuit::<F>::var(0) * zero.clone()).optimize(), zero);
1237
1238		assert_eq!((ArithCircuit::<F>::var(0) * one.clone()).optimize(), ArithCircuit::var(0));
1239		assert_eq!((one * ArithCircuit::<F>::var(0)).optimize(), ArithCircuit::var(0));
1240
1241		assert_eq!((ArithCircuit::<F>::var(0) + zero.clone()).optimize(), ArithCircuit::var(0));
1242		assert_eq!((zero.clone() + ArithCircuit::<F>::var(0)).optimize(), ArithCircuit::var(0));
1243
1244		assert_eq!((ArithCircuit::<F>::var(0) + ArithCircuit::var(0)).optimize(), zero);
1245	}
1246
1247	#[test]
1248	fn test_const_subst_and_optimize() {
1249		// NB: this is FlushSumcheckComposition from the constraint_system
1250		type F = BinaryField8b;
1251		let expr = ArithCircuit::var(0) * ArithCircuit::var(1) + ArithCircuit::one()
1252			- ArithCircuit::var(1);
1253		assert_eq!(expr.const_subst(1, F::ZERO).optimize().get_constant(), Some(F::ONE));
1254	}
1255
1256	#[test]
1257	fn test_expression_upcast() {
1258		type F8 = BinaryField8b;
1259		type F = BinaryField128b;
1260
1261		let expr = ((ArithCircuit::var(0) + ArithCircuit::constant(F8::ONE))
1262			* ArithCircuit::constant(F8::new(222)))
1263		.pow(3);
1264
1265		let expected = ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1266			* ArithCircuit::constant(F::new(222)))
1267		.pow(3);
1268		assert_eq!(expr.convert_field::<F>(), expected);
1269	}
1270
1271	#[test]
1272	fn test_expression_downcast() {
1273		type F8 = BinaryField8b;
1274		type F = BinaryField128b;
1275
1276		let expr = ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1277			* ArithCircuit::constant(F::new(222)))
1278		.pow(3);
1279
1280		assert!(expr.try_convert_field::<BinaryField1b>().is_err());
1281
1282		let expected = ((ArithCircuit::var(0) + ArithCircuit::constant(F8::ONE))
1283			* ArithCircuit::constant(F8::new(222)))
1284		.pow(3);
1285		assert_eq!(expr.try_convert_field::<BinaryField8b>().unwrap(), expected);
1286	}
1287
1288	#[test]
1289	fn test_linear_normal_form() {
1290		type F = BinaryField128b;
1291		struct Case {
1292			expr: ArithCircuit<F>,
1293			expected: LinearNormalForm<F>,
1294		}
1295		let cases = vec![
1296			Case {
1297				expr: ArithCircuit::constant(F::ONE),
1298				expected: LinearNormalForm {
1299					constant: F::ONE,
1300					var_coeffs: vec![],
1301				},
1302			},
1303			Case {
1304				expr: (ArithCircuit::constant(F::new(2)) * ArithCircuit::constant(F::new(3)))
1305					.pow(2) + ArithCircuit::constant(F::new(3))
1306					* (ArithCircuit::constant(F::new(4)) + ArithCircuit::var(0)),
1307				expected: LinearNormalForm {
1308					constant: (F::new(2) * F::new(3)).pow(2) + F::new(3) * F::new(4),
1309					var_coeffs: vec![F::new(3)],
1310				},
1311			},
1312			Case {
1313				expr: ArithCircuit::constant(F::new(133))
1314					+ ArithCircuit::constant(F::new(42)) * ArithCircuit::var(0)
1315					+ ArithCircuit::var(2)
1316					+ ArithCircuit::constant(F::new(11))
1317						* ArithCircuit::constant(F::new(37))
1318						* ArithCircuit::var(3),
1319				expected: LinearNormalForm {
1320					constant: F::new(133),
1321					var_coeffs: vec![F::new(42), F::ZERO, F::ONE, F::new(11) * F::new(37)],
1322				},
1323			},
1324		];
1325		for Case { expr, expected } in cases {
1326			let normal_form = expr.linear_normal_form().unwrap();
1327			assert_eq!(normal_form.constant, expected.constant);
1328			assert_eq!(normal_form.var_coeffs, expected.var_coeffs);
1329		}
1330	}
1331
1332	fn unique_nodes_count<F: Field>(expr: &ArithCircuit<F>) -> usize {
1333		let mut unique_nodes = HashSet::new();
1334
1335		for step in 0..expr.steps.len() {
1336			unique_nodes.insert(StepNode {
1337				index: step,
1338				steps: &expr.steps,
1339			});
1340		}
1341
1342		unique_nodes.len()
1343	}
1344
1345	fn check_serialize_bytes_roundtrip<F: Field>(expr: ArithCircuit<F>) {
1346		let mut buf = Vec::new();
1347
1348		expr.serialize(&mut buf, SerializationMode::CanonicalTower)
1349			.unwrap();
1350		let deserialized =
1351			ArithCircuit::<F>::deserialize(&buf[..], SerializationMode::CanonicalTower).unwrap();
1352		assert_eq!(expr, deserialized);
1353		assert_eq!(unique_nodes_count(&expr), unique_nodes_count(&deserialized));
1354	}
1355
1356	#[test]
1357	fn test_serialize_bytes_roundtrip() {
1358		type F = BinaryField128b;
1359		let expr = ArithCircuit::var(0)
1360			* (ArithCircuit::var(1)
1361				* ArithCircuit::var(2)
1362				* ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR)
1363				+ ArithCircuit::var(4))
1364			+ ArithCircuit::var(5).pow(3)
1365			+ ArithCircuit::constant(F::ONE);
1366
1367		check_serialize_bytes_roundtrip(expr);
1368	}
1369
1370	#[test]
1371	fn test_serialize_bytes_roundtrip_with_duplicates() {
1372		type F = BinaryField128b;
1373		let expr = (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1374			* (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1375			+ (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1376			+ ArithCircuit::var(1);
1377
1378		check_serialize_bytes_roundtrip(expr);
1379	}
1380
1381	#[test]
1382	fn test_binary_tower_level() {
1383		type F = BinaryField128b;
1384		let expr =
1385			ArithCircuit::constant(F::ONE) + ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR);
1386		assert_eq!(expr.binary_tower_level(), F::MULTIPLICATIVE_GENERATOR.min_tower_level());
1387	}
1388
1389	#[test]
1390	fn test_arith_circuit_steps() {
1391		type F = BinaryField8b;
1392		let expr = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1)) * ArithCircuit::var(2);
1393		let steps = expr.steps();
1394		assert_eq!(steps.len(), 5); // 3 variables, 1 addition, 1 multiplication
1395		assert!(matches!(steps[0], ArithCircuitStep::Var(0)));
1396		assert!(matches!(steps[1], ArithCircuitStep::Var(1)));
1397		assert!(matches!(steps[2], ArithCircuitStep::Add(_, _)));
1398		assert!(matches!(steps[3], ArithCircuitStep::Var(2)));
1399		assert!(matches!(steps[4], ArithCircuitStep::Mul(_, _)));
1400	}
1401
1402	#[test]
1403	fn test_optimize_constants() {
1404		type F = BinaryField8b;
1405		let mut circuit = (ArithCircuit::<F>::var(0) + ArithCircuit::constant(F::ZERO))
1406			* ArithCircuit::var(1)
1407			+ ArithCircuit::constant(F::ONE) * ArithCircuit::var(2)
1408			+ ArithCircuit::constant(F::ONE).pow(4).pow(5)
1409			+ (ArithCircuit::var(5) + ArithCircuit::var(5));
1410		circuit.optimize_constants_in_place();
1411
1412		let expected_circuit = ArithCircuit::var(0) * ArithCircuit::var(1)
1413			+ ArithCircuit::var(2)
1414			+ ArithCircuit::constant(F::ONE);
1415
1416		assert_eq!(circuit, expected_circuit);
1417	}
1418
1419	#[test]
1420	fn test_deduplicate_steps() {
1421		type F = BinaryField8b;
1422		let mut circuit = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1))
1423			* (ArithCircuit::var(0) + ArithCircuit::var(1))
1424			+ (ArithCircuit::var(0) + ArithCircuit::var(1));
1425		circuit.deduplicate_steps();
1426
1427		let expected_circuit = ArithCircuit::<F> {
1428			steps: vec![
1429				ArithCircuitStep::Var(0),
1430				ArithCircuitStep::Var(1),
1431				ArithCircuitStep::Add(0, 1),
1432				ArithCircuitStep::Mul(2, 2),
1433				ArithCircuitStep::Add(3, 2),
1434			],
1435		};
1436		assert_eq!(circuit, expected_circuit);
1437	}
1438
1439	#[test]
1440	fn test_compress_unused_steps() {
1441		type F = BinaryField8b;
1442		let mut circuit = ArithCircuit::<F> {
1443			steps: vec![
1444				ArithCircuitStep::Var(0),
1445				ArithCircuitStep::Var(1),
1446				ArithCircuitStep::Var(2),
1447				ArithCircuitStep::Add(0, 1),
1448				ArithCircuitStep::Var(3),
1449				ArithCircuitStep::Const(F::ZERO),
1450				ArithCircuitStep::Var(2),
1451				ArithCircuitStep::Mul(3, 3),
1452			],
1453		};
1454		circuit.compress_unused_steps();
1455
1456		let expected_circuit = ArithCircuit::<F> {
1457			steps: vec![
1458				ArithCircuitStep::Var(0),
1459				ArithCircuitStep::Var(1),
1460				ArithCircuitStep::Add(0, 1),
1461				ArithCircuitStep::Mul(2, 2),
1462			],
1463		};
1464		assert_eq!(circuit.steps, expected_circuit.steps);
1465	}
1466
1467	#[test]
1468	fn test_conversion_from_expr_node_doesnt_create_duplicated_steps() {
1469		type F = BinaryField8b;
1470		let sub_expr = Arc::new(ArithExpr::<F>::Var(0) + ArithExpr::<F>::Var(1));
1471		let expr = ArithExpr::Mul(sub_expr.clone(), sub_expr.clone()) + sub_expr;
1472		let circuit = ArithCircuit::<F>::from(&expr);
1473		assert_eq!(circuit.steps.len(), 5);
1474		assert_eq!(unique_nodes_count(&circuit), 5);
1475	}
1476
1477	fn unique_nodes_count_expr<F: Field>(expr: &ArithExpr<F>) -> usize {
1478		fn add_nodes<F: Field>(
1479			expr: &Arc<ArithExpr<F>>,
1480			unique_nodes: &mut HashSet<*const ArithExpr<F>>,
1481		) {
1482			if !unique_nodes.insert(Arc::as_ptr(expr)) {
1483				return;
1484			}
1485
1486			match expr.as_ref() {
1487				ArithExpr::Const(_) | ArithExpr::Var(_) => {}
1488				ArithExpr::Add(left, right) | ArithExpr::Mul(left, right) => {
1489					add_nodes(left, unique_nodes);
1490					add_nodes(right, unique_nodes);
1491				}
1492				ArithExpr::Pow(base, _) => {
1493					add_nodes(base, unique_nodes);
1494				}
1495			}
1496		}
1497
1498		let mut unique_nodes = HashSet::new();
1499		add_nodes(&Arc::new(expr.clone()), &mut unique_nodes);
1500
1501		unique_nodes.len()
1502	}
1503
1504	#[test]
1505	fn test_conversion_from_circuit_to_expr_node() {
1506		type F = BinaryField8b;
1507
1508		let arith_circuit = ArithCircuit::<F> {
1509			steps: vec![
1510				ArithCircuitStep::Var(0),
1511				ArithCircuitStep::Var(1),
1512				ArithCircuitStep::Add(0, 1),
1513				ArithCircuitStep::Mul(2, 2),
1514				ArithCircuitStep::Add(3, 2),
1515			],
1516		};
1517		let expr = ArithExpr::from(&arith_circuit);
1518		let expected_expr = (ArithExpr::Var(0) + ArithExpr::Var(1))
1519			* (ArithExpr::Var(0) + ArithExpr::Var(1))
1520			+ (ArithExpr::Var(0) + ArithExpr::Var(1));
1521		assert_eq!(expr, expected_expr);
1522		assert_eq!(unique_nodes_count_expr(&expr), 5);
1523	}
1524
1525	#[test]
1526	fn test_evaluate() {
1527		type F = BinaryField8b;
1528		let expr = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1))
1529			* (ArithCircuit::var(2) + ArithCircuit::var(3)).pow(5);
1530		let result = expr
1531			.evaluate(&[F::new(2), F::new(3), F::new(4), F::new(5)])
1532			.unwrap();
1533		assert_eq!(result, F::new(2) + F::new(3) * (F::new(4) + F::new(5)).pow(5));
1534	}
1535}