binius_macros/
arith_expr.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use quote::{quote, ToTokens};
4use syn::{bracketed, parse::Parse, parse_quote, spanned::Spanned, Token};
5
6#[derive(Debug)]
7pub(crate) struct ArithExprItem(syn::Expr);
8
9impl ToTokens for ArithExprItem {
10	fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
11		let Self(expr) = self;
12		tokens.extend(quote!(#expr));
13	}
14}
15
16impl Parse for ArithExprItem {
17	fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
18		let prefixed_field = input.parse::<syn::Path>().ok();
19		let vars: Vec<syn::Ident> = {
20			let content;
21			bracketed!(content in input);
22			let vars = content.parse_terminated(syn::Ident::parse, Token![,])?;
23			vars.into_iter().collect()
24		};
25		input.parse::<Token![=]>()?;
26		let mut expr = input.parse::<syn::Expr>()?;
27		rewrite_expr(&mut expr, &vars, &prefixed_field)?;
28		Ok(Self(expr))
29	}
30}
31
32fn rewrite_expr(
33	expr: &mut syn::Expr,
34	vars: &[syn::Ident],
35	prefixed_field: &Option<syn::Path>,
36) -> Result<(), syn::Error> {
37	let default_field = parse_quote!(binius_field::BinaryField1b);
38	let field = prefixed_field.as_ref().unwrap_or(&default_field);
39	match expr {
40		syn::Expr::Path(path) => {
41			let mut var_index = None;
42			for (i, var) in vars.iter().enumerate() {
43				if path.path.is_ident(var) {
44					var_index = Some(i);
45				}
46			}
47			if let Some(i) = var_index {
48				*expr = parse_quote!(binius_math::ArithExpr::<#field>::Var(#i));
49			} else {
50				return Err(syn::Error::new(path.span(), "Unknown variable"));
51			}
52		}
53		syn::Expr::Lit(exprlit) => {
54			if let syn::Lit::Int(int) = &exprlit.lit {
55				let value: syn::Expr = match &*int.to_string() {
56					"0" => parse_quote!(binius_field::Field::ZERO),
57					"1" => parse_quote!(binius_field::Field::ONE),
58					_ => match prefixed_field {
59						Some(field) => parse_quote!(#field::new(#int)),
60						_ => return Err(syn::Error::new(expr.span(), "You need to specify an explicit field to use constants other than 0 or 1"))
61					}
62				};
63				*expr = parse_quote!(binius_math::ArithExpr::<#field>::Const(#value));
64			}
65		}
66		syn::Expr::Paren(paren) => {
67			rewrite_expr(&mut paren.expr, vars, prefixed_field)?;
68		}
69		syn::Expr::Binary(binary) => {
70			rewrite_expr(&mut binary.left, vars, prefixed_field)?;
71			rewrite_expr(&mut binary.right, vars, prefixed_field)?;
72		}
73		_ => {}
74	}
75	Ok(())
76}