binius_macros/
lib.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3extern crate proc_macro;
4mod deserialize_bytes;
5
6use deserialize_bytes::{GenericsSplit, parse_container_attributes, split_for_impl};
7use proc_macro::TokenStream;
8use quote::{ToTokens, quote};
9use syn::{Data, DeriveInput, Fields, ItemImpl, parse_macro_input, parse_quote, spanned::Spanned};
10
11/// Derives the trait binius_utils::DeserializeBytes for a struct or enum
12///
13/// See the DeserializeBytes derive macro docs for examples/tests
14#[proc_macro_derive(SerializeBytes)]
15pub fn derive_serialize_bytes(input: TokenStream) -> TokenStream {
16	let input: DeriveInput = parse_macro_input!(input);
17	let span = input.span();
18	let name = input.ident;
19	let mut generics = input.generics.clone();
20	generics.type_params_mut().for_each(|type_param| {
21		type_param
22			.bounds
23			.push(parse_quote!(binius_utils::SerializeBytes))
24	});
25	let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
26	let body = match input.data {
27		Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(),
28		Data::Struct(data) => {
29			let fields = field_names(data.fields, None);
30			quote! {
31				#(binius_utils::SerializeBytes::serialize(&self.#fields, &mut write_buf, mode)?;)*
32			}
33		}
34		Data::Enum(data) => {
35			let variants = data
36				.variants
37				.into_iter()
38				.enumerate()
39				.map(|(i, variant)| {
40					let variant_ident = &variant.ident;
41					let variant_index = i as u8;
42					let fields = field_names(variant.fields.clone(), Some("field_"));
43					let serialize_variant = quote! {
44						binius_utils::SerializeBytes::serialize(&#variant_index, &mut write_buf, mode)?;
45						#(binius_utils::SerializeBytes::serialize(#fields, &mut write_buf, mode)?;)*
46					};
47					match variant.fields {
48						Fields::Named(_) => quote! {
49							Self::#variant_ident { #(#fields),* } => {
50								#serialize_variant
51							}
52						},
53						Fields::Unnamed(_) => quote! {
54							Self::#variant_ident(#(#fields),*) => {
55								#serialize_variant
56							}
57						},
58						Fields::Unit => quote! {
59							Self::#variant_ident => {
60								#serialize_variant
61							}
62						},
63					}
64				})
65				.collect::<Vec<_>>();
66
67			quote! {
68				match self {
69					#(#variants)*
70				}
71			}
72		}
73	};
74	quote! {
75		impl #impl_generics binius_utils::SerializeBytes for #name #ty_generics #where_clause {
76			fn serialize(&self, mut write_buf: impl binius_utils::bytes::BufMut, mode: binius_utils::SerializationMode) -> Result<(), binius_utils::SerializationError> {
77				#body
78				Ok(())
79			}
80		}
81	}.into()
82}
83
84/// Derives the trait binius_utils::DeserializeBytes for a struct or enum
85///
86/// ```
87/// use binius_field::BinaryField128b;
88/// use binius_utils::{SerializeBytes, DeserializeBytes, SerializationMode};
89/// use binius_macros::{SerializeBytes, DeserializeBytes};
90///
91/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)]
92/// enum MyEnum {
93///     A(usize),
94///     B { x: u32, y: u32 },
95///     C
96/// }
97///
98///
99/// let mut buf = vec![];
100/// let value = MyEnum::B { x: 42, y: 1337 };
101/// MyEnum::serialize(&value, &mut buf, SerializationMode::Native).unwrap();
102/// assert_eq!(
103///     MyEnum::deserialize(buf.as_slice(), SerializationMode::Native).unwrap(),
104///     value
105/// );
106///
107///
108/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)]
109/// struct MyStruct<F> {
110///     data: Vec<F>
111/// }
112///
113/// let mut buf = vec![];
114/// let value = MyStruct {
115///    data: vec![BinaryField128b::new(1234), BinaryField128b::new(5678)]
116/// };
117/// MyStruct::serialize(&value, &mut buf, SerializationMode::CanonicalTower).unwrap();
118/// assert_eq!(
119///     MyStruct::<BinaryField128b>::deserialize(buf.as_slice(), SerializationMode::CanonicalTower).unwrap(),
120///     value
121/// );
122/// ```
123///
124/// ## Eval generics
125///
126/// Sometimes it is convenient to limit the implementation of `DeserializeBytes` only to
127/// specific types. For example:
128///
129/// ```ignore
130/// impl DeserializeBytes for MultilinearOracleSet<BinaryField128b> {...}
131/// ```
132///
133/// To do that use `eval_generics` attribute:
134///
135/// ```
136/// use binius_field::BinaryField128b;
137/// use binius_utils::{SerializeBytes, DeserializeBytes, SerializationMode};
138/// use binius_macros::{SerializeBytes, DeserializeBytes};
139///
140///
141/// #[derive(Debug, PartialEq, SerializeBytes, DeserializeBytes)]
142/// #[deserialize_bytes(eval_generics(F = BinaryField128b))]
143/// struct MyStruct<F> {
144///     data: Vec<F>
145/// }
146///
147/// let mut buf = vec![];
148/// let value = MyStruct {
149///    data: vec![BinaryField128b::new(1234), BinaryField128b::new(5678)]
150/// };
151/// MyStruct::serialize(&value, &mut buf, SerializationMode::CanonicalTower).unwrap();
152/// assert_eq!(
153///     MyStruct::<BinaryField128b>::deserialize(buf.as_slice(), SerializationMode::CanonicalTower).unwrap(),
154///     value
155/// );
156/// ```
157///
158/// Additionally, `eval_generics` can be used to fix multiple params:
159/// `eval_generics(F = BinaryField128b, G = binius_field::BinaryField64b)`
160#[proc_macro_derive(DeserializeBytes, attributes(deserialize_bytes))]
161pub fn derive_deserialize_bytes(input: TokenStream) -> TokenStream {
162	let input: DeriveInput = parse_macro_input!(input);
163	let span = input.span();
164	let container_attributes = match parse_container_attributes(&input) {
165		Ok(x) => x,
166		Err(e) => return e.into_compile_error().into(),
167	};
168	let name = input.ident;
169	let mut generics = input.generics.clone();
170	generics.type_params_mut().for_each(|type_param| {
171		type_param
172			.bounds
173			.push(parse_quote!(binius_utils::DeserializeBytes))
174	});
175	let GenericsSplit {
176		impl_generics,
177		type_generics: ty_generics,
178		where_clause,
179	} = split_for_impl(&generics, &container_attributes);
180	let deserialize_value = quote! {
181		binius_utils::DeserializeBytes::deserialize(&mut read_buf, mode)?
182	};
183	let body = match input.data {
184		Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(),
185		Data::Struct(data) => {
186			let fields = field_names(data.fields, None);
187			quote! {
188				Ok(Self {
189					#(#fields: #deserialize_value,)*
190				})
191			}
192		}
193		Data::Enum(data) => {
194			let variants = data
195				.variants
196				.into_iter()
197				.enumerate()
198				.map(|(i, variant)| {
199					let variant_ident = &variant.ident;
200					let variant_index: u8 = i as u8;
201					match variant.fields {
202						Fields::Named(fields) => {
203							let fields = fields
204								.named
205								.into_iter()
206								.map(|field| field.ident)
207								.map(|field_name| quote!(#field_name: #deserialize_value))
208								.collect::<Vec<_>>();
209
210							quote! {
211								#variant_index => Self::#variant_ident { #(#fields,)* }
212							}
213						}
214						Fields::Unnamed(fields) => {
215							let fields = fields
216								.unnamed
217								.into_iter()
218								.map(|_| quote!(#deserialize_value))
219								.collect::<Vec<_>>();
220
221							quote! {
222								#variant_index => Self::#variant_ident(#(#fields,)*)
223							}
224						}
225						Fields::Unit => quote! {
226							#variant_index => Self::#variant_ident
227						},
228					}
229				})
230				.collect::<Vec<_>>();
231
232			let name = name.to_string();
233			quote! {
234				let variant_index: u8 = #deserialize_value;
235				Ok(match variant_index {
236					#(#variants,)*
237					_ => {
238						return Err(binius_utils::SerializationError::UnknownEnumVariant {
239							name: #name,
240							index: variant_index
241						})
242					}
243				})
244			}
245		}
246	};
247
248	quote! {
249		impl #impl_generics binius_utils::DeserializeBytes for #name #ty_generics #where_clause {
250			fn deserialize(mut read_buf: impl binius_utils::bytes::Buf, mode: binius_utils::SerializationMode) -> Result<Self, binius_utils::SerializationError>
251			where
252				Self: Sized
253			{
254				#body
255			}
256		}
257	}
258	.into()
259}
260
261/// Use on an impl block for MultivariatePoly, to automatically implement erased_serialize_bytes.
262///
263/// Importantly, this will serialize the concrete instance, prefixed by the identifier of the data
264/// type.
265///
266/// This prefix can be used to figure out which concrete data type it should use for deserialization
267/// later.
268#[proc_macro_attribute]
269pub fn erased_serialize_bytes(_attr: TokenStream, item: TokenStream) -> TokenStream {
270	let mut item_impl: ItemImpl = parse_macro_input!(item);
271	let syn::Type::Path(p) = &*item_impl.self_ty else {
272		return syn::Error::new(
273			item_impl.span(),
274			"#[erased_serialize_bytes] can only be used on an impl for a concrete type",
275		)
276		.into_compile_error()
277		.into();
278	};
279	let name = p.path.segments.last().unwrap().ident.to_string();
280	item_impl.items.push(syn::ImplItem::Fn(parse_quote! {
281		fn erased_serialize(
282			&self,
283			write_buf: &mut dyn binius_utils::bytes::BufMut,
284			mode: binius_utils::SerializationMode,
285		) -> Result<(), binius_utils::SerializationError> {
286			binius_utils::SerializeBytes::serialize(&#name, &mut *write_buf, mode)?;
287			binius_utils::SerializeBytes::serialize(self, &mut *write_buf, mode)
288		}
289	}));
290	quote! {
291		#item_impl
292	}
293	.into()
294}
295
296fn field_names(fields: Fields, positional_prefix: Option<&str>) -> Vec<proc_macro2::TokenStream> {
297	match fields {
298		Fields::Named(fields) => fields
299			.named
300			.into_iter()
301			.map(|field| field.ident.into_token_stream())
302			.collect(),
303		Fields::Unnamed(fields) => fields
304			.unnamed
305			.into_iter()
306			.enumerate()
307			.map(|(i, _)| match positional_prefix {
308				Some(prefix) => {
309					quote::format_ident!("{}{}", prefix, syn::Index::from(i)).into_token_stream()
310				}
311				None => syn::Index::from(i).into_token_stream(),
312			})
313			.collect(),
314		Fields::Unit => vec![],
315	}
316}