binius_macros/
composition_poly.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use quote::{quote, ToTokens};
4use syn::{bracketed, parse::Parse, parse_quote, spanned::Spanned, Token};
5
6#[derive(Debug)]
7pub(crate) struct CompositionPolyItem {
8	pub is_anonymous: bool,
9	pub name: syn::Ident,
10	pub vars: Vec<syn::Ident>,
11	pub poly_packed: syn::Expr,
12	pub expr: syn::Expr,
13	pub scalar_type: syn::Type,
14	pub degree: usize,
15}
16
17impl ToTokens for CompositionPolyItem {
18	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
19		let Self {
20			is_anonymous,
21			name,
22			vars,
23			poly_packed,
24			expr,
25			scalar_type,
26			degree,
27		} = self;
28		let n_vars = vars.len();
29
30		let mut eval_single = poly_packed.clone();
31		subst_vars(&mut eval_single, vars, &|i| parse_quote!(unsafe {*query.get_unchecked(#i)}))
32			.expect("Failed to substitute vars");
33
34		let mut eval_batch = poly_packed.clone();
35		subst_vars(
36			&mut eval_batch,
37			vars,
38			&|i| parse_quote!(unsafe {*batch_query.rows().get_unchecked(#i).get_unchecked(row)}),
39		)
40		.expect("Failed to substitute vars");
41
42		let result = quote! {
43			#[derive(Debug, Clone, Copy)]
44			struct #name;
45
46			impl<P> binius_math::CompositionPoly<P> for #name
47			where
48				P: binius_field::PackedField<Scalar: binius_field::ExtensionField<#scalar_type>>,
49			{
50				fn n_vars(&self) -> usize {
51					#n_vars
52				}
53
54				fn degree(&self) -> usize {
55					#degree
56				}
57
58				fn binary_tower_level(&self) -> usize {
59					0
60				}
61
62				fn expression(&self) -> binius_math::ArithExpr<P::Scalar> {
63					(#expr).convert_field()
64				}
65
66				fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
67					if query.len() != #n_vars {
68						return Err(binius_math::Error::IncorrectQuerySize { expected: #n_vars });
69					}
70					Ok(#eval_single)
71				}
72
73				fn batch_evaluate(
74					&self,
75					batch_query: &binius_math::RowsBatchRef<P>,
76					evals: &mut [P],
77				) -> Result<(), binius_math::Error> {
78					if batch_query.row_len() != #n_vars {
79						return Err(binius_math::Error::IncorrectQuerySize { expected: #n_vars });
80					}
81
82					for row in 0..batch_query.rows()[0].len() {
83						evals[row] = #eval_batch;
84					}
85
86					Ok(())
87				}
88			}
89		};
90
91		if *is_anonymous {
92			// In this case we return an instance of our struct rather
93			// than defining the struct within the current scope
94			tokens.extend(quote! {
95				{
96					#result
97					#name
98				}
99			});
100		} else {
101			tokens.extend(result);
102		}
103	}
104}
105
106impl Parse for CompositionPolyItem {
107	fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
108		let name = input.parse::<syn::Ident>();
109		let is_anonymous = name.is_err();
110		let name = name.unwrap_or_else(|_| parse_quote!(UnnamedCompositionPoly));
111		let vars = {
112			let content;
113			bracketed!(content in input);
114			let vars = content.parse_terminated(syn::Ident::parse, Token![,])?;
115			vars.into_iter().collect::<Vec<_>>()
116		};
117		input.parse::<Token![=]>()?;
118		let mut poly_packed = input.parse::<syn::Expr>()?;
119		let mut expr = poly_packed.clone();
120
121		let degree = poly_degree(&poly_packed)?;
122		rewrite_literals(&mut poly_packed, &replace_packed_literals)?;
123
124		subst_vars(&mut expr, &vars, &|i| parse_quote!(binius_math::ArithExpr::Var(#i)))?;
125		rewrite_literals(&mut expr, &replace_expr_literals)?;
126
127		let scalar_type = if input.is_empty() {
128			parse_quote!(binius_field::BinaryField1b)
129		} else {
130			input.parse::<Token![,]>()?;
131
132			input.parse()?
133		};
134
135		Ok(Self {
136			is_anonymous,
137			name,
138			vars,
139			poly_packed,
140			expr,
141			scalar_type,
142			degree,
143		})
144	}
145}
146
147/// Make sure to run this before rewrite_literals as it will rewrite Lit to Path,
148/// which will mess up the degree
149fn poly_degree(expr: &syn::Expr) -> Result<usize, syn::Error> {
150	Ok(match expr.clone() {
151		syn::Expr::Lit(_) => 0,
152		syn::Expr::Path(_) => 1,
153		syn::Expr::Paren(paren) => poly_degree(&paren.expr)?,
154		syn::Expr::Binary(binary) => {
155			let op = binary.op;
156			let left = poly_degree(&binary.left)?;
157			let right = poly_degree(&binary.right)?;
158			match op {
159				syn::BinOp::Add(_) | syn::BinOp::Sub(_) => std::cmp::max(left, right),
160				syn::BinOp::Mul(_) => left + right,
161				expr => {
162					return Err(syn::Error::new(expr.span(), "Unsupported binop"));
163				}
164			}
165		}
166		expr => return Err(syn::Error::new(expr.span(), "Unsupported expression")),
167	})
168}
169
170/// Replace literals to P::zero() and P::one() to be used in `evaluate` and `batch_evaluate`.
171fn replace_packed_literals(literal: &syn::LitInt) -> Result<syn::Expr, syn::Error> {
172	Ok(match &*literal.to_string() {
173		"0" => parse_quote!(P::zero()),
174		"1" => parse_quote!(P::one()),
175		_ => return Err(syn::Error::new(literal.span(), "Unsupported integer")),
176	})
177}
178
179/// Replace literals to Expr::zero() and Expr::one() to be used in `expression` method.
180fn replace_expr_literals(literal: &syn::LitInt) -> Result<syn::Expr, syn::Error> {
181	Ok(match &*literal.to_string() {
182		"0" => parse_quote!(binius_math::ArithExpr::zero()),
183		"1" => parse_quote!(binius_math::ArithExpr::one()),
184		_ => return Err(syn::Error::new(literal.span(), "Unsupported integer")),
185	})
186}
187
188/// Replace literals in an expression
189fn rewrite_literals(
190	expr: &mut syn::Expr,
191	f: &impl Fn(&syn::LitInt) -> Result<syn::Expr, syn::Error>,
192) -> Result<(), syn::Error> {
193	match expr {
194		syn::Expr::Lit(exprlit) => {
195			if let syn::Lit::Int(int) = &exprlit.lit {
196				*expr = f(int)?;
197			}
198		}
199		syn::Expr::Paren(paren) => {
200			rewrite_literals(&mut paren.expr, f)?;
201		}
202		syn::Expr::Binary(binary) => {
203			rewrite_literals(&mut binary.left, f)?;
204			rewrite_literals(&mut binary.right, f)?;
205		}
206		_ => {}
207	}
208	Ok(())
209}
210
211/// Substitutes variables in an expression with a slice access
212fn subst_vars(
213	expr: &mut syn::Expr,
214	vars: &[syn::Ident],
215	f: &impl Fn(usize) -> syn::Expr,
216) -> Result<(), syn::Error> {
217	match expr {
218		syn::Expr::Path(p) => {
219			for (i, var) in vars.iter().enumerate() {
220				if p.path.is_ident(var) {
221					*expr = f(i);
222					return Ok(());
223				}
224			}
225			Err(syn::Error::new(p.span(), "unknown variable"))
226		}
227		syn::Expr::Paren(paren) => subst_vars(&mut paren.expr, vars, f),
228		syn::Expr::Binary(binary) => {
229			subst_vars(&mut binary.left, vars, f)?;
230			subst_vars(&mut binary.right, vars, f)
231		}
232		_ => Ok(()),
233	}
234}