1use std::{fmt::Debug, ops::Deref};
4
5use binius_field::{
6 arch::OptimalUnderlier, as_packed_field::PackedType, packed::pack_slice, BinaryField128b,
7 BinaryField16b, BinaryField1b, BinaryField2b, BinaryField32b, BinaryField4b, BinaryField64b,
8 BinaryField8b, ExtensionField, PackedField, RepackedExtension, TowerField,
9};
10use binius_hal::{make_portable_backend, ComputationBackendExt};
11use binius_macros::erased_serialize_bytes;
12use binius_math::{MLEEmbeddingAdapter, MultilinearExtension, MultilinearPoly};
13use binius_utils::{DeserializeBytes, SerializationError, SerializationMode, SerializeBytes};
14
15use crate::polynomial::{Error, MultivariatePoly};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
25pub struct MultilinearExtensionTransparent<P, PE, Data = Vec<P>>
26where
27 P: PackedField,
28 PE: PackedField,
29 PE::Scalar: ExtensionField<P::Scalar>,
30 Data: Deref<Target = [P]>,
31{
32 data: MLEEmbeddingAdapter<P, PE, Data>,
33}
34
35impl<P, PE, Data> SerializeBytes for MultilinearExtensionTransparent<P, PE, Data>
36where
37 P: PackedField,
38 PE: RepackedExtension<P>,
39 PE::Scalar: TowerField + ExtensionField<P::Scalar>,
40 Data: Deref<Target = [P]> + Debug + Send + Sync,
41{
42 fn serialize(
43 &self,
44 write_buf: impl bytes::BufMut,
45 mode: SerializationMode,
46 ) -> Result<(), SerializationError> {
47 let elems = PE::iter_slice(
48 self.data
49 .packed_evals()
50 .expect("Evals should always be available here"),
51 )
52 .collect::<Vec<_>>();
53 SerializeBytes::serialize(&elems, write_buf, mode)
54 }
55}
56
57inventory::submit! {
58 <dyn MultivariatePoly<BinaryField128b>>::register_deserializer(
59 "MultilinearExtensionTransparent",
60 |buf, mode| {
61 type U = OptimalUnderlier;
62 type F = BinaryField128b;
63 type P = PackedType<U, F>;
64 let hypercube_evals = Vec::<F>::deserialize(&mut *buf, mode)?;
65 let result: Box<dyn MultivariatePoly<F>> = if let Some(packed_evals) = try_pack_slice(&hypercube_evals) {
66 Box::new(MultilinearExtensionTransparent::<PackedType<U, BinaryField1b>, P, _>::from_values(packed_evals).unwrap())
67 } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) {
68 Box::new(MultilinearExtensionTransparent::<PackedType<U, BinaryField2b>, P, _>::from_values(packed_evals).unwrap())
69 } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) {
70 Box::new(MultilinearExtensionTransparent::<PackedType<U, BinaryField4b>, P, _>::from_values(packed_evals).unwrap())
71 } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) {
72 Box::new(MultilinearExtensionTransparent::<PackedType<U, BinaryField8b>, P, _>::from_values(packed_evals).unwrap())
73 } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) {
74 Box::new(MultilinearExtensionTransparent::<PackedType<U, BinaryField16b>, P, _>::from_values(packed_evals).unwrap())
75 } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) {
76 Box::new(MultilinearExtensionTransparent::<PackedType<U, BinaryField32b>, P, _>::from_values(packed_evals).unwrap())
77 } else if let Some(packed_evals) = try_pack_slice(&hypercube_evals) {
78 Box::new(MultilinearExtensionTransparent::<PackedType<U, BinaryField64b>, P, _>::from_values(packed_evals).unwrap())
79 } else {
80 Box::new(MultilinearExtensionTransparent::<P, P, _>::from_values(pack_slice(&hypercube_evals)).unwrap())
81 };
82 Ok(result)
83 }
84 )
85}
86
87fn try_pack_slice<PS, F>(xs: &[F]) -> Option<Vec<PS>>
88where
89 PS: PackedField,
90 F: ExtensionField<PS::Scalar>,
91{
92 Some(pack_slice(
93 &xs.iter()
94 .copied()
95 .map(TryInto::try_into)
96 .collect::<Result<Vec<_>, _>>()
97 .ok()?,
98 ))
99}
100
101impl<P, PE, Data> MultilinearExtensionTransparent<P, PE, Data>
102where
103 P: PackedField,
104 PE: PackedField,
105 PE::Scalar: ExtensionField<P::Scalar>,
106 Data: Deref<Target = [P]>,
107{
108 pub fn from_values(values: Data) -> Result<Self, Error> {
109 let mle = MultilinearExtension::from_values_generic(values)?;
110 Ok(Self {
111 data: mle.specialize(),
112 })
113 }
114
115 pub fn from_values_and_mu(values: Data, n_vars: usize) -> Result<Self, Error> {
117 let mle = MultilinearExtension::new(n_vars, values)?;
118 Ok(Self {
119 data: mle.specialize(),
120 })
121 }
122}
123
124#[erased_serialize_bytes]
125impl<F, P, PE, Data> MultivariatePoly<F> for MultilinearExtensionTransparent<P, PE, Data>
126where
127 F: TowerField + ExtensionField<P::Scalar>,
128 P: PackedField,
129 PE: PackedField<Scalar = F> + RepackedExtension<P>,
130 Data: Deref<Target = [P]> + Send + Sync + Debug,
131{
132 fn n_vars(&self) -> usize {
133 self.data.n_vars()
134 }
135
136 fn degree(&self) -> usize {
137 self.data.n_vars()
138 }
139
140 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
141 let backend = make_portable_backend();
144 let query = backend.multilinear_query(query)?;
145 Ok(self.data.evaluate(query.to_ref())?)
146 }
147
148 fn binary_tower_level(&self) -> usize {
149 F::TOWER_LEVEL - self.data.log_extension_degree()
150 }
151}