1use quote::{quote, ToTokens};
4use syn::{bracketed, parse::Parse, parse_quote, spanned::Spanned, Token};
5
6#[derive(Debug)]
7pub(crate) struct CompositionPolyItem {
8 pub is_anonymous: bool,
9 pub name: syn::Ident,
10 pub vars: Vec<syn::Ident>,
11 pub poly_packed: syn::Expr,
12 pub expr: syn::Expr,
13 pub scalar_type: syn::Type,
14 pub degree: usize,
15}
16
17impl ToTokens for CompositionPolyItem {
18 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
19 let Self {
20 is_anonymous,
21 name,
22 vars,
23 poly_packed,
24 expr,
25 scalar_type,
26 degree,
27 } = self;
28 let n_vars = vars.len();
29
30 let mut eval_single = poly_packed.clone();
31 subst_vars(&mut eval_single, vars, &|i| parse_quote!(unsafe {*query.get_unchecked(#i)}))
32 .expect("Failed to substitute vars");
33
34 let mut eval_batch = poly_packed.clone();
35 subst_vars(
36 &mut eval_batch,
37 vars,
38 &|i| parse_quote!(unsafe {*batch_query.rows().get_unchecked(#i).get_unchecked(row)}),
39 )
40 .expect("Failed to substitute vars");
41
42 let result = quote! {
43 #[derive(Debug, Clone, Copy)]
44 struct #name;
45
46 impl<P> binius_math::CompositionPoly<P> for #name
47 where
48 P: binius_field::PackedField<Scalar: binius_field::ExtensionField<#scalar_type>>,
49 {
50 fn n_vars(&self) -> usize {
51 #n_vars
52 }
53
54 fn degree(&self) -> usize {
55 #degree
56 }
57
58 fn binary_tower_level(&self) -> usize {
59 0
60 }
61
62 fn expression(&self) -> binius_math::ArithExpr<P::Scalar> {
63 (#expr).convert_field()
64 }
65
66 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
67 if query.len() != #n_vars {
68 return Err(binius_math::Error::IncorrectQuerySize { expected: #n_vars });
69 }
70 Ok(#eval_single)
71 }
72
73 fn batch_evaluate(
74 &self,
75 batch_query: &binius_math::RowsBatchRef<P>,
76 evals: &mut [P],
77 ) -> Result<(), binius_math::Error> {
78 if batch_query.row_len() != #n_vars {
79 return Err(binius_math::Error::IncorrectQuerySize { expected: #n_vars });
80 }
81
82 for row in 0..batch_query.rows()[0].len() {
83 evals[row] = #eval_batch;
84 }
85
86 Ok(())
87 }
88 }
89 };
90
91 if *is_anonymous {
92 tokens.extend(quote! {
95 {
96 #result
97 #name
98 }
99 });
100 } else {
101 tokens.extend(result);
102 }
103 }
104}
105
106impl Parse for CompositionPolyItem {
107 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
108 let name = input.parse::<syn::Ident>();
109 let is_anonymous = name.is_err();
110 let name = name.unwrap_or_else(|_| parse_quote!(UnnamedCompositionPoly));
111 let vars = {
112 let content;
113 bracketed!(content in input);
114 let vars = content.parse_terminated(syn::Ident::parse, Token![,])?;
115 vars.into_iter().collect::<Vec<_>>()
116 };
117 input.parse::<Token![=]>()?;
118 let mut poly_packed = input.parse::<syn::Expr>()?;
119 let mut expr = poly_packed.clone();
120
121 let degree = poly_degree(&poly_packed)?;
122 rewrite_literals(&mut poly_packed, &replace_packed_literals)?;
123
124 subst_vars(&mut expr, &vars, &|i| parse_quote!(binius_math::ArithExpr::Var(#i)))?;
125 rewrite_literals(&mut expr, &replace_expr_literals)?;
126
127 let scalar_type = if input.is_empty() {
128 parse_quote!(binius_field::BinaryField1b)
129 } else {
130 input.parse::<Token![,]>()?;
131
132 input.parse()?
133 };
134
135 Ok(Self {
136 is_anonymous,
137 name,
138 vars,
139 poly_packed,
140 expr,
141 scalar_type,
142 degree,
143 })
144 }
145}
146
147fn poly_degree(expr: &syn::Expr) -> Result<usize, syn::Error> {
150 Ok(match expr.clone() {
151 syn::Expr::Lit(_) => 0,
152 syn::Expr::Path(_) => 1,
153 syn::Expr::Paren(paren) => poly_degree(&paren.expr)?,
154 syn::Expr::Binary(binary) => {
155 let op = binary.op;
156 let left = poly_degree(&binary.left)?;
157 let right = poly_degree(&binary.right)?;
158 match op {
159 syn::BinOp::Add(_) | syn::BinOp::Sub(_) => std::cmp::max(left, right),
160 syn::BinOp::Mul(_) => left + right,
161 expr => {
162 return Err(syn::Error::new(expr.span(), "Unsupported binop"));
163 }
164 }
165 }
166 expr => return Err(syn::Error::new(expr.span(), "Unsupported expression")),
167 })
168}
169
170fn replace_packed_literals(literal: &syn::LitInt) -> Result<syn::Expr, syn::Error> {
172 Ok(match &*literal.to_string() {
173 "0" => parse_quote!(P::zero()),
174 "1" => parse_quote!(P::one()),
175 _ => return Err(syn::Error::new(literal.span(), "Unsupported integer")),
176 })
177}
178
179fn replace_expr_literals(literal: &syn::LitInt) -> Result<syn::Expr, syn::Error> {
181 Ok(match &*literal.to_string() {
182 "0" => parse_quote!(binius_math::ArithExpr::zero()),
183 "1" => parse_quote!(binius_math::ArithExpr::one()),
184 _ => return Err(syn::Error::new(literal.span(), "Unsupported integer")),
185 })
186}
187
188fn rewrite_literals(
190 expr: &mut syn::Expr,
191 f: &impl Fn(&syn::LitInt) -> Result<syn::Expr, syn::Error>,
192) -> Result<(), syn::Error> {
193 match expr {
194 syn::Expr::Lit(exprlit) => {
195 if let syn::Lit::Int(int) = &exprlit.lit {
196 *expr = f(int)?;
197 }
198 }
199 syn::Expr::Paren(paren) => {
200 rewrite_literals(&mut paren.expr, f)?;
201 }
202 syn::Expr::Binary(binary) => {
203 rewrite_literals(&mut binary.left, f)?;
204 rewrite_literals(&mut binary.right, f)?;
205 }
206 _ => {}
207 }
208 Ok(())
209}
210
211fn subst_vars(
213 expr: &mut syn::Expr,
214 vars: &[syn::Ident],
215 f: &impl Fn(usize) -> syn::Expr,
216) -> Result<(), syn::Error> {
217 match expr {
218 syn::Expr::Path(p) => {
219 for (i, var) in vars.iter().enumerate() {
220 if p.path.is_ident(var) {
221 *expr = f(i);
222 return Ok(());
223 }
224 }
225 Err(syn::Error::new(p.span(), "unknown variable"))
226 }
227 syn::Expr::Paren(paren) => subst_vars(&mut paren.expr, vars, f),
228 syn::Expr::Binary(binary) => {
229 subst_vars(&mut binary.left, vars, f)?;
230 subst_vars(&mut binary.right, vars, f)
231 }
232 _ => Ok(()),
233 }
234}