1extern crate proc_macro;
4mod deserialize_bytes;
5
6use deserialize_bytes::{GenericsSplit, parse_container_attributes, split_for_impl};
7use proc_macro::TokenStream;
8use quote::{ToTokens, quote};
9use syn::{Data, DeriveInput, Fields, ItemImpl, parse_macro_input, parse_quote, spanned::Spanned};
10
11#[proc_macro_derive(SerializeBytes)]
15pub fn derive_serialize_bytes(input: TokenStream) -> TokenStream {
16 let input: DeriveInput = parse_macro_input!(input);
17 let span = input.span();
18 let name = input.ident;
19 let mut generics = input.generics.clone();
20 generics.type_params_mut().for_each(|type_param| {
21 type_param
22 .bounds
23 .push(parse_quote!(binius_utils::SerializeBytes))
24 });
25 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
26 let body = match input.data {
27 Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(),
28 Data::Struct(data) => {
29 let fields = field_names(data.fields, None);
30 quote! {
31 #(binius_utils::SerializeBytes::serialize(&self.#fields, &mut write_buf, mode)?;)*
32 }
33 }
34 Data::Enum(data) => {
35 let variants = data
36 .variants
37 .into_iter()
38 .enumerate()
39 .map(|(i, variant)| {
40 let variant_ident = &variant.ident;
41 let variant_index = i as u8;
42 let fields = field_names(variant.fields.clone(), Some("field_"));
43 let serialize_variant = quote! {
44 binius_utils::SerializeBytes::serialize(&#variant_index, &mut write_buf, mode)?;
45 #(binius_utils::SerializeBytes::serialize(#fields, &mut write_buf, mode)?;)*
46 };
47 match variant.fields {
48 Fields::Named(_) => quote! {
49 Self::#variant_ident { #(#fields),* } => {
50 #serialize_variant
51 }
52 },
53 Fields::Unnamed(_) => quote! {
54 Self::#variant_ident(#(#fields),*) => {
55 #serialize_variant
56 }
57 },
58 Fields::Unit => quote! {
59 Self::#variant_ident => {
60 #serialize_variant
61 }
62 },
63 }
64 })
65 .collect::<Vec<_>>();
66
67 quote! {
68 match self {
69 #(#variants)*
70 }
71 }
72 }
73 };
74 quote! {
75 impl #impl_generics binius_utils::SerializeBytes for #name #ty_generics #where_clause {
76 fn serialize(&self, mut write_buf: impl binius_utils::bytes::BufMut, mode: binius_utils::SerializationMode) -> Result<(), binius_utils::SerializationError> {
77 #body
78 Ok(())
79 }
80 }
81 }.into()
82}
83
84#[proc_macro_derive(DeserializeBytes, attributes(deserialize_bytes))]
161pub fn derive_deserialize_bytes(input: TokenStream) -> TokenStream {
162 let input: DeriveInput = parse_macro_input!(input);
163 let span = input.span();
164 let container_attributes = match parse_container_attributes(&input) {
165 Ok(x) => x,
166 Err(e) => return e.into_compile_error().into(),
167 };
168 let name = input.ident;
169 let mut generics = input.generics.clone();
170 generics.type_params_mut().for_each(|type_param| {
171 type_param
172 .bounds
173 .push(parse_quote!(binius_utils::DeserializeBytes))
174 });
175 let GenericsSplit {
176 impl_generics,
177 type_generics: ty_generics,
178 where_clause,
179 } = split_for_impl(&generics, &container_attributes);
180 let deserialize_value = quote! {
181 binius_utils::DeserializeBytes::deserialize(&mut read_buf, mode)?
182 };
183 let body = match input.data {
184 Data::Union(_) => syn::Error::new(span, "Unions are not supported").into_compile_error(),
185 Data::Struct(data) => {
186 let fields = field_names(data.fields, None);
187 quote! {
188 Ok(Self {
189 #(#fields: #deserialize_value,)*
190 })
191 }
192 }
193 Data::Enum(data) => {
194 let variants = data
195 .variants
196 .into_iter()
197 .enumerate()
198 .map(|(i, variant)| {
199 let variant_ident = &variant.ident;
200 let variant_index: u8 = i as u8;
201 match variant.fields {
202 Fields::Named(fields) => {
203 let fields = fields
204 .named
205 .into_iter()
206 .map(|field| field.ident)
207 .map(|field_name| quote!(#field_name: #deserialize_value))
208 .collect::<Vec<_>>();
209
210 quote! {
211 #variant_index => Self::#variant_ident { #(#fields,)* }
212 }
213 }
214 Fields::Unnamed(fields) => {
215 let fields = fields
216 .unnamed
217 .into_iter()
218 .map(|_| quote!(#deserialize_value))
219 .collect::<Vec<_>>();
220
221 quote! {
222 #variant_index => Self::#variant_ident(#(#fields,)*)
223 }
224 }
225 Fields::Unit => quote! {
226 #variant_index => Self::#variant_ident
227 },
228 }
229 })
230 .collect::<Vec<_>>();
231
232 let name = name.to_string();
233 quote! {
234 let variant_index: u8 = #deserialize_value;
235 Ok(match variant_index {
236 #(#variants,)*
237 _ => {
238 return Err(binius_utils::SerializationError::UnknownEnumVariant {
239 name: #name,
240 index: variant_index
241 })
242 }
243 })
244 }
245 }
246 };
247
248 quote! {
249 impl #impl_generics binius_utils::DeserializeBytes for #name #ty_generics #where_clause {
250 fn deserialize(mut read_buf: impl binius_utils::bytes::Buf, mode: binius_utils::SerializationMode) -> Result<Self, binius_utils::SerializationError>
251 where
252 Self: Sized
253 {
254 #body
255 }
256 }
257 }
258 .into()
259}
260
261#[proc_macro_attribute]
269pub fn erased_serialize_bytes(_attr: TokenStream, item: TokenStream) -> TokenStream {
270 let mut item_impl: ItemImpl = parse_macro_input!(item);
271 let syn::Type::Path(p) = &*item_impl.self_ty else {
272 return syn::Error::new(
273 item_impl.span(),
274 "#[erased_serialize_bytes] can only be used on an impl for a concrete type",
275 )
276 .into_compile_error()
277 .into();
278 };
279 let name = p.path.segments.last().unwrap().ident.to_string();
280 item_impl.items.push(syn::ImplItem::Fn(parse_quote! {
281 fn erased_serialize(
282 &self,
283 write_buf: &mut dyn binius_utils::bytes::BufMut,
284 mode: binius_utils::SerializationMode,
285 ) -> Result<(), binius_utils::SerializationError> {
286 binius_utils::SerializeBytes::serialize(&#name, &mut *write_buf, mode)?;
287 binius_utils::SerializeBytes::serialize(self, &mut *write_buf, mode)
288 }
289 }));
290 quote! {
291 #item_impl
292 }
293 .into()
294}
295
296fn field_names(fields: Fields, positional_prefix: Option<&str>) -> Vec<proc_macro2::TokenStream> {
297 match fields {
298 Fields::Named(fields) => fields
299 .named
300 .into_iter()
301 .map(|field| field.ident.into_token_stream())
302 .collect(),
303 Fields::Unnamed(fields) => fields
304 .unnamed
305 .into_iter()
306 .enumerate()
307 .map(|(i, _)| match positional_prefix {
308 Some(prefix) => {
309 quote::format_ident!("{}{}", prefix, syn::Index::from(i)).into_token_stream()
310 }
311 None => syn::Index::from(i).into_token_stream(),
312 })
313 .collect(),
314 Fields::Unit => vec![],
315 }
316}