binius_macros/
arith_circuit_poly.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use quote::{quote, ToTokens};
4use syn::{bracketed, parse::Parse, parse_quote, spanned::Spanned, Token};
5
6#[derive(Debug)]
7pub(crate) struct ArithCircuitPolyItem {
8	poly: syn::Expr,
9	field_name: syn::Ident,
10}
11
12impl ToTokens for ArithCircuitPolyItem {
13	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
14		let Self { poly, field_name } = self;
15
16		tokens.extend(quote! {
17			{
18				use binius_field::Field;
19				use binius_math::ArithExpr as Expr;
20
21				binius_core::polynomial::ArithCircuitPoly::<binius_field::#field_name>::new(#poly)
22			}
23		});
24	}
25}
26
27impl Parse for ArithCircuitPolyItem {
28	fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
29		let vars: Vec<syn::Ident> = {
30			let content;
31			bracketed!(content in input);
32			let vars = content.parse_terminated(syn::Ident::parse, Token![,])?;
33			vars.into_iter().collect()
34		};
35		input.parse::<Token![=]>()?;
36		let poly_packed = input.parse::<syn::Expr>()?;
37		let poly = flatten_expr(&poly_packed, &vars)?;
38
39		input.parse::<Token![,]>()?;
40
41		let field_name = input.parse()?;
42
43		Ok(Self { poly, field_name })
44	}
45}
46
47fn flatten_expr(expr: &syn::Expr, vars: &[syn::Ident]) -> Result<syn::Expr, syn::Error> {
48	match expr.clone() {
49		syn::Expr::Lit(exprlit) => {
50			if let syn::Lit::Int(int) = &exprlit.lit {
51				match &*int.to_string() {
52					"0" => Ok(parse_quote!(Expr::Const(Field::ZERO))),
53					"1" => Ok(parse_quote!(Expr::Const(Field::ONE))),
54					_ => Err(syn::Error::new(expr.span(), "Unsupported integer")),
55				}
56			} else {
57				Err(syn::Error::new(expr.span(), "Unsupported literal"))
58			}
59		}
60		syn::Expr::Path(p) => {
61			for (i, var) in vars.iter().enumerate() {
62				if p.path.is_ident(var) {
63					return Ok(parse_quote!(Expr::Var(#i)));
64				}
65			}
66			Err(syn::Error::new(expr.span(), "Unknown variable"))
67		}
68		syn::Expr::Paren(paren) => flatten_expr(&paren.expr, vars),
69		syn::Expr::Binary(binary) => {
70			let left = flatten_expr(&binary.left, vars)?;
71			let right = flatten_expr(&binary.right, vars)?;
72			match binary.op {
73				syn::BinOp::Add(_) | syn::BinOp::Sub(_) => Ok(parse_quote!((#left + #right))),
74				syn::BinOp::Mul(_) => Ok(parse_quote!((#left * #right))),
75				expr => Err(syn::Error::new(expr.span(), "Unsupported binop")),
76			}
77		}
78		_ => {
79			todo!()
80		}
81	}
82}