binius_math/
arith_expr.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	cmp::Ordering,
5	collections::{hash_map::Entry, HashMap},
6	fmt::{self, Display},
7	hash::{Hash, Hasher},
8	iter::{Product, Sum},
9	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
10	sync::Arc,
11};
12
13use binius_field::{Field, PackedField, TowerField};
14use binius_macros::{DeserializeBytes, SerializeBytes};
15
16use super::error::Error;
17
18/// A builder for arithmetic expressions.
19///
20/// This is a simple representation of multivariate polynomials. The user can explicitly
21/// specify which subexpressions are shared by using `Arc` to wrap them. The same subexpressions
22/// are guaranteed to be converted to the same steps `ArithCircuit` after the conversion.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub enum ArithExpr<F: Field> {
25	Const(F),
26	Var(usize),
27	Add(Arc<ArithExpr<F>>, Arc<ArithExpr<F>>),
28	Mul(Arc<ArithExpr<F>>, Arc<ArithExpr<F>>),
29	Pow(Arc<ArithExpr<F>>, u64),
30}
31
32impl<F: Field + Display> Display for ArithExpr<F> {
33	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34		match self {
35			Self::Const(v) => write!(f, "{v}"),
36			Self::Var(i) => write!(f, "x{i}"),
37			Self::Add(x, y) => write!(f, "({} + {})", &**x, &**y),
38			Self::Mul(x, y) => write!(f, "({} * {})", &**x, &**y),
39			Self::Pow(x, p) => write!(f, "({})^{p}", &**x),
40		}
41	}
42}
43
44impl<F: Field> ArithExpr<F> {
45	pub fn pow(self, exp: u64) -> Self {
46		Self::Pow(Arc::new(self), exp)
47	}
48
49	pub const fn zero() -> Self {
50		Self::Const(F::ZERO)
51	}
52
53	pub const fn one() -> Self {
54		Self::Const(F::ONE)
55	}
56}
57
58impl<F> Default for ArithExpr<F>
59where
60	F: Field,
61{
62	fn default() -> Self {
63		Self::zero()
64	}
65}
66
67impl<F> Add for ArithExpr<F>
68where
69	F: Field,
70{
71	type Output = Self;
72
73	fn add(self, rhs: Self) -> Self {
74		Self::Add(Arc::new(self), Arc::new(rhs))
75	}
76}
77
78impl<F> Add<Arc<Self>> for ArithExpr<F>
79where
80	F: Field,
81{
82	type Output = Self;
83
84	fn add(self, rhs: Arc<Self>) -> Self {
85		Self::Add(Arc::new(self), rhs)
86	}
87}
88
89impl<F> AddAssign for ArithExpr<F>
90where
91	F: Field,
92{
93	fn add_assign(&mut self, rhs: Self) {
94		*self = std::mem::take(self) + rhs;
95	}
96}
97
98impl<F> AddAssign<Arc<Self>> for ArithExpr<F>
99where
100	F: Field,
101{
102	fn add_assign(&mut self, rhs: Arc<Self>) {
103		*self = std::mem::take(self) + rhs;
104	}
105}
106
107impl<F> Sub for ArithExpr<F>
108where
109	F: Field,
110{
111	type Output = Self;
112
113	fn sub(self, rhs: Self) -> Self {
114		Self::Add(Arc::new(self), Arc::new(rhs))
115	}
116}
117
118impl<F> Sub<Arc<Self>> for ArithExpr<F>
119where
120	F: Field,
121{
122	type Output = Self;
123
124	fn sub(self, rhs: Arc<Self>) -> Self {
125		Self::Add(Arc::new(self), rhs)
126	}
127}
128
129impl<F> SubAssign for ArithExpr<F>
130where
131	F: Field,
132{
133	fn sub_assign(&mut self, rhs: Self) {
134		*self = std::mem::take(self) - rhs;
135	}
136}
137
138impl<F> SubAssign<Arc<Self>> for ArithExpr<F>
139where
140	F: Field,
141{
142	fn sub_assign(&mut self, rhs: Arc<Self>) {
143		*self = std::mem::take(self) - rhs;
144	}
145}
146
147impl<F> Mul for ArithExpr<F>
148where
149	F: Field,
150{
151	type Output = Self;
152
153	fn mul(self, rhs: Self) -> Self {
154		Self::Mul(Arc::new(self), Arc::new(rhs))
155	}
156}
157
158impl<F> Mul<Arc<Self>> for ArithExpr<F>
159where
160	F: Field,
161{
162	type Output = Self;
163
164	fn mul(self, rhs: Arc<Self>) -> Self {
165		Self::Mul(Arc::new(self), rhs)
166	}
167}
168
169impl<F> MulAssign for ArithExpr<F>
170where
171	F: Field,
172{
173	fn mul_assign(&mut self, rhs: Self) {
174		*self = std::mem::take(self) * rhs;
175	}
176}
177
178impl<F> MulAssign<Arc<Self>> for ArithExpr<F>
179where
180	F: Field,
181{
182	fn mul_assign(&mut self, rhs: Arc<Self>) {
183		*self = std::mem::take(self) * rhs;
184	}
185}
186
187impl<F: Field> Sum for ArithExpr<F> {
188	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
189		iter.reduce(|acc, item| acc + item).unwrap_or(Self::zero())
190	}
191}
192
193impl<F: Field> Product for ArithExpr<F> {
194	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
195		iter.reduce(|acc, item| acc * item).unwrap_or(Self::one())
196	}
197}
198
199#[derive(Clone, Copy, Debug, SerializeBytes, DeserializeBytes, PartialEq, Eq)]
200pub enum ArithCircuitStep<F: Field> {
201	Add(usize, usize),
202	Mul(usize, usize),
203	Pow(usize, u64),
204	Const(F),
205	Var(usize),
206}
207
208impl<F: Field> Default for ArithCircuitStep<F> {
209	fn default() -> Self {
210		Self::Const(F::ZERO)
211	}
212}
213
214/// Arithmetic expressions that can be evaluated symbolically.
215///
216/// Arithmetic expressions are trees, where the leaves are either constants or variables, and the
217/// non-leaf nodes are arithmetic operations, such as addition, multiplication, etc. They are
218/// specific representations of multivariate polynomials.
219///
220/// We store a tree in a form of a simple circuit.
221/// This implementation isn't optimized for performance, but rather for simplicity
222/// to allow easy conversion and preservation of the common subexpressions
223#[derive(Clone, Debug, SerializeBytes, DeserializeBytes, Eq)]
224pub struct ArithCircuit<F: Field> {
225	steps: Vec<ArithCircuitStep<F>>,
226}
227
228impl<F: Field> ArithCircuit<F> {
229	pub fn var(index: usize) -> Self {
230		Self {
231			steps: vec![ArithCircuitStep::Var(index)],
232		}
233	}
234
235	pub fn constant(value: F) -> Self {
236		Self {
237			steps: vec![ArithCircuitStep::Const(value)],
238		}
239	}
240
241	pub fn zero() -> Self {
242		Self::constant(F::ZERO)
243	}
244
245	pub fn one() -> Self {
246		Self::constant(F::ONE)
247	}
248
249	pub fn pow(mut self, exp: u64) -> Self {
250		self.steps
251			.push(ArithCircuitStep::Pow(self.steps.len() - 1, exp));
252
253		self
254	}
255
256	/// Steps of the circuit.
257	pub fn steps(&self) -> &[ArithCircuitStep<F>] {
258		&self.steps
259	}
260
261	/// The total degree of the polynomial the expression represents.
262	pub fn degree(&self) -> usize {
263		fn step_degree<F: Field>(step: usize, steps: &[ArithCircuitStep<F>]) -> usize {
264			match steps[step] {
265				ArithCircuitStep::Const(_) => 0,
266				ArithCircuitStep::Var(_) => 1,
267				ArithCircuitStep::Add(left, right) => {
268					step_degree(left, steps).max(step_degree(right, steps))
269				}
270				ArithCircuitStep::Mul(left, right) => {
271					step_degree(left, steps) + step_degree(right, steps)
272				}
273				ArithCircuitStep::Pow(base, exp) => step_degree(base, steps) * (exp as usize),
274			}
275		}
276
277		step_degree(self.steps.len() - 1, &self.steps)
278	}
279
280	/// The number of variables the expression contains.
281	pub fn n_vars(&self) -> usize {
282		self.steps
283			.iter()
284			.map(|step| {
285				if let ArithCircuitStep::Var(index) = step {
286					*index + 1
287				} else {
288					0
289				}
290			})
291			.max()
292			.unwrap_or(0)
293	}
294
295	/// The maximum tower level of the constant terms in the circuit.
296	pub fn binary_tower_level(&self) -> usize
297	where
298		F: TowerField,
299	{
300		self.steps
301			.iter()
302			.map(|step| {
303				if let ArithCircuitStep::Const(value) = step {
304					value.min_tower_level()
305				} else {
306					0
307				}
308			})
309			.max()
310			.unwrap_or(0)
311	}
312
313	/// Returns the evaluation cost of the circuit.
314	pub fn eval_cost(&self) -> EvalCost {
315		self.steps
316			.iter()
317			.fold(EvalCost::default(), |acc, step| match *step {
318				ArithCircuitStep::Const(_) | ArithCircuitStep::Var(_) => acc,
319				ArithCircuitStep::Add(_, _) => acc + EvalCost::add(1),
320				ArithCircuitStep::Mul(_, _) => acc + EvalCost::mul(1),
321				// To exponentiate b^e the square-and-multiply method is used.
322				//
323				// That takes:
324				// 1. log2(e) squarings.
325				// 1. popcount(e) - 1 multiplications.
326				ArithCircuitStep::Pow(_, exp) => {
327					let n_squares = exp.ilog2() as usize;
328					let n_muls = exp.count_ones().saturating_sub(1) as usize;
329					acc + EvalCost::mul(n_muls) + EvalCost::square(n_squares)
330				}
331			})
332	}
333
334	/// Return a new arithmetic expression that contains only the terms of highest degree
335	/// (useful for interpolation at Karatsuba infinity point).
336	pub fn leading_term(&self) -> Self {
337		let (_, expr) = self.leading_term_with_degree(self.steps.len() - 1);
338		expr
339	}
340
341	/// Same as `leading_term`, but returns the total degree as the first tuple element as well.
342	fn leading_term_with_degree(&self, step: usize) -> (usize, Self) {
343		match &self.steps[step] {
344			ArithCircuitStep::Const(value) => (0, Self::constant(*value)),
345			ArithCircuitStep::Var(index) => (1, Self::var(*index)),
346			ArithCircuitStep::Add(left, right) => {
347				let (lhs_degree, lhs) = self.leading_term_with_degree(*left);
348				let (rhs_degree, rhs) = self.leading_term_with_degree(*right);
349				match lhs_degree.cmp(&rhs_degree) {
350					Ordering::Less => (rhs_degree, rhs),
351					Ordering::Equal => (lhs_degree, lhs + rhs),
352					Ordering::Greater => (lhs_degree, lhs),
353				}
354			}
355			ArithCircuitStep::Mul(left, right) => {
356				let (lhs_degree, lhs) = self.leading_term_with_degree(*left);
357				let (rhs_degree, rhs) = self.leading_term_with_degree(*right);
358				(lhs_degree + rhs_degree, lhs * rhs)
359			}
360			ArithCircuitStep::Pow(base, exp) => {
361				let (base_degree, base) = self.leading_term_with_degree(*base);
362				(base_degree * (*exp as usize), base.pow(*exp))
363			}
364		}
365	}
366
367	pub fn evaluate(&self, query: &[F]) -> Result<F, Error> {
368		let mut step_evals = Vec::<F>::with_capacity(self.steps.len());
369		for step in &self.steps {
370			match step {
371				ArithCircuitStep::Add(left, right) => {
372					step_evals.push(step_evals[*left] + step_evals[*right])
373				}
374				ArithCircuitStep::Mul(left, right) => {
375					step_evals.push(step_evals[*left] * step_evals[*right])
376				}
377				ArithCircuitStep::Pow(base, exp) => step_evals.push(step_evals[*base].pow(*exp)),
378				ArithCircuitStep::Const(value) => step_evals.push(*value),
379				ArithCircuitStep::Var(index) => step_evals.push(query[*index]),
380			}
381		}
382		Ok(step_evals.pop().unwrap_or_default())
383	}
384
385	pub fn convert_field<FTgt: Field + From<F>>(&self) -> ArithCircuit<FTgt> {
386		ArithCircuit {
387			steps: self
388				.steps
389				.iter()
390				.map(|step| match step {
391					ArithCircuitStep::Const(value) => ArithCircuitStep::Const((*value).into()),
392					ArithCircuitStep::Var(index) => ArithCircuitStep::Var(*index),
393					ArithCircuitStep::Add(left, right) => ArithCircuitStep::Add(*left, *right),
394					ArithCircuitStep::Mul(left, right) => ArithCircuitStep::Mul(*left, *right),
395					ArithCircuitStep::Pow(base, exp) => ArithCircuitStep::Pow(*base, *exp),
396				})
397				.collect(),
398		}
399	}
400
401	pub fn try_convert_field<FTgt: Field + TryFrom<F>>(
402		&self,
403	) -> Result<ArithCircuit<FTgt>, <FTgt as TryFrom<F>>::Error> {
404		let steps = self
405			.steps
406			.iter()
407			.map(|step| -> Result<ArithCircuitStep<FTgt>, <FTgt as TryFrom<F>>::Error> {
408				let result = match step {
409					ArithCircuitStep::Const(value) => {
410						ArithCircuitStep::Const(FTgt::try_from(*value)?)
411					}
412					ArithCircuitStep::Var(index) => ArithCircuitStep::Var(*index),
413					ArithCircuitStep::Add(left, right) => ArithCircuitStep::Add(*left, *right),
414					ArithCircuitStep::Mul(left, right) => ArithCircuitStep::Mul(*left, *right),
415					ArithCircuitStep::Pow(base, exp) => ArithCircuitStep::Pow(*base, *exp),
416				};
417				Ok(result)
418			})
419			.collect::<Result<Vec<_>, _>>()?;
420
421		Ok(ArithCircuit { steps })
422	}
423
424	/// Creates a new expression with the variable indices remapped.
425	///
426	/// This recursively replaces the variable sub-expressions with an index `i` with the variable
427	/// `indices[i]`.
428	///
429	/// ## Throws
430	///
431	/// * [`Error::IncorrectArgumentLength`] if indices has length less than the current number of
432	///   variables
433	pub fn remap_vars(&self, indices: &[usize]) -> Result<Self, Error> {
434		let steps = self
435			.steps
436			.iter()
437			.map(|step| -> Result<ArithCircuitStep<F>, Error> {
438				if let ArithCircuitStep::Var(index) = step {
439					let new_index = indices.get(*index).copied().ok_or_else(|| {
440						Error::IncorrectArgumentLength {
441							arg: "indices".to_string(),
442							expected: *index,
443						}
444					})?;
445					Ok(ArithCircuitStep::Var(new_index))
446				} else {
447					Ok(*step)
448				}
449			})
450			.collect::<Result<Vec<_>, _>>()?;
451		Ok(Self { steps })
452	}
453
454	/// Substitute variable with index `var` with a constant `value`
455	pub fn const_subst(self, var: usize, value: F) -> Self {
456		let steps = self
457			.steps
458			.iter()
459			.map(|step| match step {
460				ArithCircuitStep::Var(index) if *index == var => ArithCircuitStep::Const(value),
461				_ => *step,
462			})
463			.collect();
464		Self { steps }
465	}
466
467	/// Returns `Some(F)` if the expression is a constant.
468	pub fn get_constant(&self) -> Option<F> {
469		if let ArithCircuitStep::Const(value) =
470			self.steps.last().expect("steps should not be empty")
471		{
472			Some(*value)
473		} else {
474			None
475		}
476	}
477
478	/// Returns the normal form of an expression if it is linear.
479	///
480	/// ## Throws
481	///
482	/// - [`Error::NonLinearExpression`] if the expression is not linear.
483	pub fn linear_normal_form(&self) -> Result<LinearNormalForm<F>, Error> {
484		self.sparse_linear_normal_form().map(Into::into)
485	}
486
487	fn sparse_linear_normal_form(&self) -> Result<SparseLinearNormalForm<F>, Error> {
488		fn sparse_linear_normal_form<F: Field>(
489			step: usize,
490			steps: &[ArithCircuitStep<F>],
491		) -> Result<SparseLinearNormalForm<F>, Error> {
492			match &steps[step] {
493				ArithCircuitStep::Const(val) => Ok((*val).into()),
494				ArithCircuitStep::Var(index) => Ok(SparseLinearNormalForm {
495					constant: F::ZERO,
496					dense_linear_form_len: index + 1,
497					var_coeffs: [(*index, F::ONE)].into(),
498				}),
499				ArithCircuitStep::Add(left, right) => {
500					let left = sparse_linear_normal_form(*left, steps)?;
501					let right = sparse_linear_normal_form(*right, steps)?;
502					Ok(left + right)
503				}
504				ArithCircuitStep::Mul(left, right) => {
505					let left = sparse_linear_normal_form(*left, steps)?;
506					let right = sparse_linear_normal_form(*right, steps)?;
507					left * right
508				}
509				ArithCircuitStep::Pow(_, 0) => Ok(F::ONE.into()),
510				ArithCircuitStep::Pow(expr, 1) => sparse_linear_normal_form(*expr, steps),
511				ArithCircuitStep::Pow(expr, pow) => {
512					let linear_form = sparse_linear_normal_form(*expr, steps)?;
513					if linear_form.dense_linear_form_len != 0 {
514						return Err(Error::NonLinearExpression);
515					}
516					Ok(linear_form.constant.pow(*pow).into())
517				}
518			}
519		}
520
521		sparse_linear_normal_form(self.steps.len() - 1, &self.steps)
522	}
523
524	/// Returns a vector of booleans indicating which variables are used in the expression.
525	///
526	/// The vector is indexed by variable index, and the value at index `i` is `true` if and only
527	/// if the variable is used in the expression.
528	pub fn vars_usage(&self) -> Vec<bool> {
529		let mut usage = vec![false; self.n_vars()];
530
531		for step in &self.steps {
532			if let ArithCircuitStep::Var(index) = step {
533				usage[*index] = true;
534			}
535		}
536
537		usage
538	}
539
540	/// Fold constants in the circuit.
541	fn optimize_constants(&mut self) {
542		for step_index in 0..self.steps.len() {
543			let (prev_steps, curr_steps) = self.steps.split_at_mut(step_index);
544			let curr_step = &mut curr_steps[0];
545			match curr_step {
546				ArithCircuitStep::Const(_) | ArithCircuitStep::Var(_) => {}
547				ArithCircuitStep::Add(left, right) => {
548					match (&prev_steps[*left], &prev_steps[*right]) {
549						(ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => {
550							*curr_step = ArithCircuitStep::Const(*left + *right);
551						}
552						(ArithCircuitStep::Const(left), right) if *left == F::ZERO => {
553							*curr_step = *right;
554						}
555						(left, ArithCircuitStep::Const(right)) if *right == F::ZERO => {
556							*curr_step = *left;
557						}
558						(left, right) if left == right && F::CHARACTERISTIC == 2 => {
559							*curr_step = ArithCircuitStep::Const(F::ZERO);
560						}
561						_ => {}
562					}
563				}
564				ArithCircuitStep::Mul(left, right) => {
565					match (&prev_steps[*left], &prev_steps[*right]) {
566						(ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => {
567							*curr_step = ArithCircuitStep::Const(*left * *right);
568						}
569						(ArithCircuitStep::Const(left), _) if *left == F::ZERO => {
570							*curr_step = ArithCircuitStep::Const(F::ZERO);
571						}
572						(_, ArithCircuitStep::Const(right)) if *right == F::ZERO => {
573							*curr_step = ArithCircuitStep::Const(F::ZERO);
574						}
575						(ArithCircuitStep::Const(left), right) if *left == F::ONE => {
576							*curr_step = *right;
577						}
578						(left, ArithCircuitStep::Const(right)) if *right == F::ONE => {
579							*curr_step = *left;
580						}
581						_ => {}
582					}
583				}
584				ArithCircuitStep::Pow(base, exp) => match prev_steps[*base] {
585					ArithCircuitStep::Const(value) => {
586						*curr_step = ArithCircuitStep::Const(PackedField::pow(value, *exp));
587					}
588					ArithCircuitStep::Pow(base_inner, exp_inner) => {
589						*curr_step = ArithCircuitStep::Pow(base_inner, *exp * exp_inner);
590					}
591					_ => {}
592				},
593			}
594		}
595	}
596
597	/// Deduplicate steps in the circuit by using a single instance of each step.
598	fn deduplicate_steps(&mut self) {
599		let mut step_map = HashMap::new();
600		let mut step_indices = Vec::with_capacity(self.steps.len());
601		for step in 0..self.steps.len() {
602			let node = StepNode {
603				index: step,
604				steps: &self.steps,
605			};
606			match step_map.entry(node) {
607				Entry::Occupied(entry) => {
608					step_indices.push(*entry.get());
609				}
610				Entry::Vacant(entry) => {
611					entry.insert(step);
612					step_indices.push(step);
613				}
614			}
615		}
616
617		for step in &mut self.steps {
618			match step {
619				ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
620					*left = step_indices[*left];
621					*right = step_indices[*right];
622				}
623				ArithCircuitStep::Pow(base, _) => *base = step_indices[*base],
624				_ => (),
625			}
626		}
627	}
628
629	/// Remove the unused steps from the circuit.
630	fn compress_unused_steps(&mut self) {
631		fn mark_used<F: Field>(step: usize, steps: &[ArithCircuitStep<F>], used: &mut [bool]) {
632			if used[step] {
633				return;
634			}
635			used[step] = true;
636			match steps[step] {
637				ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
638					mark_used(left, steps, used);
639					mark_used(right, steps, used);
640				}
641				ArithCircuitStep::Pow(base, _) => mark_used(base, steps, used),
642				_ => (),
643			}
644		}
645
646		let mut used = vec![false; self.steps.len()];
647		mark_used(self.steps.len() - 1, &self.steps, &mut used);
648
649		let mut steps_map = (0..self.steps.len()).collect::<Vec<_>>();
650		let mut target_index = 0;
651		for source_index in 0..self.steps.len() {
652			if used[source_index] {
653				if target_index != source_index {
654					match &mut self.steps[source_index] {
655						ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
656							*left = steps_map[*left];
657							*right = steps_map[*right];
658						}
659						ArithCircuitStep::Pow(base, _) => *base = steps_map[*base],
660						_ => (),
661					}
662
663					steps_map[source_index] = target_index;
664					self.steps[target_index] = self.steps[source_index];
665				}
666
667				target_index += 1;
668			}
669		}
670
671		self.steps.truncate(target_index);
672	}
673
674	/// Optimize the current instance of the circuit:
675	/// - Fold constants
676	/// - Extract common steps
677	/// - Remove unused steps
678	pub fn optimize_in_place(&mut self) {
679		self.optimize_constants();
680		self.deduplicate_steps();
681		self.compress_unused_steps();
682	}
683
684	/// Same as `optimize_in_place`, but returns a new instance of the circuit.
685	pub fn optimize(mut self) -> Self {
686		self.optimize_constants();
687		self.deduplicate_steps();
688		self.compress_unused_steps();
689
690		self
691	}
692}
693
694impl<F: Field> From<&ArithExpr<F>> for ArithCircuit<F> {
695	fn from(expr: &ArithExpr<F>) -> Self {
696		fn visit_node<F: Field>(
697			node: &Arc<ArithExpr<F>>,
698			node_to_index: &mut HashMap<*const ArithExpr<F>, usize>,
699			steps: &mut Vec<ArithCircuitStep<F>>,
700		) -> usize {
701			if let Some(index) = node_to_index.get(&Arc::as_ptr(node)) {
702				return *index;
703			}
704
705			let step = match &**node {
706				ArithExpr::Const(value) => ArithCircuitStep::Const(*value),
707				ArithExpr::Var(index) => ArithCircuitStep::Var(*index),
708				ArithExpr::Add(left, right) => {
709					let left = visit_node(left, node_to_index, steps);
710					let right = visit_node(right, node_to_index, steps);
711					ArithCircuitStep::Add(left, right)
712				}
713				ArithExpr::Mul(left, right) => {
714					let left = visit_node(left, node_to_index, steps);
715					let right = visit_node(right, node_to_index, steps);
716					ArithCircuitStep::Mul(left, right)
717				}
718				ArithExpr::Pow(base, exp) => {
719					let base = visit_node(base, node_to_index, steps);
720					ArithCircuitStep::Pow(base, *exp)
721				}
722			};
723
724			steps.push(step);
725			node_to_index.insert(Arc::as_ptr(node), steps.len() - 1);
726			steps.len() - 1
727		}
728
729		let mut steps = Vec::new();
730		let mut node_to_index = HashMap::new();
731		match expr {
732			ArithExpr::Const(c) => {
733				steps.push(ArithCircuitStep::Const(*c));
734			}
735			ArithExpr::Var(var) => {
736				steps.push(ArithCircuitStep::Var(*var));
737			}
738			ArithExpr::Add(left, right) => {
739				let left = visit_node(left, &mut node_to_index, &mut steps);
740				let right = visit_node(right, &mut node_to_index, &mut steps);
741				steps.push(ArithCircuitStep::Add(left, right));
742			}
743			ArithExpr::Mul(left, right) => {
744				let left = visit_node(left, &mut node_to_index, &mut steps);
745				let right = visit_node(right, &mut node_to_index, &mut steps);
746				steps.push(ArithCircuitStep::Mul(left, right));
747			}
748			ArithExpr::Pow(base, exp) => {
749				let base = visit_node(base, &mut node_to_index, &mut steps);
750				steps.push(ArithCircuitStep::Pow(base, *exp));
751			}
752		}
753
754		Self { steps }
755	}
756}
757
758impl<F: Field> From<ArithExpr<F>> for ArithCircuit<F> {
759	fn from(expr: ArithExpr<F>) -> Self {
760		Self::from(&expr)
761	}
762}
763
764impl<F: Field> From<&ArithCircuit<F>> for ArithExpr<F> {
765	fn from(circuit: &ArithCircuit<F>) -> Self {
766		let mut step_to_node = vec![Option::<Arc<Self>>::None; circuit.steps.len()];
767
768		for step in 0..circuit.steps.len() {
769			let node = match &circuit.steps[step] {
770				ArithCircuitStep::Const(value) => Self::Const(*value),
771				ArithCircuitStep::Var(index) => Self::Var(*index),
772				ArithCircuitStep::Add(left, right) => {
773					let left = step_to_node[*left].clone().expect("step must be present");
774					let right = step_to_node[*right].clone().expect("step must be present");
775					Self::Add(left, right)
776				}
777				ArithCircuitStep::Mul(left, right) => {
778					let left = step_to_node[*left].clone().expect("step must be present");
779					let right = step_to_node[*right].clone().expect("step must be present");
780					Self::Mul(left, right)
781				}
782				ArithCircuitStep::Pow(base, exp) => {
783					let base = step_to_node[*base].clone().expect("step must be present");
784					Self::Pow(base, *exp)
785				}
786			};
787			step_to_node[step] = Some(Arc::new(node));
788		}
789
790		Arc::into_inner(
791			step_to_node
792				.pop()
793				.expect("steps should not be empty")
794				.expect("last step should be initialized"),
795		)
796		.expect("last step must have a single instance")
797	}
798}
799
800impl<F: Field> From<ArithCircuit<F>> for ArithExpr<F> {
801	fn from(circuit: ArithCircuit<F>) -> Self {
802		Self::from(&circuit)
803	}
804}
805
806impl<F: Field> Display for ArithCircuit<F> {
807	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
808		fn display_step<F: Field>(
809			step: usize,
810			steps: &[ArithCircuitStep<F>],
811			f: &mut fmt::Formatter<'_>,
812		) -> Result<(), fmt::Error> {
813			match &steps[step] {
814				ArithCircuitStep::Const(value) => write!(f, "{value}"),
815				ArithCircuitStep::Var(index) => write!(f, "x{index}"),
816				ArithCircuitStep::Add(left, right) => {
817					write!(f, "(")?;
818					display_step(*left, steps, f)?;
819					write!(f, " + ")?;
820					display_step(*right, steps, f)?;
821					write!(f, ")")
822				}
823				ArithCircuitStep::Mul(left, right) => {
824					write!(f, "(")?;
825					display_step(*left, steps, f)?;
826					write!(f, " * ")?;
827					display_step(*right, steps, f)?;
828					write!(f, ")")
829				}
830				ArithCircuitStep::Pow(base, exp) => {
831					write!(f, "(")?;
832					display_step(*base, steps, f)?;
833					write!(f, ")^{exp}")
834				}
835			}
836		}
837
838		display_step(self.steps.len() - 1, &self.steps, f)
839	}
840}
841
842impl<F: Field> PartialEq for ArithCircuit<F> {
843	fn eq(&self, other: &Self) -> bool {
844		StepNode {
845			index: self.steps.len() - 1,
846			steps: &self.steps,
847		} == StepNode {
848			index: other.steps.len() - 1,
849			steps: &other.steps,
850		}
851	}
852}
853
854impl<F: Field> Add for ArithCircuit<F> {
855	type Output = Self;
856
857	fn add(mut self, rhs: Self) -> Self {
858		self += rhs;
859		self
860	}
861}
862
863impl<F: Field> AddAssign for ArithCircuit<F> {
864	fn add_assign(&mut self, mut rhs: Self) {
865		let old_len = self.steps.len();
866		add_offset(&mut rhs.steps, old_len);
867		self.steps.extend(rhs.steps);
868		self.steps
869			.push(ArithCircuitStep::Add(old_len - 1, self.steps.len() - 1));
870	}
871}
872
873impl<F: Field> Sub for ArithCircuit<F> {
874	type Output = Self;
875
876	fn sub(mut self, rhs: Self) -> Self {
877		self -= rhs;
878		self
879	}
880}
881
882impl<F: Field> SubAssign for ArithCircuit<F> {
883	#[allow(clippy::suspicious_op_assign_impl)]
884	fn sub_assign(&mut self, rhs: Self) {
885		*self += rhs;
886	}
887}
888
889impl<F: Field> Mul for ArithCircuit<F> {
890	type Output = Self;
891
892	fn mul(mut self, rhs: Self) -> Self {
893		self *= rhs;
894		self
895	}
896}
897
898impl<F: Field> MulAssign for ArithCircuit<F> {
899	fn mul_assign(&mut self, mut rhs: Self) {
900		let old_len = self.steps.len();
901		add_offset(&mut rhs.steps, old_len);
902		self.steps.extend(rhs.steps);
903		self.steps
904			.push(ArithCircuitStep::Mul(old_len - 1, self.steps.len() - 1));
905	}
906}
907
908impl<F: Field> Sum for ArithCircuit<F> {
909	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
910		iter.fold(Self::zero(), |sum, item| sum + item)
911	}
912}
913
914impl<F: Field> Product for ArithCircuit<F> {
915	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
916		iter.fold(Self::one(), |product, item| product * item)
917	}
918}
919
920fn add_offset<F: Field>(steps: &mut [ArithCircuitStep<F>], offset: usize) {
921	for step in steps.iter_mut() {
922		match step {
923			ArithCircuitStep::Add(left, right) | ArithCircuitStep::Mul(left, right) => {
924				*left += offset;
925				*right += offset;
926			}
927			ArithCircuitStep::Pow(base, _) => {
928				*base += offset;
929			}
930			_ => (),
931		}
932	}
933}
934/// A normal form for a linear expression.
935#[derive(Debug, Default, Clone, PartialEq, Eq)]
936pub struct LinearNormalForm<F: Field> {
937	/// The constant offset of the expression.
938	pub constant: F,
939	/// A vector mapping variable indices to their coefficients.
940	pub var_coeffs: Vec<F>,
941}
942
943struct SparseLinearNormalForm<F: Field> {
944	/// The constant offset of the expression.
945	pub constant: F,
946	/// A map of variable indices to their coefficients.
947	pub var_coeffs: HashMap<usize, F>,
948	/// The `var_coeffs` vector len if converted to [`LinearNormalForm`].
949	/// It is used for optimization of conversion to [`LinearNormalForm`].
950	pub dense_linear_form_len: usize,
951}
952
953impl<F: Field> From<F> for SparseLinearNormalForm<F> {
954	fn from(value: F) -> Self {
955		Self {
956			constant: value,
957			dense_linear_form_len: 0,
958			var_coeffs: HashMap::new(),
959		}
960	}
961}
962
963impl<F: Field> Add for SparseLinearNormalForm<F> {
964	type Output = Self;
965	fn add(self, rhs: Self) -> Self::Output {
966		let (mut result, consumable) = if self.var_coeffs.len() < rhs.var_coeffs.len() {
967			(rhs, self)
968		} else {
969			(self, rhs)
970		};
971		result.constant += consumable.constant;
972		if consumable.dense_linear_form_len > result.dense_linear_form_len {
973			result.dense_linear_form_len = consumable.dense_linear_form_len;
974		}
975
976		for (index, coeff) in consumable.var_coeffs {
977			result
978				.var_coeffs
979				.entry(index)
980				.and_modify(|res_coeff| {
981					*res_coeff += coeff;
982				})
983				.or_insert(coeff);
984		}
985		result
986	}
987}
988
989impl<F: Field> Mul for SparseLinearNormalForm<F> {
990	type Output = Result<Self, Error>;
991	fn mul(self, rhs: Self) -> Result<Self, Error> {
992		if !self.var_coeffs.is_empty() && !rhs.var_coeffs.is_empty() {
993			return Err(Error::NonLinearExpression);
994		}
995		let (mut result, consumable) = if self.var_coeffs.is_empty() {
996			(rhs, self)
997		} else {
998			(self, rhs)
999		};
1000		result.constant *= consumable.constant;
1001		for coeff in result.var_coeffs.values_mut() {
1002			*coeff *= consumable.constant;
1003		}
1004		Ok(result)
1005	}
1006}
1007
1008impl<F: Field> From<SparseLinearNormalForm<F>> for LinearNormalForm<F> {
1009	fn from(value: SparseLinearNormalForm<F>) -> Self {
1010		let mut var_coeffs = vec![F::ZERO; value.dense_linear_form_len];
1011		for (i, coeff) in value.var_coeffs {
1012			var_coeffs[i] = coeff;
1013		}
1014		Self {
1015			constant: value.constant,
1016			var_coeffs,
1017		}
1018	}
1019}
1020
1021/// A helper structure to be used for by-value comparison and hashing of the subexpressions
1022/// regardless if they are in the same ArithCircuit or not.
1023#[derive(Eq)]
1024struct StepNode<'a, F: Field> {
1025	index: usize,
1026	steps: &'a [ArithCircuitStep<F>],
1027}
1028
1029impl<F: Field> StepNode<'_, F> {
1030	fn prev_step(&self, step: usize) -> Self {
1031		StepNode {
1032			index: step,
1033			steps: self.steps,
1034		}
1035	}
1036}
1037
1038impl<F: Field> PartialEq for StepNode<'_, F> {
1039	#[allow(clippy::suspicious_operation_groupings)] // false positive
1040	fn eq(&self, other: &Self) -> bool {
1041		match (&self.steps[self.index], &other.steps[other.index]) {
1042			(ArithCircuitStep::Const(left), ArithCircuitStep::Const(right)) => left == right,
1043			(ArithCircuitStep::Var(left), ArithCircuitStep::Var(right)) => left == right,
1044			(
1045				ArithCircuitStep::Add(left, right),
1046				ArithCircuitStep::Add(other_left, other_right),
1047			)
1048			| (
1049				ArithCircuitStep::Mul(left, right),
1050				ArithCircuitStep::Mul(other_left, other_right),
1051			) => {
1052				self.prev_step(*left) == other.prev_step(*other_left)
1053					&& self.prev_step(*right) == other.prev_step(*other_right)
1054			}
1055			(ArithCircuitStep::Pow(base, exp), ArithCircuitStep::Pow(other_base, other_exp)) => {
1056				self.prev_step(*base) == other.prev_step(*other_base) && exp == other_exp
1057			}
1058			_ => false,
1059		}
1060	}
1061}
1062
1063impl<F: Field> Hash for StepNode<'_, F> {
1064	fn hash<H: Hasher>(&self, state: &mut H) {
1065		match self.steps[self.index] {
1066			ArithCircuitStep::Const(value) => {
1067				0u8.hash(state);
1068				value.hash(state);
1069			}
1070			ArithCircuitStep::Var(index) => {
1071				1u8.hash(state);
1072				index.hash(state);
1073			}
1074			ArithCircuitStep::Add(left, right) => {
1075				2u8.hash(state);
1076				self.prev_step(left).hash(state);
1077				self.prev_step(right).hash(state);
1078			}
1079			ArithCircuitStep::Mul(left, right) => {
1080				3u8.hash(state);
1081				self.prev_step(left).hash(state);
1082				self.prev_step(right).hash(state);
1083			}
1084			ArithCircuitStep::Pow(base, exp) => {
1085				4u8.hash(state);
1086				self.prev_step(base).hash(state);
1087				exp.hash(state);
1088			}
1089		}
1090	}
1091}
1092
1093/// The breakdown of the number of operations it takes to evaluate this circuit.
1094#[derive(Default)]
1095pub struct EvalCost {
1096	/// Number of additions.
1097	///
1098	/// Addition is a relatively cheap operation.
1099	pub n_adds: usize,
1100	/// Number of multiplications.
1101	///
1102	/// Multiplications are a dominating factor in the evaluation cost.
1103	pub n_muls: usize,
1104	/// Number of squares.
1105	///
1106	/// Squares are a relatively cheap operation. Generally we assume that squaring takes the 1/5th
1107	/// of a single multiplication.
1108	pub n_squares: usize,
1109}
1110
1111impl EvalCost {
1112	pub fn add(n_adds: usize) -> Self {
1113		Self {
1114			n_adds,
1115			..Self::default()
1116		}
1117	}
1118
1119	pub fn mul(n_muls: usize) -> Self {
1120		Self {
1121			n_muls,
1122			..Self::default()
1123		}
1124	}
1125
1126	pub fn square(n_squares: usize) -> Self {
1127		Self {
1128			n_squares,
1129			..Self::default()
1130		}
1131	}
1132
1133	/// Returns an approximate cost to evaluate an expression as a single number.
1134	///
1135	/// Since the evaluation time is dominated by the number of multiplications we can roughly
1136	/// gauge the cost by the number of the multiplications. We count the number of squares
1137	/// as 1/5th of a single multiplication.
1138	pub fn mult_cost_approx(&self) -> usize {
1139		self.n_muls + self.n_squares.div_ceil(5)
1140	}
1141}
1142
1143impl Add for EvalCost {
1144	type Output = Self;
1145	fn add(self, other: Self) -> Self {
1146		Self {
1147			n_adds: self.n_adds + other.n_adds,
1148			n_muls: self.n_muls + other.n_muls,
1149			n_squares: self.n_squares + other.n_squares,
1150		}
1151	}
1152}
1153
1154#[cfg(test)]
1155mod tests {
1156	use std::collections::HashSet;
1157
1158	use assert_matches::assert_matches;
1159	use binius_field::{BinaryField, BinaryField128b, BinaryField1b, BinaryField8b};
1160	use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes};
1161
1162	use super::*;
1163
1164	#[test]
1165	fn test_degree_with_pow() {
1166		let expr = ArithCircuit::constant(BinaryField8b::new(6)).pow(7);
1167		assert_eq!(expr.degree(), 0);
1168
1169		let expr: ArithCircuit<BinaryField8b> = ArithCircuit::var(0).pow(7);
1170		assert_eq!(expr.degree(), 7);
1171
1172		let expr: ArithCircuit<BinaryField8b> =
1173			(ArithCircuit::var(0) * ArithCircuit::var(1)).pow(7);
1174		assert_eq!(expr.degree(), 14);
1175	}
1176
1177	#[test]
1178	fn test_n_vars() {
1179		type F = BinaryField8b;
1180		let expr = ArithCircuit::<F>::var(0) * ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR)
1181			+ ArithCircuit::var(2).pow(2);
1182		assert_eq!(expr.n_vars(), 3);
1183	}
1184
1185	#[test]
1186	fn test_leading_term_with_degree() {
1187		let expr = ArithCircuit::var(0)
1188			* (ArithCircuit::var(1)
1189				* ArithCircuit::var(2)
1190				* ArithCircuit::constant(BinaryField8b::MULTIPLICATIVE_GENERATOR)
1191				+ ArithCircuit::var(4))
1192			+ ArithCircuit::var(5).pow(3)
1193			+ ArithCircuit::constant(BinaryField8b::ONE);
1194
1195		let expected_expr = ArithCircuit::var(0)
1196			* ((ArithCircuit::var(1) * ArithCircuit::var(2))
1197				* ArithCircuit::constant(BinaryField8b::MULTIPLICATIVE_GENERATOR))
1198			+ ArithCircuit::var(5).pow(3);
1199
1200		assert_eq!(expr.leading_term_with_degree(expr.steps().len() - 1), (3, expected_expr));
1201	}
1202
1203	#[test]
1204	fn test_remap_vars_with_too_few_vars() {
1205		type F = BinaryField8b;
1206		let expr =
1207			((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(1)).pow(3);
1208		assert_matches!(expr.remap_vars(&[5]), Err(Error::IncorrectArgumentLength { .. }));
1209	}
1210
1211	#[test]
1212	fn test_remap_vars_works() {
1213		type F = BinaryField8b;
1214		let expr =
1215			((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(1)).pow(3);
1216		let new_expr = expr.remap_vars(&[5, 3]);
1217
1218		let expected =
1219			((ArithCircuit::var(5) + ArithCircuit::constant(F::ONE)) * ArithCircuit::var(3)).pow(3);
1220		assert_eq!(new_expr.unwrap(), expected);
1221	}
1222
1223	#[test]
1224	fn test_optimize_identity_handling() {
1225		type F = BinaryField8b;
1226		let zero = ArithCircuit::<F>::zero();
1227		let one = ArithCircuit::<F>::one();
1228
1229		assert_eq!((zero.clone() * ArithCircuit::<F>::var(0)).optimize(), zero);
1230		assert_eq!((ArithCircuit::<F>::var(0) * zero.clone()).optimize(), zero);
1231
1232		assert_eq!((ArithCircuit::<F>::var(0) * one.clone()).optimize(), ArithCircuit::var(0));
1233		assert_eq!((one * ArithCircuit::<F>::var(0)).optimize(), ArithCircuit::var(0));
1234
1235		assert_eq!((ArithCircuit::<F>::var(0) + zero.clone()).optimize(), ArithCircuit::var(0));
1236		assert_eq!((zero.clone() + ArithCircuit::<F>::var(0)).optimize(), ArithCircuit::var(0));
1237
1238		assert_eq!((ArithCircuit::<F>::var(0) + ArithCircuit::var(0)).optimize(), zero);
1239	}
1240
1241	#[test]
1242	fn test_const_subst_and_optimize() {
1243		// NB: this is FlushSumcheckComposition from the constraint_system
1244		type F = BinaryField8b;
1245		let expr = ArithCircuit::var(0) * ArithCircuit::var(1) + ArithCircuit::one()
1246			- ArithCircuit::var(1);
1247		assert_eq!(expr.const_subst(1, F::ZERO).optimize().get_constant(), Some(F::ONE));
1248	}
1249
1250	#[test]
1251	fn test_expression_upcast() {
1252		type F8 = BinaryField8b;
1253		type F = BinaryField128b;
1254
1255		let expr = ((ArithCircuit::var(0) + ArithCircuit::constant(F8::ONE))
1256			* ArithCircuit::constant(F8::new(222)))
1257		.pow(3);
1258
1259		let expected = ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1260			* ArithCircuit::constant(F::new(222)))
1261		.pow(3);
1262		assert_eq!(expr.convert_field::<F>(), expected);
1263	}
1264
1265	#[test]
1266	fn test_expression_downcast() {
1267		type F8 = BinaryField8b;
1268		type F = BinaryField128b;
1269
1270		let expr = ((ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1271			* ArithCircuit::constant(F::new(222)))
1272		.pow(3);
1273
1274		assert!(expr.try_convert_field::<BinaryField1b>().is_err());
1275
1276		let expected = ((ArithCircuit::var(0) + ArithCircuit::constant(F8::ONE))
1277			* ArithCircuit::constant(F8::new(222)))
1278		.pow(3);
1279		assert_eq!(expr.try_convert_field::<BinaryField8b>().unwrap(), expected);
1280	}
1281
1282	#[test]
1283	fn test_linear_normal_form() {
1284		type F = BinaryField128b;
1285		struct Case {
1286			expr: ArithCircuit<F>,
1287			expected: LinearNormalForm<F>,
1288		}
1289		let cases = vec![
1290			Case {
1291				expr: ArithCircuit::constant(F::ONE),
1292				expected: LinearNormalForm {
1293					constant: F::ONE,
1294					var_coeffs: vec![],
1295				},
1296			},
1297			Case {
1298				expr: (ArithCircuit::constant(F::new(2)) * ArithCircuit::constant(F::new(3)))
1299					.pow(2) + ArithCircuit::constant(F::new(3))
1300					* (ArithCircuit::constant(F::new(4)) + ArithCircuit::var(0)),
1301				expected: LinearNormalForm {
1302					constant: (F::new(2) * F::new(3)).pow(2) + F::new(3) * F::new(4),
1303					var_coeffs: vec![F::new(3)],
1304				},
1305			},
1306			Case {
1307				expr: ArithCircuit::constant(F::new(133))
1308					+ ArithCircuit::constant(F::new(42)) * ArithCircuit::var(0)
1309					+ ArithCircuit::var(2)
1310					+ ArithCircuit::constant(F::new(11))
1311						* ArithCircuit::constant(F::new(37))
1312						* ArithCircuit::var(3),
1313				expected: LinearNormalForm {
1314					constant: F::new(133),
1315					var_coeffs: vec![F::new(42), F::ZERO, F::ONE, F::new(11) * F::new(37)],
1316				},
1317			},
1318		];
1319		for Case { expr, expected } in cases {
1320			let normal_form = expr.linear_normal_form().unwrap();
1321			assert_eq!(normal_form.constant, expected.constant);
1322			assert_eq!(normal_form.var_coeffs, expected.var_coeffs);
1323		}
1324	}
1325
1326	fn unique_nodes_count<F: Field>(expr: &ArithCircuit<F>) -> usize {
1327		let mut unique_nodes = HashSet::new();
1328
1329		for step in 0..expr.steps.len() {
1330			unique_nodes.insert(StepNode {
1331				index: step,
1332				steps: &expr.steps,
1333			});
1334		}
1335
1336		unique_nodes.len()
1337	}
1338
1339	fn check_serialize_bytes_roundtrip<F: Field>(expr: ArithCircuit<F>) {
1340		let mut buf = Vec::new();
1341
1342		expr.serialize(&mut buf, SerializationMode::CanonicalTower)
1343			.unwrap();
1344		let deserialized =
1345			ArithCircuit::<F>::deserialize(&buf[..], SerializationMode::CanonicalTower).unwrap();
1346		assert_eq!(expr, deserialized);
1347		assert_eq!(unique_nodes_count(&expr), unique_nodes_count(&deserialized));
1348	}
1349
1350	#[test]
1351	fn test_serialize_bytes_roundtrip() {
1352		type F = BinaryField128b;
1353		let expr = ArithCircuit::var(0)
1354			* (ArithCircuit::var(1)
1355				* ArithCircuit::var(2)
1356				* ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR)
1357				+ ArithCircuit::var(4))
1358			+ ArithCircuit::var(5).pow(3)
1359			+ ArithCircuit::constant(F::ONE);
1360
1361		check_serialize_bytes_roundtrip(expr);
1362	}
1363
1364	#[test]
1365	fn test_serialize_bytes_rountrip_with_duplicates() {
1366		type F = BinaryField128b;
1367		let expr = (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1368			* (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1369			+ (ArithCircuit::var(0) + ArithCircuit::constant(F::ONE))
1370			+ ArithCircuit::var(1);
1371
1372		check_serialize_bytes_roundtrip(expr);
1373	}
1374
1375	#[test]
1376	fn test_binary_tower_level() {
1377		type F = BinaryField128b;
1378		let expr =
1379			ArithCircuit::constant(F::ONE) + ArithCircuit::constant(F::MULTIPLICATIVE_GENERATOR);
1380		assert_eq!(expr.binary_tower_level(), F::MULTIPLICATIVE_GENERATOR.min_tower_level());
1381	}
1382
1383	#[test]
1384	fn test_arith_circuit_steps() {
1385		type F = BinaryField8b;
1386		let expr = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1)) * ArithCircuit::var(2);
1387		let steps = expr.steps();
1388		assert_eq!(steps.len(), 5); // 3 variables, 1 addition, 1 multiplication
1389		assert!(matches!(steps[0], ArithCircuitStep::Var(0)));
1390		assert!(matches!(steps[1], ArithCircuitStep::Var(1)));
1391		assert!(matches!(steps[2], ArithCircuitStep::Add(_, _)));
1392		assert!(matches!(steps[3], ArithCircuitStep::Var(2)));
1393		assert!(matches!(steps[4], ArithCircuitStep::Mul(_, _)));
1394	}
1395
1396	#[test]
1397	fn test_optimize_constants() {
1398		type F = BinaryField8b;
1399		let mut circuit = (ArithCircuit::<F>::var(0) + ArithCircuit::constant(F::ZERO))
1400			* ArithCircuit::var(1)
1401			+ ArithCircuit::constant(F::ONE) * ArithCircuit::var(2)
1402			+ ArithCircuit::constant(F::ONE).pow(4).pow(5)
1403			+ (ArithCircuit::var(5) + ArithCircuit::var(5));
1404		circuit.optimize_constants();
1405
1406		let expected_ciruit = ArithCircuit::var(0) * ArithCircuit::var(1)
1407			+ ArithCircuit::var(2)
1408			+ ArithCircuit::constant(F::ONE);
1409
1410		assert_eq!(circuit, expected_ciruit);
1411	}
1412
1413	#[test]
1414	fn test_deduplicate_steps() {
1415		type F = BinaryField8b;
1416		let mut circuit = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1))
1417			* (ArithCircuit::var(0) + ArithCircuit::var(1))
1418			+ (ArithCircuit::var(0) + ArithCircuit::var(1));
1419		circuit.deduplicate_steps();
1420
1421		let expected_circuit = ArithCircuit::<F> {
1422			steps: vec![
1423				ArithCircuitStep::Var(0),
1424				ArithCircuitStep::Var(1),
1425				ArithCircuitStep::Add(0, 1),
1426				ArithCircuitStep::Mul(2, 2),
1427				ArithCircuitStep::Add(3, 2),
1428			],
1429		};
1430		assert_eq!(circuit, expected_circuit);
1431	}
1432
1433	#[test]
1434	fn test_compress_unused_steps() {
1435		type F = BinaryField8b;
1436		let mut circuit = ArithCircuit::<F> {
1437			steps: vec![
1438				ArithCircuitStep::Var(0),
1439				ArithCircuitStep::Var(1),
1440				ArithCircuitStep::Var(2),
1441				ArithCircuitStep::Add(0, 1),
1442				ArithCircuitStep::Var(3),
1443				ArithCircuitStep::Const(F::ZERO),
1444				ArithCircuitStep::Var(2),
1445				ArithCircuitStep::Mul(3, 3),
1446			],
1447		};
1448		circuit.compress_unused_steps();
1449
1450		let expected_circuit = ArithCircuit::<F> {
1451			steps: vec![
1452				ArithCircuitStep::Var(0),
1453				ArithCircuitStep::Var(1),
1454				ArithCircuitStep::Add(0, 1),
1455				ArithCircuitStep::Mul(2, 2),
1456			],
1457		};
1458		assert_eq!(circuit.steps, expected_circuit.steps);
1459	}
1460
1461	#[test]
1462	fn test_conversion_from_expr_node_doesnt_create_duplicated_steps() {
1463		type F = BinaryField8b;
1464		let sub_expr = Arc::new(ArithExpr::<F>::Var(0) + ArithExpr::<F>::Var(1));
1465		let expr = ArithExpr::Mul(sub_expr.clone(), sub_expr.clone()) + sub_expr;
1466		let circuit = ArithCircuit::<F>::from(&expr);
1467		assert_eq!(circuit.steps.len(), 5);
1468		assert_eq!(unique_nodes_count(&circuit), 5);
1469	}
1470
1471	fn unique_nodes_count_expr<F: Field>(expr: &ArithExpr<F>) -> usize {
1472		fn add_nodes<F: Field>(
1473			expr: &Arc<ArithExpr<F>>,
1474			unique_nodes: &mut HashSet<*const ArithExpr<F>>,
1475		) {
1476			if !unique_nodes.insert(Arc::as_ptr(expr)) {
1477				return;
1478			}
1479
1480			match expr.as_ref() {
1481				ArithExpr::Const(_) | ArithExpr::Var(_) => {}
1482				ArithExpr::Add(left, right) | ArithExpr::Mul(left, right) => {
1483					add_nodes(left, unique_nodes);
1484					add_nodes(right, unique_nodes);
1485				}
1486				ArithExpr::Pow(base, _) => {
1487					add_nodes(base, unique_nodes);
1488				}
1489			}
1490		}
1491
1492		let mut unique_nodes = HashSet::new();
1493		add_nodes(&Arc::new(expr.clone()), &mut unique_nodes);
1494
1495		unique_nodes.len()
1496	}
1497
1498	#[test]
1499	fn test_conversion_from_circuit_to_expr_node() {
1500		type F = BinaryField8b;
1501
1502		let arith_circuit = ArithCircuit::<F> {
1503			steps: vec![
1504				ArithCircuitStep::Var(0),
1505				ArithCircuitStep::Var(1),
1506				ArithCircuitStep::Add(0, 1),
1507				ArithCircuitStep::Mul(2, 2),
1508				ArithCircuitStep::Add(3, 2),
1509			],
1510		};
1511		let expr = ArithExpr::from(&arith_circuit);
1512		let expected_expr = (ArithExpr::Var(0) + ArithExpr::Var(1))
1513			* (ArithExpr::Var(0) + ArithExpr::Var(1))
1514			+ (ArithExpr::Var(0) + ArithExpr::Var(1));
1515		assert_eq!(expr, expected_expr);
1516		assert_eq!(unique_nodes_count_expr(&expr), 5);
1517	}
1518
1519	#[test]
1520	fn test_evaluate() {
1521		type F = BinaryField8b;
1522		let expr = (ArithCircuit::<F>::var(0) + ArithCircuit::var(1))
1523			* (ArithCircuit::var(2) + ArithCircuit::var(3)).pow(5);
1524		let result = expr
1525			.evaluate(&[F::new(2), F::new(3), F::new(4), F::new(5)])
1526			.unwrap();
1527		assert_eq!(result, F::new(2) + F::new(3) * (F::new(4) + F::new(5)).pow(5));
1528	}
1529}