1extern crate proc_macro;
4mod arith_circuit_poly;
5mod arith_expr;
6mod composition_poly;
7
8use proc_macro::TokenStream;
9use quote::{quote, ToTokens};
10use syn::{parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, ItemImpl};
11
12use crate::{
13 arith_circuit_poly::ArithCircuitPolyItem, arith_expr::ArithExprItem,
14 composition_poly::CompositionPolyItem,
15};
16
17#[proc_macro]
40pub fn composition_poly(input: TokenStream) -> TokenStream {
41 parse_macro_input!(input as CompositionPolyItem)
42 .into_token_stream()
43 .into()
44}
45
46#[proc_macro]
64pub fn arith_expr(input: TokenStream) -> TokenStream {
65 parse_macro_input!(input as ArithExprItem)
66 .into_token_stream()
67 .into()
68}
69
70#[proc_macro]
71pub fn arith_circuit_poly(input: TokenStream) -> TokenStream {
72 parse_macro_input!(input as ArithCircuitPolyItem)
73 .into_token_stream()
74 .into()
75}
76
77#[proc_macro_derive(SerializeBytes)]
81pub fn derive_serialize_bytes(input: TokenStream) -> TokenStream {
82 let input: DeriveInput = parse_macro_input!(input);
83 let span = input.span();
84 let name = input.ident;
85 let mut generics = input.generics.clone();
86 generics.type_params_mut().for_each(|type_param| {
87 type_param
88 .bounds
89 .push(parse_quote!(binius_utils::SerializeBytes))
90 });
91 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
92 let body = match input.data {
93 Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(),
94 Data::Struct(data) => {
95 let fields = field_names(data.fields, None);
96 quote! {
97 #(binius_utils::SerializeBytes::serialize(&self.#fields, &mut write_buf, mode)?;)*
98 }
99 }
100 Data::Enum(data) => {
101 let variants = data
102 .variants
103 .into_iter()
104 .enumerate()
105 .map(|(i, variant)| {
106 let variant_ident = &variant.ident;
107 let variant_index = i as u8;
108 let fields = field_names(variant.fields.clone(), Some("field_"));
109 let serialize_variant = quote! {
110 binius_utils::SerializeBytes::serialize(&#variant_index, &mut write_buf, mode)?;
111 #(binius_utils::SerializeBytes::serialize(#fields, &mut write_buf, mode)?;)*
112 };
113 match variant.fields {
114 Fields::Named(_) => quote! {
115 Self::#variant_ident { #(#fields),* } => {
116 #serialize_variant
117 }
118 },
119 Fields::Unnamed(_) => quote! {
120 Self::#variant_ident(#(#fields),*) => {
121 #serialize_variant
122 }
123 },
124 Fields::Unit => quote! {
125 Self::#variant_ident => {
126 #serialize_variant
127 }
128 },
129 }
130 })
131 .collect::<Vec<_>>();
132
133 quote! {
134 match self {
135 #(#variants)*
136 }
137 }
138 }
139 };
140 quote! {
141 impl #impl_generics binius_utils::SerializeBytes for #name #ty_generics #where_clause {
142 fn serialize(&self, mut write_buf: impl binius_utils::bytes::BufMut, mode: binius_utils::SerializationMode) -> Result<(), binius_utils::SerializationError> {
143 #body
144 Ok(())
145 }
146 }
147 }.into()
148}
149
150#[proc_macro_derive(DeserializeBytes)]
190pub fn derive_deserialize_bytes(input: TokenStream) -> TokenStream {
191 let input: DeriveInput = parse_macro_input!(input);
192 let span = input.span();
193 let name = input.ident;
194 let mut generics = input.generics.clone();
195 generics.type_params_mut().for_each(|type_param| {
196 type_param
197 .bounds
198 .push(parse_quote!(binius_utils::DeserializeBytes))
199 });
200 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
201 let deserialize_value = quote! {
202 binius_utils::DeserializeBytes::deserialize(&mut read_buf, mode)?
203 };
204 let body = match input.data {
205 Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(),
206 Data::Struct(data) => {
207 let fields = field_names(data.fields, None);
208 quote! {
209 Ok(Self {
210 #(#fields: #deserialize_value,)*
211 })
212 }
213 }
214 Data::Enum(data) => {
215 let variants = data
216 .variants
217 .into_iter()
218 .enumerate()
219 .map(|(i, variant)| {
220 let variant_ident = &variant.ident;
221 let variant_index: u8 = i as u8;
222 match variant.fields {
223 Fields::Named(fields) => {
224 let fields = fields
225 .named
226 .into_iter()
227 .map(|field| field.ident)
228 .map(|field_name| quote!(#field_name: #deserialize_value))
229 .collect::<Vec<_>>();
230
231 quote! {
232 #variant_index => Self::#variant_ident { #(#fields,)* }
233 }
234 }
235 Fields::Unnamed(fields) => {
236 let fields = fields
237 .unnamed
238 .into_iter()
239 .map(|_| quote!(#deserialize_value))
240 .collect::<Vec<_>>();
241
242 quote! {
243 #variant_index => Self::#variant_ident(#(#fields,)*)
244 }
245 }
246 Fields::Unit => quote! {
247 #variant_index => Self::#variant_ident
248 },
249 }
250 })
251 .collect::<Vec<_>>();
252
253 let name = name.to_string();
254 quote! {
255 let variant_index: u8 = #deserialize_value;
256 Ok(match variant_index {
257 #(#variants,)*
258 _ => {
259 return Err(binius_utils::SerializationError::UnknownEnumVariant {
260 name: #name,
261 index: variant_index
262 })
263 }
264 })
265 }
266 }
267 };
268 quote! {
269 impl #impl_generics binius_utils::DeserializeBytes for #name #ty_generics #where_clause {
270 fn deserialize(mut read_buf: impl binius_utils::bytes::Buf, mode: binius_utils::SerializationMode) -> Result<Self, binius_utils::SerializationError>
271 where
272 Self: Sized
273 {
274 #body
275 }
276 }
277 }
278 .into()
279}
280
281#[proc_macro_attribute]
287pub fn erased_serialize_bytes(_attr: TokenStream, item: TokenStream) -> TokenStream {
288 let mut item_impl: ItemImpl = parse_macro_input!(item);
289 let syn::Type::Path(p) = &*item_impl.self_ty else {
290 return syn::Error::new(
291 item_impl.span(),
292 "#[erased_serialize_bytes] can only be used on an impl for a concrete type",
293 )
294 .into_compile_error()
295 .into();
296 };
297 let name = p.path.segments.last().unwrap().ident.to_string();
298 item_impl.items.push(syn::ImplItem::Fn(parse_quote! {
299 fn erased_serialize(
300 &self,
301 write_buf: &mut dyn binius_utils::bytes::BufMut,
302 mode: binius_utils::SerializationMode,
303 ) -> Result<(), binius_utils::SerializationError> {
304 binius_utils::SerializeBytes::serialize(&#name, &mut *write_buf, mode)?;
305 binius_utils::SerializeBytes::serialize(self, &mut *write_buf, mode)
306 }
307 }));
308 quote! {
309 #item_impl
310 }
311 .into()
312}
313
314fn field_names(fields: Fields, positional_prefix: Option<&str>) -> Vec<proc_macro2::TokenStream> {
315 match fields {
316 Fields::Named(fields) => fields
317 .named
318 .into_iter()
319 .map(|field| field.ident.into_token_stream())
320 .collect(),
321 Fields::Unnamed(fields) => fields
322 .unnamed
323 .into_iter()
324 .enumerate()
325 .map(|(i, _)| match positional_prefix {
326 Some(prefix) => {
327 quote::format_ident!("{}{}", prefix, syn::Index::from(i)).into_token_stream()
328 }
329 None => syn::Index::from(i).into_token_stream(),
330 })
331 .collect(),
332 Fields::Unit => vec![],
333 }
334}