1use std::{
4 cmp::Ordering,
5 fmt::{self, Display},
6 iter::{Product, Sum},
7 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_field::{Field, PackedField, TowerField};
11use binius_macros::{DeserializeBytes, SerializeBytes};
12
13use super::error::Error;
14
15#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
21pub enum ArithExpr<F: Field> {
22 Const(F),
23 Var(usize),
24 Add(Box<ArithExpr<F>>, Box<ArithExpr<F>>),
25 Mul(Box<ArithExpr<F>>, Box<ArithExpr<F>>),
26 Pow(Box<ArithExpr<F>>, u64),
27}
28
29impl<F: Field + Display> Display for ArithExpr<F> {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 Self::Const(v) => write!(f, "{v}"),
33 Self::Var(i) => write!(f, "x{i}"),
34 Self::Add(x, y) => write!(f, "({} + {})", &**x, &**y),
35 Self::Mul(x, y) => write!(f, "({} * {})", &**x, &**y),
36 Self::Pow(x, p) => write!(f, "({})^{p}", &**x),
37 }
38 }
39}
40
41impl<F: Field> ArithExpr<F> {
42 pub fn n_vars(&self) -> usize {
44 match self {
45 Self::Const(_) => 0,
46 Self::Var(index) => *index + 1,
47 Self::Add(left, right) | Self::Mul(left, right) => left.n_vars().max(right.n_vars()),
48 Self::Pow(id, _) => id.n_vars(),
49 }
50 }
51
52 pub fn degree(&self) -> usize {
54 match self {
55 Self::Const(_) => 0,
56 Self::Var(_) => 1,
57 Self::Add(left, right) => left.degree().max(right.degree()),
58 Self::Mul(left, right) => left.degree() + right.degree(),
59 Self::Pow(base, exp) => base.degree() * *exp as usize,
60 }
61 }
62
63 pub fn leading_term(&self) -> Self {
66 let (_, expr) = self.leading_term_with_degree();
67 expr
68 }
69
70 pub fn leading_term_with_degree(&self) -> (usize, Self) {
72 match self {
73 expr @ Self::Const(_) => (0, expr.clone()),
74 expr @ Self::Var(_) => (1, expr.clone()),
75 Self::Add(left, right) => {
76 let (lhs_degree, lhs) = left.leading_term_with_degree();
77 let (rhs_degree, rhs) = right.leading_term_with_degree();
78 match lhs_degree.cmp(&rhs_degree) {
79 Ordering::Less => (rhs_degree, rhs),
80 Ordering::Equal => (lhs_degree, Self::Add(Box::new(lhs), Box::new(rhs))),
81 Ordering::Greater => (lhs_degree, lhs),
82 }
83 }
84 Self::Mul(left, right) => {
85 let (lhs_degree, lhs) = left.leading_term_with_degree();
86 let (rhs_degree, rhs) = right.leading_term_with_degree();
87 (lhs_degree + rhs_degree, Self::Mul(Box::new(lhs), Box::new(rhs)))
88 }
89 Self::Pow(base, exp) => {
90 let (base_degree, base) = base.leading_term_with_degree();
91 (base_degree * *exp as usize, Self::Pow(Box::new(base), *exp))
92 }
93 }
94 }
95
96 pub fn pow(self, exp: u64) -> Self {
97 Self::Pow(Box::new(self), exp)
98 }
99
100 pub const fn zero() -> Self {
101 Self::Const(F::ZERO)
102 }
103
104 pub const fn one() -> Self {
105 Self::Const(F::ONE)
106 }
107
108 pub fn remap_vars(self, indices: &[usize]) -> Result<Self, Error> {
118 let expr = match self {
119 Self::Const(_) => self,
120 Self::Var(index) => {
121 let new_index =
122 indices
123 .get(index)
124 .ok_or_else(|| Error::IncorrectArgumentLength {
125 arg: "subset".to_string(),
126 expected: index,
127 })?;
128 Self::Var(*new_index)
129 }
130 Self::Add(left, right) => {
131 let new_left = left.remap_vars(indices)?;
132 let new_right = right.remap_vars(indices)?;
133 Self::Add(Box::new(new_left), Box::new(new_right))
134 }
135 Self::Mul(left, right) => {
136 let new_left = left.remap_vars(indices)?;
137 let new_right = right.remap_vars(indices)?;
138 Self::Mul(Box::new(new_left), Box::new(new_right))
139 }
140 Self::Pow(base, exp) => {
141 let new_base = base.remap_vars(indices)?;
142 Self::Pow(Box::new(new_base), exp)
143 }
144 };
145 Ok(expr)
146 }
147
148 pub fn const_subst(self, var: usize, value: F) -> Self {
150 match self {
151 Self::Const(_) => self,
152 Self::Var(index) => {
153 if index == var {
154 Self::Const(value)
155 } else {
156 self
157 }
158 }
159 Self::Add(left, right) => {
160 let new_left = left.const_subst(var, value);
161 let new_right = right.const_subst(var, value);
162 Self::Add(Box::new(new_left), Box::new(new_right))
163 }
164 Self::Mul(left, right) => {
165 let new_left = left.const_subst(var, value);
166 let new_right = right.const_subst(var, value);
167 Self::Mul(Box::new(new_left), Box::new(new_right))
168 }
169 Self::Pow(base, exp) => {
170 let new_base = base.const_subst(var, value);
171 Self::Pow(Box::new(new_base), exp)
172 }
173 }
174 }
175
176 pub fn convert_field<FTgt: Field + From<F>>(&self) -> ArithExpr<FTgt> {
177 match self {
178 Self::Const(val) => ArithExpr::Const((*val).into()),
179 Self::Var(index) => ArithExpr::Var(*index),
180 Self::Add(left, right) => {
181 let new_left = left.convert_field();
182 let new_right = right.convert_field();
183 ArithExpr::Add(Box::new(new_left), Box::new(new_right))
184 }
185 Self::Mul(left, right) => {
186 let new_left = left.convert_field();
187 let new_right = right.convert_field();
188 ArithExpr::Mul(Box::new(new_left), Box::new(new_right))
189 }
190 Self::Pow(base, exp) => {
191 let new_base = base.convert_field();
192 ArithExpr::Pow(Box::new(new_base), *exp)
193 }
194 }
195 }
196
197 pub fn try_convert_field<FTgt: Field + TryFrom<F>>(
198 &self,
199 ) -> Result<ArithExpr<FTgt>, <FTgt as TryFrom<F>>::Error> {
200 Ok(match self {
201 Self::Const(val) => ArithExpr::Const(FTgt::try_from(*val)?),
202 Self::Var(index) => ArithExpr::Var(*index),
203 Self::Add(left, right) => {
204 let new_left = left.try_convert_field()?;
205 let new_right = right.try_convert_field()?;
206 ArithExpr::Add(Box::new(new_left), Box::new(new_right))
207 }
208 Self::Mul(left, right) => {
209 let new_left = left.try_convert_field()?;
210 let new_right = right.try_convert_field()?;
211 ArithExpr::Mul(Box::new(new_left), Box::new(new_right))
212 }
213 Self::Pow(base, exp) => {
214 let new_base = base.try_convert_field()?;
215 ArithExpr::Pow(Box::new(new_base), *exp)
216 }
217 })
218 }
219
220 pub const fn is_composite(&self) -> bool {
222 match self {
223 Self::Const(_) | Self::Var(_) => false,
224 Self::Add(_, _) | Self::Mul(_, _) | Self::Pow(_, _) => true,
225 }
226 }
227
228 pub const fn constant(&self) -> Option<F> {
230 match self {
231 Self::Const(value) => Some(*value),
232 _ => None,
233 }
234 }
235
236 pub fn optimize(&self) -> Self {
242 match self {
243 Self::Const(_) | Self::Var(_) => self.clone(),
244 Self::Add(left, right) => {
245 let left = left.optimize();
246 let right = right.optimize();
247 match (left, right) {
248 (Self::Const(left), Self::Const(right)) => Self::Const(left + right),
250 (Self::Const(left), right) if left == F::ZERO => right,
252 (left, Self::Const(right)) if right == F::ZERO => left,
253 (left, right) if left == right && F::CHARACTERISTIC == 2 => {
256 Self::Const(F::ZERO)
257 }
258 (left, right) => Self::Add(Box::new(left), Box::new(right)),
260 }
261 }
262 Self::Mul(left, right) => {
263 let left = left.optimize();
264 let right = right.optimize();
265 match (left, right) {
266 (Self::Const(left), Self::Const(right)) => Self::Const(left * right),
268 (left, right)
270 if left == Self::Const(F::ZERO) || right == Self::Const(F::ZERO) =>
271 {
272 Self::Const(F::ZERO)
273 }
274 (Self::Const(left), right) if left == F::ONE => right,
276 (left, Self::Const(right)) if right == F::ONE => left,
277 (left, right) => Self::Mul(Box::new(left), Box::new(right)),
279 }
280 }
281 Self::Pow(id, exp) => {
282 let id = id.optimize();
283 match id {
284 Self::Const(value) => Self::Const(PackedField::pow(value, *exp)),
285 Self::Pow(id_inner, exp_inner) => Self::Pow(id_inner, *exp * exp_inner),
286 id => Self::Pow(Box::new(id), *exp),
287 }
288 }
289 }
290 }
291
292 pub fn linear_normal_form(&self) -> Result<LinearNormalForm<F>, Error> {
298 if self.degree() > 1 {
299 return Err(Error::NonLinearExpression);
300 }
301
302 let n_vars = self.n_vars();
303
304 let constant = self.evaluate(&vec![F::ZERO; n_vars]);
307
308 let var_coeffs = (0..n_vars)
311 .map(|i| {
312 let mut vars = vec![F::ZERO; n_vars];
313 vars[i] = F::ONE;
314 self.evaluate(&vars) - constant
315 })
316 .collect();
317 Ok(LinearNormalForm {
318 constant,
319 var_coeffs,
320 })
321 }
322
323 fn evaluate(&self, vars: &[F]) -> F {
324 match self {
325 Self::Const(val) => *val,
326 Self::Var(index) => vars[*index],
327 Self::Add(left, right) => left.evaluate(vars) + right.evaluate(vars),
328 Self::Mul(left, right) => left.evaluate(vars) * right.evaluate(vars),
329 Self::Pow(base, exp) => base.evaluate(vars).pow(*exp),
330 }
331 }
332
333 pub fn vars_usage(&self) -> Vec<bool> {
338 let mut usage = vec![false; self.n_vars()];
339 self.mark_vars_usage(&mut usage);
340 usage
341 }
342
343 fn mark_vars_usage(&self, usage: &mut [bool]) {
344 match self {
345 Self::Const(_) => (),
346 Self::Var(index) => usage[*index] = true,
347 Self::Add(left, right) | Self::Mul(left, right) => {
348 left.mark_vars_usage(usage);
349 right.mark_vars_usage(usage);
350 }
351 Self::Pow(base, _) => base.mark_vars_usage(usage),
352 }
353 }
354}
355
356impl<F: TowerField> ArithExpr<F> {
357 pub fn binary_tower_level(&self) -> usize {
358 match self {
359 Self::Const(value) => value.min_tower_level(),
360 Self::Var(_) => 0,
361 Self::Add(left, right) | Self::Mul(left, right) => {
362 left.binary_tower_level().max(right.binary_tower_level())
363 }
364 Self::Pow(base, _) => base.binary_tower_level(),
365 }
366 }
367}
368
369impl<F> Default for ArithExpr<F>
370where
371 F: Field,
372{
373 fn default() -> Self {
374 Self::zero()
375 }
376}
377
378impl<F> Add for ArithExpr<F>
379where
380 F: Field,
381{
382 type Output = Self;
383
384 fn add(self, rhs: Self) -> Self {
385 Self::Add(Box::new(self), Box::new(rhs))
386 }
387}
388
389impl<F> AddAssign for ArithExpr<F>
390where
391 F: Field,
392{
393 fn add_assign(&mut self, rhs: Self) {
394 *self = std::mem::take(self) + rhs;
395 }
396}
397
398impl<F> Sub for ArithExpr<F>
399where
400 F: Field,
401{
402 type Output = Self;
403
404 fn sub(self, rhs: Self) -> Self {
405 Self::Add(Box::new(self), Box::new(rhs))
406 }
407}
408
409impl<F> SubAssign for ArithExpr<F>
410where
411 F: Field,
412{
413 fn sub_assign(&mut self, rhs: Self) {
414 *self = std::mem::take(self) - rhs;
415 }
416}
417
418impl<F> Mul for ArithExpr<F>
419where
420 F: Field,
421{
422 type Output = Self;
423
424 fn mul(self, rhs: Self) -> Self {
425 Self::Mul(Box::new(self), Box::new(rhs))
426 }
427}
428
429impl<F> MulAssign for ArithExpr<F>
430where
431 F: Field,
432{
433 fn mul_assign(&mut self, rhs: Self) {
434 *self = std::mem::take(self) * rhs;
435 }
436}
437
438impl<F: Field> Sum for ArithExpr<F> {
439 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
440 iter.reduce(|acc, item| acc + item).unwrap_or(Self::zero())
441 }
442}
443
444impl<F: Field> Product for ArithExpr<F> {
445 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
446 iter.reduce(|acc, item| acc * item).unwrap_or(Self::one())
447 }
448}
449
450#[derive(Debug, Default, Clone, PartialEq, Eq)]
452pub struct LinearNormalForm<F: Field> {
453 pub constant: F,
455 pub var_coeffs: Vec<F>,
457}
458
459#[cfg(test)]
460mod tests {
461 use assert_matches::assert_matches;
462 use binius_field::{BinaryField, BinaryField128b, BinaryField1b, BinaryField8b};
463
464 use super::*;
465
466 #[test]
467 fn test_degree_with_pow() {
468 let expr = ArithExpr::Const(BinaryField8b::new(6)).pow(7);
469 assert_eq!(expr.degree(), 0);
470
471 let expr: ArithExpr<BinaryField8b> = ArithExpr::Var(0).pow(7);
472 assert_eq!(expr.degree(), 7);
473
474 let expr: ArithExpr<BinaryField8b> = (ArithExpr::Var(0) * ArithExpr::Var(1)).pow(7);
475 assert_eq!(expr.degree(), 14);
476 }
477
478 #[test]
479 fn test_leading_term_with_degree() {
480 let expr = ArithExpr::Var(0)
481 * (ArithExpr::Var(1)
482 * ArithExpr::Var(2)
483 * ArithExpr::Const(BinaryField8b::MULTIPLICATIVE_GENERATOR)
484 + ArithExpr::Var(4))
485 + ArithExpr::Var(5).pow(3)
486 + ArithExpr::Const(BinaryField8b::ONE);
487
488 let expected_expr = ArithExpr::Var(0)
489 * ((ArithExpr::Var(1) * ArithExpr::Var(2))
490 * ArithExpr::Const(BinaryField8b::MULTIPLICATIVE_GENERATOR))
491 + ArithExpr::Var(5).pow(3);
492
493 assert_eq!(expr.leading_term_with_degree(), (3, expected_expr));
494 }
495
496 #[test]
497 fn test_remap_vars_with_too_few_vars() {
498 type F = BinaryField8b;
499 let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(1)).pow(3);
500 assert_matches!(expr.remap_vars(&[5]), Err(Error::IncorrectArgumentLength { .. }));
501 }
502
503 #[test]
504 fn test_remap_vars_works() {
505 type F = BinaryField8b;
506 let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(1)).pow(3);
507 let new_expr = expr.remap_vars(&[5, 3]);
508
509 let expected = ((ArithExpr::Var(5) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(3)).pow(3);
510 assert_eq!(new_expr.unwrap(), expected);
511 }
512
513 #[test]
514 fn test_optimize_identity_handling() {
515 type F = BinaryField8b;
516 let zero = ArithExpr::<F>::zero();
517 let one = ArithExpr::<F>::one();
518
519 assert_eq!((zero.clone() * ArithExpr::<F>::Var(0)).optimize(), zero);
520 assert_eq!((ArithExpr::<F>::Var(0) * zero.clone()).optimize(), zero);
521
522 assert_eq!((ArithExpr::<F>::Var(0) * one.clone()).optimize(), ArithExpr::Var(0));
523 assert_eq!((one * ArithExpr::<F>::Var(0)).optimize(), ArithExpr::Var(0));
524
525 assert_eq!((ArithExpr::<F>::Var(0) + zero.clone()).optimize(), ArithExpr::Var(0));
526 assert_eq!((zero.clone() + ArithExpr::<F>::Var(0)).optimize(), ArithExpr::Var(0));
527
528 assert_eq!((ArithExpr::<F>::Var(0) + ArithExpr::Var(0)).optimize(), zero);
529 }
530
531 #[test]
532 fn test_const_subst_and_optimize() {
533 type F = BinaryField8b;
535 let expr = ArithExpr::Var(0) * ArithExpr::Var(1) + ArithExpr::one() - ArithExpr::Var(1);
536 assert_eq!(expr.const_subst(1, F::ZERO).optimize().constant(), Some(F::ONE));
537 }
538
539 #[test]
540 fn test_expression_upcast() {
541 type F8 = BinaryField8b;
542 type F = BinaryField128b;
543
544 let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F8::ONE))
545 * ArithExpr::Const(F8::new(222)))
546 .pow(3);
547
548 let expected =
549 ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Const(F::new(222))).pow(3);
550 assert_eq!(expr.convert_field::<F>(), expected);
551 }
552
553 #[test]
554 fn test_expression_downcast() {
555 type F8 = BinaryField8b;
556 type F = BinaryField128b;
557
558 let expr =
559 ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Const(F::new(222))).pow(3);
560
561 assert!(expr.try_convert_field::<BinaryField1b>().is_err());
562
563 let expected = ((ArithExpr::Var(0) + ArithExpr::Const(F8::ONE))
564 * ArithExpr::Const(F8::new(222)))
565 .pow(3);
566 assert_eq!(expr.try_convert_field::<BinaryField8b>().unwrap(), expected);
567 }
568
569 #[test]
570 fn test_linear_normal_form() {
571 type F = BinaryField128b;
572 use ArithExpr::{Const, Var};
573 let expr = Const(F::new(133))
574 + Const(F::new(42)) * Var(0)
575 + Var(2) + Const(F::new(11)) * Const(F::new(37)) * Var(3);
576 let normal_form = expr.linear_normal_form().unwrap();
577 assert_eq!(normal_form.constant, F::new(133));
578 assert_eq!(
579 normal_form.var_coeffs,
580 vec![F::new(42), F::ZERO, F::ONE, F::new(11) * F::new(37)]
581 );
582 }
583}