binius_math/
arith_expr.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	cmp::Ordering,
5	fmt::{self, Display},
6	iter::{Product, Sum},
7	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_field::{Field, PackedField, TowerField};
11use binius_macros::{DeserializeBytes, SerializeBytes};
12
13use super::error::Error;
14
15/// Arithmetic expressions that can be evaluated symbolically.
16///
17/// Arithmetic expressions are trees, where the leaves are either constants or variables, and the
18/// non-leaf nodes are arithmetic operations, such as addition, multiplication, etc. They are
19/// specific representations of multivariate polynomials.
20#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
21pub enum ArithExpr<F: Field> {
22	Const(F),
23	Var(usize),
24	Add(Box<ArithExpr<F>>, Box<ArithExpr<F>>),
25	Mul(Box<ArithExpr<F>>, Box<ArithExpr<F>>),
26	Pow(Box<ArithExpr<F>>, u64),
27}
28
29impl<F: Field + Display> Display for ArithExpr<F> {
30	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31		match self {
32			Self::Const(v) => write!(f, "{v}"),
33			Self::Var(i) => write!(f, "x{i}"),
34			Self::Add(x, y) => write!(f, "({} + {})", &**x, &**y),
35			Self::Mul(x, y) => write!(f, "({} * {})", &**x, &**y),
36			Self::Pow(x, p) => write!(f, "({})^{p}", &**x),
37		}
38	}
39}
40
41impl<F: Field> ArithExpr<F> {
42	/// The number of variables the expression contains.
43	pub fn n_vars(&self) -> usize {
44		match self {
45			Self::Const(_) => 0,
46			Self::Var(index) => *index + 1,
47			Self::Add(left, right) | Self::Mul(left, right) => left.n_vars().max(right.n_vars()),
48			Self::Pow(id, _) => id.n_vars(),
49		}
50	}
51
52	/// The total degree of the polynomial the expression represents.
53	pub fn degree(&self) -> usize {
54		match self {
55			Self::Const(_) => 0,
56			Self::Var(_) => 1,
57			Self::Add(left, right) => left.degree().max(right.degree()),
58			Self::Mul(left, right) => left.degree() + right.degree(),
59			Self::Pow(base, exp) => base.degree() * *exp as usize,
60		}
61	}
62
63	/// Return a new arithmetic expression that contains only the terms of highest degree
64	/// (useful for interpolation at Karatsuba infinity point).
65	pub fn leading_term(&self) -> Self {
66		let (_, expr) = self.leading_term_with_degree();
67		expr
68	}
69
70	/// Same as `leading_term`, but returns the total degree as the first tuple element as well.
71	pub fn leading_term_with_degree(&self) -> (usize, Self) {
72		match self {
73			expr @ Self::Const(_) => (0, expr.clone()),
74			expr @ Self::Var(_) => (1, expr.clone()),
75			Self::Add(left, right) => {
76				let (lhs_degree, lhs) = left.leading_term_with_degree();
77				let (rhs_degree, rhs) = right.leading_term_with_degree();
78				match lhs_degree.cmp(&rhs_degree) {
79					Ordering::Less => (rhs_degree, rhs),
80					Ordering::Equal => (lhs_degree, Self::Add(Box::new(lhs), Box::new(rhs))),
81					Ordering::Greater => (lhs_degree, lhs),
82				}
83			}
84			Self::Mul(left, right) => {
85				let (lhs_degree, lhs) = left.leading_term_with_degree();
86				let (rhs_degree, rhs) = right.leading_term_with_degree();
87				(lhs_degree + rhs_degree, Self::Mul(Box::new(lhs), Box::new(rhs)))
88			}
89			Self::Pow(base, exp) => {
90				let (base_degree, base) = base.leading_term_with_degree();
91				(base_degree * *exp as usize, Self::Pow(Box::new(base), *exp))
92			}
93		}
94	}
95
96	pub fn pow(self, exp: u64) -> Self {
97		Self::Pow(Box::new(self), exp)
98	}
99
100	pub const fn zero() -> Self {
101		Self::Const(F::ZERO)
102	}
103
104	pub const fn one() -> Self {
105		Self::Const(F::ONE)
106	}
107
108	/// Creates a new expression with the variable indices remapped.
109	///
110	/// This recursively replaces the variable sub-expressions with an index `i` with the variable
111	/// `indices[i]`.
112	///
113	/// ## Throws
114	///
115	/// * [`Error::IncorrectArgumentLength`] if indices has length less than the current number of
116	///   variables
117	pub fn remap_vars(self, indices: &[usize]) -> Result<Self, Error> {
118		let expr = match self {
119			Self::Const(_) => self,
120			Self::Var(index) => {
121				let new_index =
122					indices
123						.get(index)
124						.ok_or_else(|| Error::IncorrectArgumentLength {
125							arg: "subset".to_string(),
126							expected: index,
127						})?;
128				Self::Var(*new_index)
129			}
130			Self::Add(left, right) => {
131				let new_left = left.remap_vars(indices)?;
132				let new_right = right.remap_vars(indices)?;
133				Self::Add(Box::new(new_left), Box::new(new_right))
134			}
135			Self::Mul(left, right) => {
136				let new_left = left.remap_vars(indices)?;
137				let new_right = right.remap_vars(indices)?;
138				Self::Mul(Box::new(new_left), Box::new(new_right))
139			}
140			Self::Pow(base, exp) => {
141				let new_base = base.remap_vars(indices)?;
142				Self::Pow(Box::new(new_base), exp)
143			}
144		};
145		Ok(expr)
146	}
147
148	pub fn convert_field<FTgt: Field + From<F>>(&self) -> ArithExpr<FTgt> {
149		match self {
150			Self::Const(val) => ArithExpr::Const((*val).into()),
151			Self::Var(index) => ArithExpr::Var(*index),
152			Self::Add(left, right) => {
153				let new_left = left.convert_field();
154				let new_right = right.convert_field();
155				ArithExpr::Add(Box::new(new_left), Box::new(new_right))
156			}
157			Self::Mul(left, right) => {
158				let new_left = left.convert_field();
159				let new_right = right.convert_field();
160				ArithExpr::Mul(Box::new(new_left), Box::new(new_right))
161			}
162			Self::Pow(base, exp) => {
163				let new_base = base.convert_field();
164				ArithExpr::Pow(Box::new(new_base), *exp)
165			}
166		}
167	}
168
169	pub fn try_convert_field<FTgt: Field + TryFrom<F>>(
170		&self,
171	) -> Result<ArithExpr<FTgt>, <FTgt as TryFrom<F>>::Error> {
172		Ok(match self {
173			Self::Const(val) => ArithExpr::Const(FTgt::try_from(*val)?),
174			Self::Var(index) => ArithExpr::Var(*index),
175			Self::Add(left, right) => {
176				let new_left = left.try_convert_field()?;
177				let new_right = right.try_convert_field()?;
178				ArithExpr::Add(Box::new(new_left), Box::new(new_right))
179			}
180			Self::Mul(left, right) => {
181				let new_left = left.try_convert_field()?;
182				let new_right = right.try_convert_field()?;
183				ArithExpr::Mul(Box::new(new_left), Box::new(new_right))
184			}
185			Self::Pow(base, exp) => {
186				let new_base = base.try_convert_field()?;
187				ArithExpr::Pow(Box::new(new_base), *exp)
188			}
189		})
190	}
191
192	/// Whether expression is a composite node, and not a leaf.
193	pub const fn is_composite(&self) -> bool {
194		match self {
195			Self::Const(_) | Self::Var(_) => false,
196			Self::Add(_, _) | Self::Mul(_, _) | Self::Pow(_, _) => true,
197		}
198	}
199
200	/// Creates a new optimized expression.
201	///
202	/// Recursively rewrites expression for better evaluation performance.
203	pub fn optimize(&self) -> Self {
204		match self {
205			Self::Const(_) | Self::Var(_) => self.clone(),
206			Self::Add(left, right) => {
207				let left = left.optimize();
208				let right = right.optimize();
209				match (left, right) {
210					(Self::Const(left), Self::Const(right)) => Self::Const(left + right),
211					(left, right) => Self::Add(Box::new(left), Box::new(right)),
212				}
213			}
214			Self::Mul(left, right) => {
215				let left = left.optimize();
216				let right = right.optimize();
217				match (left, right) {
218					(Self::Const(left), Self::Const(right)) => Self::Const(left * right),
219					(left, right) => Self::Mul(Box::new(left), Box::new(right)),
220				}
221			}
222			Self::Pow(id, exp) => {
223				let id = id.optimize();
224				match id {
225					Self::Const(value) => Self::Const(PackedField::pow(value, *exp)),
226					Self::Pow(id_inner, exp_inner) => Self::Pow(id_inner, *exp * exp_inner),
227					id => Self::Pow(Box::new(id), *exp),
228				}
229			}
230		}
231	}
232
233	/// Returns the normal form of an expression if it is linear.
234	///
235	/// ## Throws
236	///
237	/// - [`Error::NonLinearExpression`] if the expression is not linear.
238	pub fn linear_normal_form(&self) -> Result<LinearNormalForm<F>, Error> {
239		if self.degree() > 1 {
240			return Err(Error::NonLinearExpression);
241		}
242
243		let mut normal_form = LinearNormalForm::default();
244		let n_vars = self.n_vars();
245
246		// Linear normal form: f(x0, x1, ... x{n-1}) = c + a0*x0 + a1*x1 + ... + a{n-1}*x{n-1}
247		// Evaluating with all variables set to 0, should give the constant term
248		let constant = self.evaluate(&vec![F::ZERO; n_vars]);
249
250		// Evaluating with x{k} set to 1 and all other x{i} set to 0, gives us `constant + a{k}`
251		// That means we can subtract the constant from the evaluated expression to get the coefficient a{k}
252		for i in 0..n_vars {
253			let mut vars = vec![F::ZERO; n_vars];
254			vars[i] = F::ONE;
255			normal_form.var_coeffs.push(self.evaluate(&vars) - constant);
256		}
257		Ok(normal_form)
258	}
259
260	fn evaluate(&self, vars: &[F]) -> F {
261		match self {
262			Self::Const(val) => *val,
263			Self::Var(index) => vars[*index],
264			Self::Add(left, right) => left.evaluate(vars) + right.evaluate(vars),
265			Self::Mul(left, right) => left.evaluate(vars) * right.evaluate(vars),
266			Self::Pow(base, exp) => base.evaluate(vars).pow(*exp),
267		}
268	}
269}
270
271impl<F: TowerField> ArithExpr<F> {
272	pub fn binary_tower_level(&self) -> usize {
273		match self {
274			Self::Const(value) => value.min_tower_level(),
275			Self::Var(_) => 0,
276			Self::Add(left, right) | Self::Mul(left, right) => {
277				left.binary_tower_level().max(right.binary_tower_level())
278			}
279			Self::Pow(base, _) => base.binary_tower_level(),
280		}
281	}
282}
283
284impl<F> Default for ArithExpr<F>
285where
286	F: Field,
287{
288	fn default() -> Self {
289		Self::zero()
290	}
291}
292
293impl<F> Add for ArithExpr<F>
294where
295	F: Field,
296{
297	type Output = Self;
298
299	fn add(self, rhs: Self) -> Self {
300		Self::Add(Box::new(self), Box::new(rhs))
301	}
302}
303
304impl<F> AddAssign for ArithExpr<F>
305where
306	F: Field,
307{
308	fn add_assign(&mut self, rhs: Self) {
309		*self = std::mem::take(self) + rhs;
310	}
311}
312
313impl<F> Sub for ArithExpr<F>
314where
315	F: Field,
316{
317	type Output = Self;
318
319	fn sub(self, rhs: Self) -> Self {
320		Self::Add(Box::new(self), Box::new(rhs))
321	}
322}
323
324impl<F> SubAssign for ArithExpr<F>
325where
326	F: Field,
327{
328	fn sub_assign(&mut self, rhs: Self) {
329		*self = std::mem::take(self) - rhs;
330	}
331}
332
333impl<F> Mul for ArithExpr<F>
334where
335	F: Field,
336{
337	type Output = Self;
338
339	fn mul(self, rhs: Self) -> Self {
340		Self::Mul(Box::new(self), Box::new(rhs))
341	}
342}
343
344impl<F> MulAssign for ArithExpr<F>
345where
346	F: Field,
347{
348	fn mul_assign(&mut self, rhs: Self) {
349		*self = std::mem::take(self) * rhs;
350	}
351}
352
353impl<F: Field> Sum for ArithExpr<F> {
354	fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
355		iter.reduce(|acc, item| acc + item).unwrap_or(Self::zero())
356	}
357}
358
359impl<F: Field> Product for ArithExpr<F> {
360	fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
361		iter.reduce(|acc, item| acc * item).unwrap_or(Self::one())
362	}
363}
364
365/// A normal form for a linear expression.
366#[derive(Debug, Default, Clone, PartialEq, Eq)]
367pub struct LinearNormalForm<F: Field> {
368	/// The constant offset of the expression.
369	pub constant: F,
370	/// A vector mapping variable indices to their coefficients.
371	pub var_coeffs: Vec<F>,
372}
373
374#[cfg(test)]
375mod tests {
376	use assert_matches::assert_matches;
377	use binius_field::{BinaryField, BinaryField128b, BinaryField1b, BinaryField8b};
378
379	use super::*;
380
381	#[test]
382	fn test_degree_with_pow() {
383		let expr = ArithExpr::Const(BinaryField8b::new(6)).pow(7);
384		assert_eq!(expr.degree(), 0);
385
386		let expr: ArithExpr<BinaryField8b> = ArithExpr::Var(0).pow(7);
387		assert_eq!(expr.degree(), 7);
388
389		let expr: ArithExpr<BinaryField8b> = (ArithExpr::Var(0) * ArithExpr::Var(1)).pow(7);
390		assert_eq!(expr.degree(), 14);
391	}
392
393	#[test]
394	fn test_leading_term_with_degree() {
395		let expr = ArithExpr::Var(0)
396			* (ArithExpr::Var(1)
397				* ArithExpr::Var(2)
398				* ArithExpr::Const(BinaryField8b::MULTIPLICATIVE_GENERATOR)
399				+ ArithExpr::Var(4))
400			+ ArithExpr::Var(5).pow(3)
401			+ ArithExpr::Const(BinaryField8b::ONE);
402
403		let expected_expr = ArithExpr::Var(0)
404			* ((ArithExpr::Var(1) * ArithExpr::Var(2))
405				* ArithExpr::Const(BinaryField8b::MULTIPLICATIVE_GENERATOR))
406			+ ArithExpr::Var(5).pow(3);
407
408		assert_eq!(expr.leading_term_with_degree(), (3, expected_expr));
409	}
410
411	#[test]
412	fn test_remap_vars_with_too_few_vars() {
413		type F = BinaryField8b;
414		let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(1)).pow(3);
415		assert_matches!(expr.remap_vars(&[5]), Err(Error::IncorrectArgumentLength { .. }));
416	}
417
418	#[test]
419	fn test_remap_vars_works() {
420		type F = BinaryField8b;
421		let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(1)).pow(3);
422		let new_expr = expr.remap_vars(&[5, 3]);
423
424		let expected = ((ArithExpr::Var(5) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(3)).pow(3);
425		assert_eq!(new_expr.unwrap(), expected);
426	}
427
428	#[test]
429	fn test_expression_upcast() {
430		type F8 = BinaryField8b;
431		type F = BinaryField128b;
432
433		let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F8::ONE))
434			* ArithExpr::Const(F8::new(222)))
435		.pow(3);
436
437		let expected =
438			((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Const(F::new(222))).pow(3);
439		assert_eq!(expr.convert_field::<F>(), expected);
440	}
441
442	#[test]
443	fn test_expression_downcast() {
444		type F8 = BinaryField8b;
445		type F = BinaryField128b;
446
447		let expr =
448			((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Const(F::new(222))).pow(3);
449
450		assert!(expr.try_convert_field::<BinaryField1b>().is_err());
451
452		let expected = ((ArithExpr::Var(0) + ArithExpr::Const(F8::ONE))
453			* ArithExpr::Const(F8::new(222)))
454		.pow(3);
455		assert_eq!(expr.try_convert_field::<BinaryField8b>().unwrap(), expected);
456	}
457
458	#[test]
459	fn test_linear_normal_form() {
460		type F = BinaryField128b;
461		use ArithExpr::{Const, Var};
462		let expr = Const(F::new(133))
463			+ Const(F::new(42)) * Var(0)
464			+ Var(2) + Const(F::new(11)) * Const(F::new(37)) * Var(3);
465		let normal_form = expr.linear_normal_form().unwrap();
466		assert_eq!(normal_form.constant, F::ZERO);
467		assert_eq!(
468			normal_form.var_coeffs,
469			vec![F::new(42), F::ZERO, F::ONE, F::new(11) * F::new(37)]
470		);
471	}
472}