1use std::{
4 cmp::Ordering,
5 fmt::{self, Display},
6 iter::{Product, Sum},
7 ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
8};
9
10use binius_field::{Field, PackedField, TowerField};
11use binius_macros::{DeserializeBytes, SerializeBytes};
12
13use super::error::Error;
14
15#[derive(Debug, Clone, PartialEq, Eq, SerializeBytes, DeserializeBytes)]
21pub enum ArithExpr<F: Field> {
22 Const(F),
23 Var(usize),
24 Add(Box<ArithExpr<F>>, Box<ArithExpr<F>>),
25 Mul(Box<ArithExpr<F>>, Box<ArithExpr<F>>),
26 Pow(Box<ArithExpr<F>>, u64),
27}
28
29impl<F: Field + Display> Display for ArithExpr<F> {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 Self::Const(v) => write!(f, "{v}"),
33 Self::Var(i) => write!(f, "x{i}"),
34 Self::Add(x, y) => write!(f, "({} + {})", &**x, &**y),
35 Self::Mul(x, y) => write!(f, "({} * {})", &**x, &**y),
36 Self::Pow(x, p) => write!(f, "({})^{p}", &**x),
37 }
38 }
39}
40
41impl<F: Field> ArithExpr<F> {
42 pub fn n_vars(&self) -> usize {
44 match self {
45 Self::Const(_) => 0,
46 Self::Var(index) => *index + 1,
47 Self::Add(left, right) | Self::Mul(left, right) => left.n_vars().max(right.n_vars()),
48 Self::Pow(id, _) => id.n_vars(),
49 }
50 }
51
52 pub fn degree(&self) -> usize {
54 match self {
55 Self::Const(_) => 0,
56 Self::Var(_) => 1,
57 Self::Add(left, right) => left.degree().max(right.degree()),
58 Self::Mul(left, right) => left.degree() + right.degree(),
59 Self::Pow(base, exp) => base.degree() * *exp as usize,
60 }
61 }
62
63 pub fn leading_term(&self) -> Self {
66 let (_, expr) = self.leading_term_with_degree();
67 expr
68 }
69
70 pub fn leading_term_with_degree(&self) -> (usize, Self) {
72 match self {
73 expr @ Self::Const(_) => (0, expr.clone()),
74 expr @ Self::Var(_) => (1, expr.clone()),
75 Self::Add(left, right) => {
76 let (lhs_degree, lhs) = left.leading_term_with_degree();
77 let (rhs_degree, rhs) = right.leading_term_with_degree();
78 match lhs_degree.cmp(&rhs_degree) {
79 Ordering::Less => (rhs_degree, rhs),
80 Ordering::Equal => (lhs_degree, Self::Add(Box::new(lhs), Box::new(rhs))),
81 Ordering::Greater => (lhs_degree, lhs),
82 }
83 }
84 Self::Mul(left, right) => {
85 let (lhs_degree, lhs) = left.leading_term_with_degree();
86 let (rhs_degree, rhs) = right.leading_term_with_degree();
87 (lhs_degree + rhs_degree, Self::Mul(Box::new(lhs), Box::new(rhs)))
88 }
89 Self::Pow(base, exp) => {
90 let (base_degree, base) = base.leading_term_with_degree();
91 (base_degree * *exp as usize, Self::Pow(Box::new(base), *exp))
92 }
93 }
94 }
95
96 pub fn pow(self, exp: u64) -> Self {
97 Self::Pow(Box::new(self), exp)
98 }
99
100 pub const fn zero() -> Self {
101 Self::Const(F::ZERO)
102 }
103
104 pub const fn one() -> Self {
105 Self::Const(F::ONE)
106 }
107
108 pub fn remap_vars(self, indices: &[usize]) -> Result<Self, Error> {
118 let expr = match self {
119 Self::Const(_) => self,
120 Self::Var(index) => {
121 let new_index =
122 indices
123 .get(index)
124 .ok_or_else(|| Error::IncorrectArgumentLength {
125 arg: "subset".to_string(),
126 expected: index,
127 })?;
128 Self::Var(*new_index)
129 }
130 Self::Add(left, right) => {
131 let new_left = left.remap_vars(indices)?;
132 let new_right = right.remap_vars(indices)?;
133 Self::Add(Box::new(new_left), Box::new(new_right))
134 }
135 Self::Mul(left, right) => {
136 let new_left = left.remap_vars(indices)?;
137 let new_right = right.remap_vars(indices)?;
138 Self::Mul(Box::new(new_left), Box::new(new_right))
139 }
140 Self::Pow(base, exp) => {
141 let new_base = base.remap_vars(indices)?;
142 Self::Pow(Box::new(new_base), exp)
143 }
144 };
145 Ok(expr)
146 }
147
148 pub fn convert_field<FTgt: Field + From<F>>(&self) -> ArithExpr<FTgt> {
149 match self {
150 Self::Const(val) => ArithExpr::Const((*val).into()),
151 Self::Var(index) => ArithExpr::Var(*index),
152 Self::Add(left, right) => {
153 let new_left = left.convert_field();
154 let new_right = right.convert_field();
155 ArithExpr::Add(Box::new(new_left), Box::new(new_right))
156 }
157 Self::Mul(left, right) => {
158 let new_left = left.convert_field();
159 let new_right = right.convert_field();
160 ArithExpr::Mul(Box::new(new_left), Box::new(new_right))
161 }
162 Self::Pow(base, exp) => {
163 let new_base = base.convert_field();
164 ArithExpr::Pow(Box::new(new_base), *exp)
165 }
166 }
167 }
168
169 pub fn try_convert_field<FTgt: Field + TryFrom<F>>(
170 &self,
171 ) -> Result<ArithExpr<FTgt>, <FTgt as TryFrom<F>>::Error> {
172 Ok(match self {
173 Self::Const(val) => ArithExpr::Const(FTgt::try_from(*val)?),
174 Self::Var(index) => ArithExpr::Var(*index),
175 Self::Add(left, right) => {
176 let new_left = left.try_convert_field()?;
177 let new_right = right.try_convert_field()?;
178 ArithExpr::Add(Box::new(new_left), Box::new(new_right))
179 }
180 Self::Mul(left, right) => {
181 let new_left = left.try_convert_field()?;
182 let new_right = right.try_convert_field()?;
183 ArithExpr::Mul(Box::new(new_left), Box::new(new_right))
184 }
185 Self::Pow(base, exp) => {
186 let new_base = base.try_convert_field()?;
187 ArithExpr::Pow(Box::new(new_base), *exp)
188 }
189 })
190 }
191
192 pub const fn is_composite(&self) -> bool {
194 match self {
195 Self::Const(_) | Self::Var(_) => false,
196 Self::Add(_, _) | Self::Mul(_, _) | Self::Pow(_, _) => true,
197 }
198 }
199
200 pub fn optimize(&self) -> Self {
204 match self {
205 Self::Const(_) | Self::Var(_) => self.clone(),
206 Self::Add(left, right) => {
207 let left = left.optimize();
208 let right = right.optimize();
209 match (left, right) {
210 (Self::Const(left), Self::Const(right)) => Self::Const(left + right),
211 (left, right) => Self::Add(Box::new(left), Box::new(right)),
212 }
213 }
214 Self::Mul(left, right) => {
215 let left = left.optimize();
216 let right = right.optimize();
217 match (left, right) {
218 (Self::Const(left), Self::Const(right)) => Self::Const(left * right),
219 (left, right) => Self::Mul(Box::new(left), Box::new(right)),
220 }
221 }
222 Self::Pow(id, exp) => {
223 let id = id.optimize();
224 match id {
225 Self::Const(value) => Self::Const(PackedField::pow(value, *exp)),
226 Self::Pow(id_inner, exp_inner) => Self::Pow(id_inner, *exp * exp_inner),
227 id => Self::Pow(Box::new(id), *exp),
228 }
229 }
230 }
231 }
232
233 pub fn linear_normal_form(&self) -> Result<LinearNormalForm<F>, Error> {
239 if self.degree() > 1 {
240 return Err(Error::NonLinearExpression);
241 }
242
243 let mut normal_form = LinearNormalForm::default();
244 let n_vars = self.n_vars();
245
246 let constant = self.evaluate(&vec![F::ZERO; n_vars]);
249
250 for i in 0..n_vars {
253 let mut vars = vec![F::ZERO; n_vars];
254 vars[i] = F::ONE;
255 normal_form.var_coeffs.push(self.evaluate(&vars) - constant);
256 }
257 Ok(normal_form)
258 }
259
260 fn evaluate(&self, vars: &[F]) -> F {
261 match self {
262 Self::Const(val) => *val,
263 Self::Var(index) => vars[*index],
264 Self::Add(left, right) => left.evaluate(vars) + right.evaluate(vars),
265 Self::Mul(left, right) => left.evaluate(vars) * right.evaluate(vars),
266 Self::Pow(base, exp) => base.evaluate(vars).pow(*exp),
267 }
268 }
269}
270
271impl<F: TowerField> ArithExpr<F> {
272 pub fn binary_tower_level(&self) -> usize {
273 match self {
274 Self::Const(value) => value.min_tower_level(),
275 Self::Var(_) => 0,
276 Self::Add(left, right) | Self::Mul(left, right) => {
277 left.binary_tower_level().max(right.binary_tower_level())
278 }
279 Self::Pow(base, _) => base.binary_tower_level(),
280 }
281 }
282}
283
284impl<F> Default for ArithExpr<F>
285where
286 F: Field,
287{
288 fn default() -> Self {
289 Self::zero()
290 }
291}
292
293impl<F> Add for ArithExpr<F>
294where
295 F: Field,
296{
297 type Output = Self;
298
299 fn add(self, rhs: Self) -> Self {
300 Self::Add(Box::new(self), Box::new(rhs))
301 }
302}
303
304impl<F> AddAssign for ArithExpr<F>
305where
306 F: Field,
307{
308 fn add_assign(&mut self, rhs: Self) {
309 *self = std::mem::take(self) + rhs;
310 }
311}
312
313impl<F> Sub for ArithExpr<F>
314where
315 F: Field,
316{
317 type Output = Self;
318
319 fn sub(self, rhs: Self) -> Self {
320 Self::Add(Box::new(self), Box::new(rhs))
321 }
322}
323
324impl<F> SubAssign for ArithExpr<F>
325where
326 F: Field,
327{
328 fn sub_assign(&mut self, rhs: Self) {
329 *self = std::mem::take(self) - rhs;
330 }
331}
332
333impl<F> Mul for ArithExpr<F>
334where
335 F: Field,
336{
337 type Output = Self;
338
339 fn mul(self, rhs: Self) -> Self {
340 Self::Mul(Box::new(self), Box::new(rhs))
341 }
342}
343
344impl<F> MulAssign for ArithExpr<F>
345where
346 F: Field,
347{
348 fn mul_assign(&mut self, rhs: Self) {
349 *self = std::mem::take(self) * rhs;
350 }
351}
352
353impl<F: Field> Sum for ArithExpr<F> {
354 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
355 iter.reduce(|acc, item| acc + item).unwrap_or(Self::zero())
356 }
357}
358
359impl<F: Field> Product for ArithExpr<F> {
360 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
361 iter.reduce(|acc, item| acc * item).unwrap_or(Self::one())
362 }
363}
364
365#[derive(Debug, Default, Clone, PartialEq, Eq)]
367pub struct LinearNormalForm<F: Field> {
368 pub constant: F,
370 pub var_coeffs: Vec<F>,
372}
373
374#[cfg(test)]
375mod tests {
376 use assert_matches::assert_matches;
377 use binius_field::{BinaryField, BinaryField128b, BinaryField1b, BinaryField8b};
378
379 use super::*;
380
381 #[test]
382 fn test_degree_with_pow() {
383 let expr = ArithExpr::Const(BinaryField8b::new(6)).pow(7);
384 assert_eq!(expr.degree(), 0);
385
386 let expr: ArithExpr<BinaryField8b> = ArithExpr::Var(0).pow(7);
387 assert_eq!(expr.degree(), 7);
388
389 let expr: ArithExpr<BinaryField8b> = (ArithExpr::Var(0) * ArithExpr::Var(1)).pow(7);
390 assert_eq!(expr.degree(), 14);
391 }
392
393 #[test]
394 fn test_leading_term_with_degree() {
395 let expr = ArithExpr::Var(0)
396 * (ArithExpr::Var(1)
397 * ArithExpr::Var(2)
398 * ArithExpr::Const(BinaryField8b::MULTIPLICATIVE_GENERATOR)
399 + ArithExpr::Var(4))
400 + ArithExpr::Var(5).pow(3)
401 + ArithExpr::Const(BinaryField8b::ONE);
402
403 let expected_expr = ArithExpr::Var(0)
404 * ((ArithExpr::Var(1) * ArithExpr::Var(2))
405 * ArithExpr::Const(BinaryField8b::MULTIPLICATIVE_GENERATOR))
406 + ArithExpr::Var(5).pow(3);
407
408 assert_eq!(expr.leading_term_with_degree(), (3, expected_expr));
409 }
410
411 #[test]
412 fn test_remap_vars_with_too_few_vars() {
413 type F = BinaryField8b;
414 let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(1)).pow(3);
415 assert_matches!(expr.remap_vars(&[5]), Err(Error::IncorrectArgumentLength { .. }));
416 }
417
418 #[test]
419 fn test_remap_vars_works() {
420 type F = BinaryField8b;
421 let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(1)).pow(3);
422 let new_expr = expr.remap_vars(&[5, 3]);
423
424 let expected = ((ArithExpr::Var(5) + ArithExpr::Const(F::ONE)) * ArithExpr::Var(3)).pow(3);
425 assert_eq!(new_expr.unwrap(), expected);
426 }
427
428 #[test]
429 fn test_expression_upcast() {
430 type F8 = BinaryField8b;
431 type F = BinaryField128b;
432
433 let expr = ((ArithExpr::Var(0) + ArithExpr::Const(F8::ONE))
434 * ArithExpr::Const(F8::new(222)))
435 .pow(3);
436
437 let expected =
438 ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Const(F::new(222))).pow(3);
439 assert_eq!(expr.convert_field::<F>(), expected);
440 }
441
442 #[test]
443 fn test_expression_downcast() {
444 type F8 = BinaryField8b;
445 type F = BinaryField128b;
446
447 let expr =
448 ((ArithExpr::Var(0) + ArithExpr::Const(F::ONE)) * ArithExpr::Const(F::new(222))).pow(3);
449
450 assert!(expr.try_convert_field::<BinaryField1b>().is_err());
451
452 let expected = ((ArithExpr::Var(0) + ArithExpr::Const(F8::ONE))
453 * ArithExpr::Const(F8::new(222)))
454 .pow(3);
455 assert_eq!(expr.try_convert_field::<BinaryField8b>().unwrap(), expected);
456 }
457
458 #[test]
459 fn test_linear_normal_form() {
460 type F = BinaryField128b;
461 use ArithExpr::{Const, Var};
462 let expr = Const(F::new(133))
463 + Const(F::new(42)) * Var(0)
464 + Var(2) + Const(F::new(11)) * Const(F::new(37)) * Var(3);
465 let normal_form = expr.linear_normal_form().unwrap();
466 assert_eq!(normal_form.constant, F::ZERO);
467 assert_eq!(
468 normal_form.var_coeffs,
469 vec![F::new(42), F::ZERO, F::ONE, F::new(11) * F::new(37)]
470 );
471 }
472}