binius_math/
arith_expr.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	cmp::Ordering,
5	fmt::{self, Display},
6	iter::{Product, Sum},
7	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_field::{Field, PackedField, TowerField};
11use binius_macros::{DeserializeBytes, SerializeBytes};
12
13use super::error::Error;
14
15/// Arithmetic expressions that can be evaluated symbolically.
16///
17/// Arithmetic expressions are trees, where the leaves are either constants or variables, and the
18/// non-leaf nodes are arithmetic operations, such as addition, multiplication, etc. They are
19/// specific representations of multivariate polynomials.
20#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
21pub enum ArithExpr<F: Field> {
22	Const(F),
23	Var(usize),
24	Add(Box<ArithExpr<F>>, Box<ArithExpr<F>>),
25	Mul(Box<ArithExpr<F>>, Box<ArithExpr<F>>),
26	Pow(Box<ArithExpr<F>>, u64),
27}
28
29impl<F: Field + Display> Display for ArithExpr<F> {
30	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31		match self {
32			Self::Const(v) => write!(f, "{v}"),
33			Self::Var(i) => write!(f, "x{i}"),
34			Self::Add(x, y) => write!(f, "({} + {})", &**x, &**y),
35			Self::Mul(x, y) => write!(f, "({} * {})", &**x, &**y),
36			Self::Pow(x, p) => write!(f, "({})^{p}", &**x),
37		}
38	}
39}
40
41impl<F: Field> ArithExpr<F> {
42	/// The number of variables the expression contains.
43	pub fn n_vars(&self) -> usize {
44		match self {
45			Self::Const(_) => 0,
46			Self::Var(index) => *index + 1,
47			Self::Add(left, right) | Self::Mul(left, right) => left.n_vars().max(right.n_vars()),
48			Self::Pow(id, _) => id.n_vars(),
49		}
50	}
51
52	/// The total degree of the polynomial the expression represents.
53	pub fn degree(&self) -> usize {
54		match self {
55			Self::Const(_) => 0,
56			Self::Var(_) => 1,
57			Self::Add(left, right) => left.degree().max(right.degree()),
58			Self::Mul(left, right) => left.degree() + right.degree(),
59			Self::Pow(base, exp) => base.degree() * *exp as usize,
60		}
61	}
62
63	/// Return a new arithmetic expression that contains only the terms of highest degree
64	/// (useful for interpolation at Karatsuba infinity point).
65	pub fn leading_term(&self) -> Self {
66		let (_, expr) = self.leading_term_with_degree();
67		expr
68	}
69
70	/// Same as `leading_term`, but returns the total degree as the first tuple element as well.
71	pub fn leading_term_with_degree(&self) -> (usize, Self) {
72		match self {
73			expr @ Self::Const(_) => (0, expr.clone()),
74			expr @ Self::Var(_) => (1, expr.clone()),
75			Self::Add(left, right) => {
76				let (lhs_degree, lhs) = left.leading_term_with_degree();
77				let (rhs_degree, rhs) = right.leading_term_with_degree();
78				match lhs_degree.cmp(&rhs_degree) {
79					Ordering::Less => (rhs_degree, rhs),
80					Ordering::Equal => (lhs_degree, Self::Add(Box::new(lhs), Box::new(rhs))),
81					Ordering::Greater => (lhs_degree, lhs),
82				}
83			}
84			Self::Mul(left, right) => {
85				let (lhs_degree, lhs) = left.leading_term_with_degree();
86				let (rhs_degree, rhs) = right.leading_term_with_degree();
87				(lhs_degree + rhs_degree, Self::Mul(Box::new(lhs), Box::new(rhs)))
88			}
89			Self::Pow(base, exp) => {
90				let (base_degree, base) = base.leading_term_with_degree();
91				(base_degree * *exp as usize, Self::Pow(Box::new(base), *exp))
92			}
93		}
94	}
95
96	pub fn pow(self, exp: u64) -> Self {
97		Self::Pow(Box::new(self), exp)
98	}
99
100	pub const fn zero() -> Self {
101		Self::Const(F::ZERO)
102	}
103
104	pub const fn one() -> Self {
105		Self::Const(F::ONE)
106	}
107
108	/// Creates a new expression with the variable indices remapped.
109	///
110	/// This recursively replaces the variable sub-expressions with an index `i` with the variable
111	/// `indices[i]`.
112	///
113	/// ## Throws
114	///
115	/// * [`Error::IncorrectArgumentLength`] if indices has length less than the current number of
116	///   variables
117	pub fn remap_vars(self, indices: &[usize]) -> Result<Self, Error> {
118		let expr = match self {
119			Self::Const(_) => self,
120			Self::Var(index) => {
121				let new_index =
122					indices
123						.get(index)
124						.ok_or_else(|| Error::IncorrectArgumentLength {
125							arg: "subset".to_string(),
126							expected: index,
127						})?;
128				Self::Var(*new_index)
129			}
130			Self::Add(left, right) => {
131				let new_left = left.remap_vars(indices)?;
132				let new_right = right.remap_vars(indices)?;
133				Self::Add(Box::new(new_left), Box::new(new_right))
134			}
135			Self::Mul(left, right) => {
136				let new_left = left.remap_vars(indices)?;
137				let new_right = right.remap_vars(indices)?;
138				Self::Mul(Box::new(new_left), Box::new(new_right))
139			}
140			Self::Pow(base, exp) => {
141				let new_base = base.remap_vars(indices)?;
142				Self::Pow(Box::new(new_base), exp)
143			}
144		};
145		Ok(expr)
146	}
147
148	/// Substitute variable with index `var` with a constant `value`
149	pub fn const_subst(self, var: usize, value: F) -> Self {
150		match self {
151			Self::Const(_) => self,
152			Self::Var(index) => {
153				if index == var {
154					Self::Const(value)
155				} else {
156					self
157				}
158			}
159			Self::Add(left, right) => {
160				let new_left = left.const_subst(var, value);
161				let new_right = right.const_subst(var, value);
162				Self::Add(Box::new(new_left), Box::new(new_right))
163			}
164			Self::Mul(left, right) => {
165				let new_left = left.const_subst(var, value);
166				let new_right = right.const_subst(var, value);
167				Self::Mul(Box::new(new_left), Box::new(new_right))
168			}
169			Self::Pow(base, exp) => {
170				let new_base = base.const_subst(var, value);
171				Self::Pow(Box::new(new_base), exp)
172			}
173		}
174	}
175
176	pub fn convert_field<FTgt: Field + From<F>>(&self) -> ArithExpr<FTgt> {
177		match self {
178			Self::Const(val) => ArithExpr::Const((*val).into()),
179			Self::Var(index) => ArithExpr::Var(*index),
180			Self::Add(left, right) => {
181				let new_left = left.convert_field();
182				let new_right = right.convert_field();
183				ArithExpr::Add(Box::new(new_left), Box::new(new_right))
184			}
185			Self::Mul(left, right) => {
186				let new_left = left.convert_field();
187				let new_right = right.convert_field();
188				ArithExpr::Mul(Box::new(new_left), Box::new(new_right))
189			}
190			Self::Pow(base, exp) => {
191				let new_base = base.convert_field();
192				ArithExpr::Pow(Box::new(new_base), *exp)
193			}
194		}
195	}
196
197	pub fn try_convert_field<FTgt: Field + TryFrom<F>>(
198		&self,
199	) -> Result<ArithExpr<FTgt>, <FTgt as TryFrom<F>>::Error> {
200		Ok(match self {
201			Self::Const(val) => ArithExpr::Const(FTgt::try_from(*val)?),
202			Self::Var(index) => ArithExpr::Var(*index),
203			Self::Add(left, right) => {
204				let new_left = left.try_convert_field()?;
205				let new_right = right.try_convert_field()?;
206				ArithExpr::Add(Box::new(new_left), Box::new(new_right))
207			}
208			Self::Mul(left, right) => {
209				let new_left = left.try_convert_field()?;
210				let new_right = right.try_convert_field()?;
211				ArithExpr::Mul(Box::new(new_left), Box::new(new_right))
212			}
213			Self::Pow(base, exp) => {
214				let new_base = base.try_convert_field()?;
215				ArithExpr::Pow(Box::new(new_base), *exp)
216			}
217		})
218	}
219
220	/// Whether expression is a composite node, and not a leaf.
221	pub const fn is_composite(&self) -> bool {
222		match self {
223			Self::Const(_) | Self::Var(_) => false,
224			Self::Add(_, _) | Self::Mul(_, _) | Self::Pow(_, _) => true,
225		}
226	}
227
228	/// Returns `Some(F)` if the expression is a constant.
229	pub const fn constant(&self) -> Option<F> {
230		match self {
231			Self::Const(value) => Some(*value),
232			_ => None,
233		}
234	}
235
236	/// Creates a new optimized expression.
237	///
238	/// Recursively rewrites the expression for better evaluation performance. Performs constant folding,
239	/// as well as leverages simple rewriting rules around additive/multiplicative identities and addition
240	/// in characteristic 2.
241	pub fn optimize(&self) -> Self {
242		match self {
243			Self::Const(_) | Self::Var(_) => self.clone(),
244			Self::Add(left, right) => {
245				let left = left.optimize();
246				let right = right.optimize();
247				match (left, right) {
248					// constant folding
249					(Self::Const(left), Self::Const(right)) => Self::Const(left + right),
250					// 0 + a = a + 0 = a
251					(Self::Const(left), right) if left == F::ZERO => right,
252					(left, Self::Const(right)) if right == F::ZERO => left,
253					// a + a = 0 in char 2
254					// REVIEW: relies on precise structural equality, find a better way
255					(left, right) if left == right && F::CHARACTERISTIC == 2 => {
256						Self::Const(F::ZERO)
257					}
258					// fallback
259					(left, right) => Self::Add(Box::new(left), Box::new(right)),
260				}
261			}
262			Self::Mul(left, right) => {
263				let left = left.optimize();
264				let right = right.optimize();
265				match (left, right) {
266					// constant folding
267					(Self::Const(left), Self::Const(right)) => Self::Const(left * right),
268					// 0 * a = a * 0 = 0
269					(left, right)
270						if left == Self::Const(F::ZERO) || right == Self::Const(F::ZERO) =>
271					{
272						Self::Const(F::ZERO)
273					}
274					// 1 * a = a * 1 = a
275					(Self::Const(left), right) if left == F::ONE => right,
276					(left, Self::Const(right)) if right == F::ONE => left,
277					// fallback
278					(left, right) => Self::Mul(Box::new(left), Box::new(right)),
279				}
280			}
281			Self::Pow(id, exp) => {
282				let id = id.optimize();
283				match id {
284					Self::Const(value) => Self::Const(PackedField::pow(value, *exp)),
285					Self::Pow(id_inner, exp_inner) => Self::Pow(id_inner, *exp * exp_inner),
286					id => Self::Pow(Box::new(id), *exp),
287				}
288			}
289		}
290	}
291
292	/// Returns the normal form of an expression if it is linear.
293	///
294	/// ## Throws
295	///
296	/// - [`Error::NonLinearExpression`] if the expression is not linear.
297	pub fn linear_normal_form(&self) -> Result<LinearNormalForm<F>, Error> {
298		if self.degree() > 1 {
299			return Err(Error::NonLinearExpression);
300		}
301
302		let n_vars = self.n_vars();
303
304		// Linear normal form: f(x0, x1, ... x{n-1}) = c + a0*x0 + a1*x1 + ... + a{n-1}*x{n-1}
305		// Evaluating with all variables set to 0, should give the constant term
306		let constant = self.evaluate(&vec![F::ZERO; n_vars]);
307
308		// Evaluating with x{k} set to 1 and all other x{i} set to 0, gives us `constant + a{k}`
309		// That means we can subtract the constant from the evaluated expression to get the coefficient a{k}
310		let var_coeffs = (0..n_vars)
311			.map(|i| {
312				let mut vars = vec![F::ZERO; n_vars];
313				vars[i] = F::ONE;
314				self.evaluate(&vars) - constant
315			})
316			.collect();
317		Ok(LinearNormalForm {
318			constant,
319			var_coeffs,
320		})
321	}
322
323	fn evaluate(&self, vars: &[F]) -> F {
324		match self {
325			Self::Const(val) => *val,
326			Self::Var(index) => vars[*index],
327			Self::Add(left, right) => left.evaluate(vars) + right.evaluate(vars),
328			Self::Mul(left, right) => left.evaluate(vars) * right.evaluate(vars),
329			Self::Pow(base, exp) => base.evaluate(vars).pow(*exp),
330		}
331	}
332
333	/// Returns a vector of booleans indicating which variables are used in the expression.
334	///
335	/// The vector is indexed by variable index, and the value at index `i` is `true` if and only
336	/// if the variable is used in the expression.
337	pub fn vars_usage(&self) -> Vec<bool> {
338		let mut usage = vec![false; self.n_vars()];
339		self.mark_vars_usage(&mut usage);
340		usage
341	}
342
343	fn mark_vars_usage(&self, usage: &mut [bool]) {
344		match self {
345			Self::Const(_) => (),
346			Self::Var(index) => usage[*index] = true,
347			Self::Add(left, right) | Self::Mul(left, right) => {
348				left.mark_vars_usage(usage);
349				right.mark_vars_usage(usage);
350			}
351			Self::Pow(base, _) => base.mark_vars_usage(usage),
352		}
353	}
354}
355
356impl<F: TowerField> ArithExpr<F> {
357	pub fn binary_tower_level(&self) -> usize {
358		match self {
359			Self::Const(value) => value.min_tower_level(),
360			Self::Var(_) => 0,
361			Self::Add(left, right) | Self::Mul(left, right) => {
362				left.binary_tower_level().max(right.binary_tower_level())
363			}
364			Self::Pow(base, _) => base.binary_tower_level(),
365		}
366	}
367}
368
369impl<F> Default for ArithExpr<F>
370where
371	F: Field,
372{
373	fn default() -> Self {
374		Self::zero()
375	}
376}
377
378impl<F> Add for ArithExpr<F>
379where
380	F: Field,
381{
382	type Output = Self;
383
384	fn add(self, rhs: Self) -> Self {
385		Self::Add(Box::new(self), Box::new(rhs))
386	}
387}
388
389impl<F> AddAssign for ArithExpr<F>
390where
391	F: Field,
392{
393	fn add_assign(&mut self, rhs: Self) {
394		*self = std::mem::take(self) + rhs;
395	}
396}
397
398impl<F> Sub for ArithExpr<F>
399where
400	F: Field,
401{
402	type Output = Self;
403
404	fn sub(self, rhs: Self) -> Self {
405		Self::Add(Box::new(self), Box::new(rhs))
406	}
407}
408
409impl<F> SubAssign for ArithExpr<F>
410where
411	F: Field,
412{
413	fn sub_assign(&mut self, rhs: Self) {
414		*self = std::mem::take(self) - rhs;
415	}
416}
417
418impl<F> Mul for ArithExpr<F>
419where
420	F: Field,
421{
422	type Output = Self;
423
424	fn mul(self, rhs: Self) -> Self {
425		Self::Mul(Box::new(self), Box::new(rhs))
426	}
427}
428
429impl<F> MulAssign for ArithExpr<F>
430where
431	F: Field,
432{
433	fn mul_assign(&mut self, rhs: Self) {
434		*self = std::mem::take(self) * rhs;
435	}
436}
437
438impl<F: Field> Sum for ArithExpr<F> {
439	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
440		iter.reduce(|acc, item| acc + item).unwrap_or(Self::zero())
441	}
442}
443
444impl<F: Field> Product for ArithExpr<F> {
445	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
446		iter.reduce(|acc, item| acc * item).unwrap_or(Self::one())
447	}
448}
449
450/// A normal form for a linear expression.
451#[derive(Debug, Default, Clone, PartialEq, Eq)]
452pub struct LinearNormalForm<F: Field> {
453	/// The constant offset of the expression.
454	pub constant: F,
455	/// A vector mapping variable indices to their coefficients.
456	pub var_coeffs: Vec<F>,
457}
458
459#[cfg(test)]
460mod tests {
461	use assert_matches::assert_matches;
462	use binius_field::{BinaryField, BinaryField128b, BinaryField1b, BinaryField8b};
463
464	use super::*;
465
466	#[test]
467	fn test_degree_with_pow() {
468		let expr = ArithExpr::Const(BinaryField8b::new(6)).pow(7);
469		assert_eq!(expr.degree(), 0);
470
471		let expr: ArithExpr<BinaryField8b> = ArithExpr::Var(0).pow(7);
472		assert_eq!(expr.degree(), 7);
473
474		let expr: ArithExpr<BinaryField8b> = (ArithExpr::Var(0) * ArithExpr::Var(1)).pow(7);
475		assert_eq!(expr.degree(), 14);
476	}
477
478	#[test]
479	fn test_leading_term_with_degree() {
480		let expr = ArithExpr::Var(0)
481			* (ArithExpr::Var(1)
482				* ArithExpr::Var(2)
483				* ArithExpr::Const(BinaryField8b::MULTIPLICATIVE_GENERATOR)
484				+ ArithExpr::Var(4))
485			+ ArithExpr::Var(5).pow(3)
486			+ ArithExpr::Const(BinaryField8b::ONE);
487
488		let expected_expr = ArithExpr::Var(0)
489			* ((ArithExpr::Var(1) * ArithExpr::Var(2))
490				* ArithExpr::Const(BinaryField8b::MULTIPLICATIVE_GENERATOR))
491			+ ArithExpr::Var(5).pow(3);
492
493		assert_eq!(expr.leading_term_with_degree(), (3, expected_expr));
494	}
495
496	#[test]
497	fn test_remap_vars_with_too_few_vars() {
498		type F = BinaryField8b;
499		let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(1)).pow(3);
500		assert_matches!(expr.remap_vars(&[5]), Err(Error::IncorrectArgumentLength { .. }));
501	}
502
503	#[test]
504	fn test_remap_vars_works() {
505		type F = BinaryField8b;
506		let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(1)).pow(3);
507		let new_expr = expr.remap_vars(&[5, 3]);
508
509		let expected = ((ArithExpr::Var(5) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(3)).pow(3);
510		assert_eq!(new_expr.unwrap(), expected);
511	}
512
513	#[test]
514	fn test_optimize_identity_handling() {
515		type F = BinaryField8b;
516		let zero = ArithExpr::<F>::zero();
517		let one = ArithExpr::<F>::one();
518
519		assert_eq!((zero.clone() * ArithExpr::<F>::Var(0)).optimize(), zero);
520		assert_eq!((ArithExpr::<F>::Var(0) * zero.clone()).optimize(), zero);
521
522		assert_eq!((ArithExpr::<F>::Var(0) * one.clone()).optimize(), ArithExpr::Var(0));
523		assert_eq!((one * ArithExpr::<F>::Var(0)).optimize(), ArithExpr::Var(0));
524
525		assert_eq!((ArithExpr::<F>::Var(0) + zero.clone()).optimize(), ArithExpr::Var(0));
526		assert_eq!((zero.clone() + ArithExpr::<F>::Var(0)).optimize(), ArithExpr::Var(0));
527
528		assert_eq!((ArithExpr::<F>::Var(0) + ArithExpr::Var(0)).optimize(), zero);
529	}
530
531	#[test]
532	fn test_const_subst_and_optimize() {
533		// NB: this is FlushSumcheckComposition from the constraint_system
534		type F = BinaryField8b;
535		let expr = ArithExpr::Var(0) * ArithExpr::Var(1) + ArithExpr::one() - ArithExpr::Var(1);
536		assert_eq!(expr.const_subst(1, F::ZERO).optimize().constant(), Some(F::ONE));
537	}
538
539	#[test]
540	fn test_expression_upcast() {
541		type F8 = BinaryField8b;
542		type F = BinaryField128b;
543
544		let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F8::ONE))
545			* ArithExpr::Const(F8::new(222)))
546		.pow(3);
547
548		let expected =
549			((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Const(F::new(222))).pow(3);
550		assert_eq!(expr.convert_field::<F>(), expected);
551	}
552
553	#[test]
554	fn test_expression_downcast() {
555		type F8 = BinaryField8b;
556		type F = BinaryField128b;
557
558		let expr =
559			((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Const(F::new(222))).pow(3);
560
561		assert!(expr.try_convert_field::<BinaryField1b>().is_err());
562
563		let expected = ((ArithExpr::Var(0) + ArithExpr::Const(F8::ONE))
564			* ArithExpr::Const(F8::new(222)))
565		.pow(3);
566		assert_eq!(expr.try_convert_field::<BinaryField8b>().unwrap(), expected);
567	}
568
569	#[test]
570	fn test_linear_normal_form() {
571		type F = BinaryField128b;
572		use ArithExpr::{Const, Var};
573		let expr = Const(F::new(133))
574			+ Const(F::new(42)) * Var(0)
575			+ Var(2) + Const(F::new(11)) * Const(F::new(37)) * Var(3);
576		let normal_form = expr.linear_normal_form().unwrap();
577		assert_eq!(normal_form.constant, F::new(133));
578		assert_eq!(
579			normal_form.var_coeffs,
580			vec![F::new(42), F::ZERO, F::ONE, F::new(11) * F::new(37)]
581		);
582	}
583}