1use binius_field::{ExtensionField, Field, TowerField};
4use binius_math::{ArithCircuit, ArithCircuitStep, ArithExpr};
5use getset::{CopyGetters, Getters};
6
7use super::{column::Col, table::TableId};
8
9#[derive(Debug)]
11pub struct ZeroConstraint<F: Field> {
12 pub name: String,
13 pub expr: ArithCircuit<F>,
14 pub tower_level: usize,
15}
16
17#[derive(Debug, Clone, Getters, CopyGetters)]
21pub struct Expr<F: TowerField, const V: usize> {
22 #[get_copy = "pub"]
23 table_id: TableId,
24 #[get = "pub"]
25 expr: ArithExpr<F>,
26}
27
28impl<F: TowerField, const V: usize> Expr<F, V> {
29 pub fn degree(&self) -> usize {
31 ArithCircuit::from(&self.expr).degree()
32 }
33
34 pub fn pow(self, exp: u64) -> Self {
36 Self {
37 table_id: self.table_id,
38 expr: self.expr.pow(exp),
39 }
40 }
41}
42
43impl<F: TowerField, const V: usize> From<Col<F, V>> for Expr<F, V> {
44 fn from(value: Col<F, V>) -> Self {
45 Expr {
46 table_id: value.table_id,
47 expr: ArithExpr::Var(value.partition_index.0),
48 }
49 }
50}
51
52impl<F: TowerField, const V: usize> std::ops::Add<Self> for Col<F, V> {
53 type Output = Expr<F, V>;
54
55 fn add(self, rhs: Self) -> Self::Output {
56 assert_eq!(self.table_id, rhs.table_id);
57
58 let lhs_expr = ArithExpr::Var(self.partition_index.0);
59 let rhs_expr = ArithExpr::Var(rhs.partition_index.0);
60
61 Expr {
62 table_id: self.table_id,
63 expr: lhs_expr + rhs_expr,
64 }
65 }
66}
67
68impl<F: TowerField, const V: usize> std::ops::Add<Col<F, V>> for Expr<F, V> {
69 type Output = Expr<F, V>;
70
71 fn add(self, rhs: Col<F, V>) -> Self::Output {
72 assert_eq!(self.table_id, rhs.table_id);
73
74 let rhs_expr = ArithExpr::Var(rhs.partition_index.0);
75 Expr {
76 table_id: self.table_id,
77 expr: self.expr + rhs_expr,
78 }
79 }
80}
81
82impl<F: TowerField, const V: usize> std::ops::Add<Expr<F, V>> for Expr<F, V> {
83 type Output = Expr<F, V>;
84
85 fn add(self, rhs: Expr<F, V>) -> Self::Output {
86 assert_eq!(self.table_id, rhs.table_id);
87 Expr {
88 table_id: self.table_id,
89 expr: self.expr + rhs.expr,
90 }
91 }
92}
93
94impl<F: TowerField, const V: usize> std::ops::Add<F> for Expr<F, V> {
95 type Output = Expr<F, V>;
96
97 fn add(self, rhs: F) -> Self::Output {
98 Expr {
99 table_id: self.table_id,
100 expr: self.expr + ArithExpr::Const(rhs),
101 }
102 }
103}
104
105impl<F: TowerField, const V: usize> std::ops::Add<Expr<F, V>> for Col<F, V> {
106 type Output = Expr<F, V>;
107
108 fn add(self, rhs: Expr<F, V>) -> Self::Output {
109 Expr::from(self) + rhs
110 }
111}
112
113impl<F: TowerField, const V: usize> std::ops::Add<F> for Col<F, V> {
114 type Output = Expr<F, V>;
115
116 fn add(self, rhs: F) -> Self::Output {
117 Expr::from(self) + rhs
118 }
119}
120
121impl<F: TowerField, const V: usize> std::ops::Sub<Self> for Col<F, V> {
122 type Output = Expr<F, V>;
123
124 fn sub(self, rhs: Self) -> Self::Output {
125 assert_eq!(self.table_id, rhs.table_id);
126 let lhs_expr = ArithExpr::Var(self.partition_index.0);
127 let rhs_expr = ArithExpr::Var(rhs.partition_index.0);
128
129 Expr {
130 table_id: self.table_id,
131 expr: lhs_expr - rhs_expr,
132 }
133 }
134}
135
136impl<F: TowerField, const V: usize> std::ops::Sub<Col<F, V>> for Expr<F, V> {
137 type Output = Expr<F, V>;
138
139 fn sub(self, rhs: Col<F, V>) -> Self::Output {
140 self - Expr::from(rhs)
141 }
142}
143
144impl<F: TowerField, const V: usize> std::ops::Sub<Expr<F, V>> for Expr<F, V> {
145 type Output = Expr<F, V>;
146
147 fn sub(self, rhs: Expr<F, V>) -> Self::Output {
148 assert_eq!(self.table_id, rhs.table_id);
149 Expr {
150 table_id: self.table_id,
151 expr: self.expr - rhs.expr,
152 }
153 }
154}
155
156impl<F: TowerField, const V: usize> std::ops::Sub<F> for Expr<F, V> {
157 type Output = Expr<F, V>;
158
159 fn sub(self, rhs: F) -> Self::Output {
160 Expr {
161 table_id: self.table_id,
162 expr: self.expr - ArithExpr::Const(rhs),
163 }
164 }
165}
166
167impl<F: TowerField, const V: usize> std::ops::Sub<Expr<F, V>> for Col<F, V> {
168 type Output = Expr<F, V>;
169
170 fn sub(self, rhs: Expr<F, V>) -> Self::Output {
171 Expr::from(self) - rhs
172 }
173}
174
175impl<F: TowerField, const V: usize> std::ops::Sub<F> for Col<F, V> {
176 type Output = Expr<F, V>;
177
178 fn sub(self, rhs: F) -> Self::Output {
179 Expr::from(self) - rhs
180 }
181}
182
183impl<F: TowerField, const V: usize> std::ops::Mul<Self> for Col<F, V> {
184 type Output = Expr<F, V>;
185
186 fn mul(self, rhs: Self) -> Self::Output {
187 Expr::from(self) * Expr::from(rhs)
188 }
189}
190
191impl<F: TowerField, const V: usize> std::ops::Mul<Col<F, V>> for Expr<F, V> {
192 type Output = Expr<F, V>;
193
194 fn mul(self, rhs: Col<F, V>) -> Self::Output {
195 self * Expr::from(rhs)
196 }
197}
198
199impl<F: TowerField, const V: usize> std::ops::Mul<Expr<F, V>> for Expr<F, V> {
200 type Output = Expr<F, V>;
201
202 fn mul(self, rhs: Expr<F, V>) -> Self::Output {
203 assert_eq!(self.table_id, rhs.table_id);
204 Expr {
205 table_id: self.table_id,
206 expr: self.expr * rhs.expr,
207 }
208 }
209}
210
211impl<F: TowerField, const V: usize> std::ops::Mul<F> for Expr<F, V> {
212 type Output = Expr<F, V>;
213
214 fn mul(self, rhs: F) -> Self::Output {
215 Expr {
216 table_id: self.table_id,
217 expr: self.expr * ArithExpr::Const(rhs),
218 }
219 }
220}
221
222impl<F: TowerField, const V: usize> std::ops::Mul<Expr<F, V>> for Col<F, V> {
223 type Output = Expr<F, V>;
224
225 fn mul(self, rhs: Expr<F, V>) -> Self::Output {
226 Expr::from(self) * rhs
227 }
228}
229
230impl<F: TowerField, const V: usize> std::ops::Mul<F> for Col<F, V> {
231 type Output = Expr<F, V>;
232
233 fn mul(self, rhs: F) -> Self::Output {
234 Expr::from(self) * rhs
235 }
236}
237
238pub fn upcast_expr<F, FSub, const V: usize>(expr: Expr<FSub, V>) -> Expr<F, V>
240where
241 FSub: TowerField,
242 F: TowerField + ExtensionField<FSub>,
243{
244 let Expr { table_id, expr } = expr;
245 Expr {
246 table_id,
247 expr: ArithCircuit::from(&expr).convert_field().into(),
248 }
249}
250
251pub struct ArithExprNamedVars<'a, F: TowerField>(pub &'a ArithCircuit<F>, pub &'a [String]);
253
254impl<F: TowerField> std::fmt::Display for ArithExprNamedVars<'_, F> {
255 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
256 fn write_step<F: TowerField>(
257 f: &mut std::fmt::Formatter<'_>,
258 step: usize,
259 steps: &[ArithCircuitStep<F>],
260 names: &[String],
261 ) -> std::fmt::Result {
262 match &steps[step] {
263 ArithCircuitStep::Const(v) => write!(f, "{v}"),
264 ArithCircuitStep::Var(i) => write!(f, "{}", names[*i]),
265 ArithCircuitStep::Add(x, y) => {
266 write_step(f, *x, steps, names)?;
267 write!(f, " + ")?;
268 write_step(f, *y, steps, names)
269 }
270 ArithCircuitStep::Mul(x, y) => {
271 write!(f, "(")?;
272 write_step(f, *x, steps, names)?;
273 write!(f, ") * (")?;
274 write_step(f, *y, steps, names)?;
275 write!(f, ")")
276 }
277 ArithCircuitStep::Pow(x, p) => {
278 write!(f, "(")?;
279 write_step(f, *x, steps, names)?;
280 write!(f, ")^{p}")
281 }
282 }
283 }
284
285 write_step(f, 0, self.0.steps(), self.1)
286 }
287}