binius_macros/
lib.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3extern crate proc_macro;
4mod arith_circuit_poly;
5mod arith_expr;
6mod composition_poly;
7
8use proc_macro::TokenStream;
9use quote::{quote, ToTokens};
10use syn::{parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, ItemImpl};
11
12use crate::{
13	arith_circuit_poly::ArithCircuitPolyItem, arith_expr::ArithExprItem,
14	composition_poly::CompositionPolyItem,
15};
16
17/// Useful for concisely creating structs that implement CompositionPoly.
18/// This currently only supports creating composition polynomials of tower level 0.
19///
20/// ```
21/// use binius_macros::composition_poly;
22/// use binius_math::CompositionPoly;
23/// use binius_field::{Field, BinaryField1b as F};
24///
25/// // Defines named struct without any fields that implements CompositionPoly
26/// composition_poly!(MyComposition[x, y, z] = x + y * z);
27/// assert_eq!(
28///     MyComposition.evaluate(&[F::ONE, F::ONE, F::ONE]).unwrap(),
29///     F::ZERO
30/// );
31///
32/// // If you omit the name you get an anonymous instance instead, which can be used inline
33/// assert_eq!(
34///     composition_poly!([x, y, z] = x + y * z)
35///         .evaluate(&[F::ONE, F::ONE, F::ONE]).unwrap(),
36///     F::ZERO
37/// );
38/// ```
39#[proc_macro]
40pub fn composition_poly(input: TokenStream) -> TokenStream {
41	parse_macro_input!(input as CompositionPolyItem)
42		.into_token_stream()
43		.into()
44}
45
46/// Define polynomial expressions compactly using named positional arguments
47///
48/// ```
49/// use binius_macros::arith_expr;
50/// use binius_field::{Field, BinaryField1b, BinaryField8b};
51/// use binius_math::ArithExpr as Expr;
52///
53/// assert_eq!(
54///     arith_expr!([x, y] = x + y + 1),
55///     Expr::Var(0) + Expr::Var(1) + Expr::Const(BinaryField1b::ONE)
56/// );
57///
58/// assert_eq!(
59///     arith_expr!(BinaryField8b[x] = 3*x + 15),
60///     Expr::Const(BinaryField8b::new(3)) * Expr::Var(0) + Expr::Const(BinaryField8b::new(15))
61/// );
62/// ```
63#[proc_macro]
64pub fn arith_expr(input: TokenStream) -> TokenStream {
65	parse_macro_input!(input as ArithExprItem)
66		.into_token_stream()
67		.into()
68}
69
70#[proc_macro]
71pub fn arith_circuit_poly(input: TokenStream) -> TokenStream {
72	parse_macro_input!(input as ArithCircuitPolyItem)
73		.into_token_stream()
74		.into()
75}
76
77/// Derives the trait binius_utils::DeserializeBytes for a struct or enum
78///
79/// See the DeserializeBytes derive macro docs for examples/tests
80#[proc_macro_derive(SerializeBytes)]
81pub fn derive_serialize_bytes(input: TokenStream) -> TokenStream {
82	let input: DeriveInput = parse_macro_input!(input);
83	let span = input.span();
84	let name = input.ident;
85	let mut generics = input.generics.clone();
86	generics.type_params_mut().for_each(|type_param| {
87		type_param
88			.bounds
89			.push(parse_quote!(binius_utils::SerializeBytes))
90	});
91	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
92	let body = match input.data {
93		Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(),
94		Data::Struct(data) => {
95			let fields = field_names(data.fields, None);
96			quote! {
97				#(binius_utils::SerializeBytes::serialize(&self.#fields, &mut write_buf, mode)?;)*
98			}
99		}
100		Data::Enum(data) => {
101			let variants = data
102				.variants
103				.into_iter()
104				.enumerate()
105				.map(|(i, variant)| {
106					let variant_ident = &variant.ident;
107					let variant_index = i as u8;
108					let fields = field_names(variant.fields.clone(), Some("field_"));
109					let serialize_variant = quote! {
110						binius_utils::SerializeBytes::serialize(&#variant_index, &mut write_buf, mode)?;
111						#(binius_utils::SerializeBytes::serialize(#fields, &mut write_buf, mode)?;)*
112					};
113					match variant.fields {
114						Fields::Named(_) => quote! {
115							Self::#variant_ident { #(#fields),* } => {
116								#serialize_variant
117							}
118						},
119						Fields::Unnamed(_) => quote! {
120							Self::#variant_ident(#(#fields),*) => {
121								#serialize_variant
122							}
123						},
124						Fields::Unit => quote! {
125							Self::#variant_ident => {
126								#serialize_variant
127							}
128						},
129					}
130				})
131				.collect::<Vec<_>>();
132
133			quote! {
134				match self {
135					#(#variants)*
136				}
137			}
138		}
139	};
140	quote! {
141		impl #impl_generics binius_utils::SerializeBytes for #name #ty_generics #where_clause {
142			fn serialize(&self, mut write_buf: impl binius_utils::bytes::BufMut, mode: binius_utils::SerializationMode) -> Result<(), binius_utils::SerializationError> {
143				#body
144				Ok(())
145			}
146		}
147	}.into()
148}
149
150/// Derives the trait binius_utils::DeserializeBytes for a struct or enum
151///
152/// ```
153/// use binius_field::BinaryField128b;
154/// use binius_utils::{SerializeBytes, DeserializeBytes, SerializationMode};
155/// use binius_macros::{SerializeBytes, DeserializeBytes};
156///
157/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)]
158/// enum MyEnum {
159///     A(usize),
160///     B { x: u32, y: u32 },
161///     C
162/// }
163///
164///
165/// let mut buf = vec![];
166/// let value = MyEnum::B { x: 42, y: 1337 };
167/// MyEnum::serialize(&value, &mut buf, SerializationMode::Native).unwrap();
168/// assert_eq!(
169///     MyEnum::deserialize(buf.as_slice(), SerializationMode::Native).unwrap(),
170///     value
171/// );
172///
173///
174/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)]
175/// struct MyStruct<F> {
176///     data: Vec<F>
177/// }
178///
179/// let mut buf = vec![];
180/// let value = MyStruct {
181///    data: vec![BinaryField128b::new(1234), BinaryField128b::new(5678)]
182/// };
183/// MyStruct::serialize(&value, &mut buf, SerializationMode::CanonicalTower).unwrap();
184/// assert_eq!(
185///     MyStruct::<BinaryField128b>::deserialize(buf.as_slice(), SerializationMode::CanonicalTower).unwrap(),
186///     value
187/// );
188/// ```
189#[proc_macro_derive(DeserializeBytes)]
190pub fn derive_deserialize_bytes(input: TokenStream) -> TokenStream {
191	let input: DeriveInput = parse_macro_input!(input);
192	let span = input.span();
193	let name = input.ident;
194	let mut generics = input.generics.clone();
195	generics.type_params_mut().for_each(|type_param| {
196		type_param
197			.bounds
198			.push(parse_quote!(binius_utils::DeserializeBytes))
199	});
200	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
201	let deserialize_value = quote! {
202		binius_utils::DeserializeBytes::deserialize(&mut read_buf, mode)?
203	};
204	let body = match input.data {
205		Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(),
206		Data::Struct(data) => {
207			let fields = field_names(data.fields, None);
208			quote! {
209				Ok(Self {
210					#(#fields: #deserialize_value,)*
211				})
212			}
213		}
214		Data::Enum(data) => {
215			let variants = data
216				.variants
217				.into_iter()
218				.enumerate()
219				.map(|(i, variant)| {
220					let variant_ident = &variant.ident;
221					let variant_index: u8 = i as u8;
222					match variant.fields {
223						Fields::Named(fields) => {
224							let fields = fields
225								.named
226								.into_iter()
227								.map(|field| field.ident)
228								.map(|field_name| quote!(#field_name: #deserialize_value))
229								.collect::<Vec<_>>();
230
231							quote! {
232								#variant_index => Self::#variant_ident { #(#fields,)* }
233							}
234						}
235						Fields::Unnamed(fields) => {
236							let fields = fields
237								.unnamed
238								.into_iter()
239								.map(|_| quote!(#deserialize_value))
240								.collect::<Vec<_>>();
241
242							quote! {
243								#variant_index => Self::#variant_ident(#(#fields,)*)
244							}
245						}
246						Fields::Unit => quote! {
247							#variant_index => Self::#variant_ident
248						},
249					}
250				})
251				.collect::<Vec<_>>();
252
253			let name = name.to_string();
254			quote! {
255				let variant_index: u8 = #deserialize_value;
256				Ok(match variant_index {
257					#(#variants,)*
258					_ => {
259						return Err(binius_utils::SerializationError::UnknownEnumVariant {
260							name: #name,
261							index: variant_index
262						})
263					}
264				})
265			}
266		}
267	};
268	quote! {
269		impl #impl_generics binius_utils::DeserializeBytes for #name #ty_generics #where_clause {
270			fn deserialize(mut read_buf: impl binius_utils::bytes::Buf, mode: binius_utils::SerializationMode) -> Result<Self, binius_utils::SerializationError>
271			where
272				Self: Sized
273			{
274				#body
275			}
276		}
277	}
278	.into()
279}
280
281/// Use on an impl block for MultivariatePoly, to automatically implement erased_serialize_bytes.
282///
283/// Importantly, this will serialize the concrete instance, prefixed by the identifier of the data type.
284///
285/// This prefix can be used to figure out which concrete data type it should use for deserialization later.
286#[proc_macro_attribute]
287pub fn erased_serialize_bytes(_attr: TokenStream, item: TokenStream) -> TokenStream {
288	let mut item_impl: ItemImpl = parse_macro_input!(item);
289	let syn::Type::Path(p) = &*item_impl.self_ty else {
290		return syn::Error::new(
291			item_impl.span(),
292			"#[erased_serialize_bytes] can only be used on an impl for a concrete type",
293		)
294		.into_compile_error()
295		.into();
296	};
297	let name = p.path.segments.last().unwrap().ident.to_string();
298	item_impl.items.push(syn::ImplItem::Fn(parse_quote! {
299		fn erased_serialize(
300			&self,
301			write_buf: &mut dyn binius_utils::bytes::BufMut,
302			mode: binius_utils::SerializationMode,
303		) -> Result<(), binius_utils::SerializationError> {
304			binius_utils::SerializeBytes::serialize(&#name, &mut *write_buf, mode)?;
305			binius_utils::SerializeBytes::serialize(self, &mut *write_buf, mode)
306		}
307	}));
308	quote! {
309		#item_impl
310	}
311	.into()
312}
313
314fn field_names(fields: Fields, positional_prefix: Option<&str>) -> Vec<proc_macro2::TokenStream> {
315	match fields {
316		Fields::Named(fields) => fields
317			.named
318			.into_iter()
319			.map(|field| field.ident.into_token_stream())
320			.collect(),
321		Fields::Unnamed(fields) => fields
322			.unnamed
323			.into_iter()
324			.enumerate()
325			.map(|(i, _)| match positional_prefix {
326				Some(prefix) => {
327					quote::format_ident!("{}{}", prefix, syn::Index::from(i)).into_token_stream()
328				}
329				None => syn::Index::from(i).into_token_stream(),
330			})
331			.collect(),
332		Fields::Unit => vec![],
333	}
334}