binius_macros/
arith_expr.rs1use quote::{quote, ToTokens};
4use syn::{bracketed, parse::Parse, parse_quote, spanned::Spanned, Token};
5
6#[derive(Debug)]
7pub(crate) struct ArithExprItem(syn::Expr);
8
9impl ToTokens for ArithExprItem {
10 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
11 let Self(expr) = self;
12 tokens.extend(quote!(#expr));
13 }
14}
15
16impl Parse for ArithExprItem {
17 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
18 let prefixed_field = input.parse::<syn::Path>().ok();
19 let vars: Vec<syn::Ident> = {
20 let content;
21 bracketed!(content in input);
22 let vars = content.parse_terminated(syn::Ident::parse, Token![,])?;
23 vars.into_iter().collect()
24 };
25 input.parse::<Token![=]>()?;
26 let mut expr = input.parse::<syn::Expr>()?;
27 rewrite_expr(&mut expr, &vars, &prefixed_field)?;
28 Ok(Self(expr))
29 }
30}
31
32fn rewrite_expr(
33 expr: &mut syn::Expr,
34 vars: &[syn::Ident],
35 prefixed_field: &Option<syn::Path>,
36) -> Result<(), syn::Error> {
37 let default_field = parse_quote!(binius_field::BinaryField1b);
38 let field = prefixed_field.as_ref().unwrap_or(&default_field);
39 match expr {
40 syn::Expr::Path(path) => {
41 let mut var_index = None;
42 for (i, var) in vars.iter().enumerate() {
43 if path.path.is_ident(var) {
44 var_index = Some(i);
45 }
46 }
47 if let Some(i) = var_index {
48 *expr = parse_quote!(binius_math::ArithExpr::<#field>::Var(#i));
49 } else {
50 return Err(syn::Error::new(path.span(), "Unknown variable"));
51 }
52 }
53 syn::Expr::Lit(exprlit) => {
54 if let syn::Lit::Int(int) = &exprlit.lit {
55 let value: syn::Expr = match &*int.to_string() {
56 "0" => parse_quote!(binius_field::Field::ZERO),
57 "1" => parse_quote!(binius_field::Field::ONE),
58 _ => match prefixed_field {
59 Some(field) => parse_quote!(#field::new(#int)),
60 _ => return Err(syn::Error::new(expr.span(), "You need to specify an explicit field to use constants other than 0 or 1"))
61 }
62 };
63 *expr = parse_quote!(binius_math::ArithExpr::<#field>::Const(#value));
64 }
65 }
66 syn::Expr::Paren(paren) => {
67 rewrite_expr(&mut paren.expr, vars, prefixed_field)?;
68 }
69 syn::Expr::Binary(binary) => {
70 rewrite_expr(&mut binary.left, vars, prefixed_field)?;
71 rewrite_expr(&mut binary.right, vars, prefixed_field)?;
72 }
73 _ => {}
74 }
75 Ok(())
76}