1use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc};
4
5use binius_field::{Field, PackedField};
6use binius_math::{
7 ArithExpr, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef,
8};
9use binius_utils::{bail, SerializationError, SerializationMode};
10use bytes::BufMut;
11use itertools::Itertools;
12use rand::{rngs::StdRng, SeedableRng};
13
14use super::error::Error;
15
16pub trait MultivariatePoly<P>: Debug + Send + Sync {
21 fn n_vars(&self) -> usize;
23
24 fn degree(&self) -> usize;
26
27 fn evaluate(&self, query: &[P]) -> Result<P, Error>;
29
30 fn binary_tower_level(&self) -> usize;
32
33 fn erased_serialize(
36 &self,
37 write_buf: &mut dyn BufMut,
38 mode: SerializationMode,
39 ) -> Result<(), SerializationError> {
40 let _ = (write_buf, mode);
41 Err(SerializationError::SerializationNotImplemented)
42 }
43}
44
45#[derive(Clone, Debug)]
47pub struct IdentityCompositionPoly;
48
49impl<P: PackedField> CompositionPoly<P> for IdentityCompositionPoly {
50 fn n_vars(&self) -> usize {
51 1
52 }
53
54 fn degree(&self) -> usize {
55 1
56 }
57
58 fn expression(&self) -> ArithExpr<P::Scalar> {
59 binius_math::ArithExpr::Var(0)
60 }
61
62 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
63 if query.len() != 1 {
64 bail!(binius_math::Error::IncorrectQuerySize { expected: 1 });
65 }
66 Ok(query[0])
67 }
68
69 fn binary_tower_level(&self) -> usize {
70 0
71 }
72}
73
74#[derive(Debug, Clone)]
79pub struct CompositionScalarAdapter<P, Composition> {
80 composition: Composition,
81 _marker: PhantomData<P>,
82}
83
84impl<P, Composition> CompositionScalarAdapter<P, Composition>
85where
86 P: PackedField,
87 Composition: CompositionPoly<P>,
88{
89 pub const fn new(composition: Composition) -> Self {
90 Self {
91 composition,
92 _marker: PhantomData,
93 }
94 }
95}
96
97impl<F, P, Composition> CompositionPoly<F> for CompositionScalarAdapter<P, Composition>
98where
99 F: Field,
100 P: PackedField<Scalar = F>,
101 Composition: CompositionPoly<P>,
102{
103 fn n_vars(&self) -> usize {
104 self.composition.n_vars()
105 }
106
107 fn degree(&self) -> usize {
108 self.composition.degree()
109 }
110
111 fn expression(&self) -> ArithExpr<F> {
112 self.composition.expression()
113 }
114
115 fn evaluate(&self, query: &[F]) -> Result<F, binius_math::Error> {
116 let packed_query = query.iter().copied().map(P::set_single).collect::<Vec<_>>();
117 let packed_result = self.composition.evaluate(&packed_query)?;
118 Ok(packed_result.get(0))
119 }
120
121 fn binary_tower_level(&self) -> usize {
122 self.composition.binary_tower_level()
123 }
124}
125
126#[derive(Debug, Clone)]
141pub struct MultilinearComposite<P, C, M>
142where
143 P: PackedField,
144 M: MultilinearPoly<P>,
145{
146 pub composition: C,
147 n_vars: usize,
148 pub multilinears: Vec<M>,
150 pub _marker: PhantomData<P>,
151}
152
153impl<P, C, M> MultilinearComposite<P, C, M>
154where
155 P: PackedField,
156 C: CompositionPoly<P>,
157 M: MultilinearPoly<P>,
158{
159 pub fn new(n_vars: usize, composition: C, multilinears: Vec<M>) -> Result<Self, Error> {
160 if composition.n_vars() != multilinears.len() {
161 let err_str = format!(
162 "Number of variables in composition {} does not match length of multilinears {}",
163 composition.n_vars(),
164 multilinears.len()
165 );
166 bail!(Error::MultilinearCompositeValidation(err_str));
167 }
168 for multilin in multilinears.iter().map(Borrow::borrow) {
169 if multilin.n_vars() != n_vars {
170 let err_str = format!(
171 "Number of variables in multilinear {} does not match n_vars {}",
172 multilin.n_vars(),
173 n_vars
174 );
175 bail!(Error::MultilinearCompositeValidation(err_str));
176 }
177 }
178 Ok(Self {
179 n_vars,
180 composition,
181 multilinears,
182 _marker: PhantomData,
183 })
184 }
185
186 pub fn evaluate<'a>(
187 &self,
188 query: impl Into<MultilinearQueryRef<'a, P>>,
189 ) -> Result<P::Scalar, Error> {
190 let query = query.into();
191 let evals = self
192 .multilinears
193 .iter()
194 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate(query)?)))
195 .collect::<Result<Vec<_>, _>>()?;
196 Ok(self.composition.evaluate(&evals)?.get(0))
197 }
198
199 pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
200 let evals = self
201 .multilinears
202 .iter()
203 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate_on_hypercube(index)?)))
204 .collect::<Result<Vec<_>, _>>()?;
205
206 Ok(self.composition.evaluate(&evals)?.get(0))
207 }
208
209 pub fn max_individual_degree(&self) -> usize {
210 self.composition.degree()
212 }
213
214 pub fn n_multilinears(&self) -> usize {
215 self.composition.n_vars()
216 }
217}
218
219impl<P, C, M> MultilinearComposite<P, C, M>
220where
221 P: PackedField,
222 C: CompositionPoly<P> + 'static,
223 M: MultilinearPoly<P>,
224{
225 pub fn to_arc_dyn_composition(self) -> MultilinearComposite<P, Arc<dyn CompositionPoly<P>>, M> {
226 MultilinearComposite {
227 n_vars: self.n_vars,
228 composition: Arc::new(self.composition),
229 multilinears: self.multilinears,
230 _marker: PhantomData,
231 }
232 }
233}
234
235impl<P, C, M> MultilinearComposite<P, C, M>
236where
237 P: PackedField,
238 M: MultilinearPoly<P>,
239{
240 pub const fn n_vars(&self) -> usize {
241 self.n_vars
242 }
243}
244
245impl<P, C, M> MultilinearComposite<P, C, M>
246where
247 P: PackedField,
248 C: Clone,
249 M: MultilinearPoly<P>,
250{
251 pub fn evaluate_partial_low(
252 &self,
253 query: MultilinearQueryRef<P>,
254 ) -> Result<MultilinearComposite<P, C, impl MultilinearPoly<P>>, Error> {
255 let new_multilinears = self
256 .multilinears
257 .iter()
258 .map(|multilin| {
259 multilin
260 .evaluate_partial_low(query)
261 .map(MLEDirectAdapter::from)
262 })
263 .collect::<Result<Vec<_>, _>>()?;
264 Ok(MultilinearComposite {
265 composition: self.composition.clone(),
266 n_vars: self.n_vars - query.n_vars(),
267 multilinears: new_multilinears,
268 _marker: PhantomData,
269 })
270 }
271}
272
273pub fn composition_hash<P: PackedField, C: CompositionPoly<P>>(composition: &C) -> P {
281 let mut rng = StdRng::from_seed([0; 32]);
282
283 let random_point = repeat_with(|| P::random(&mut rng))
284 .take(composition.n_vars())
285 .collect_vec();
286
287 composition
288 .evaluate(&random_point)
289 .expect("Failed to evaluate composition")
290}
291
292#[cfg(test)]
293mod tests {
294 use binius_math::{ArithExpr, CompositionPoly};
295
296 #[test]
297 fn test_fingerprint_same_32b() {
298 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
299
300 let expr =
302 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
303 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
304 as &dyn CompositionPoly<PackedBinaryField8x32b>;
305
306 let product_composition = crate::composition::ProductComposition::<2> {};
307
308 assert_eq!(
309 crate::polynomial::composition_hash(&circuit_poly),
310 crate::polynomial::composition_hash(&product_composition)
311 );
312 }
313
314 #[test]
315 fn test_fingerprint_diff_32b() {
316 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
317
318 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
319
320 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
321 as &dyn CompositionPoly<PackedBinaryField8x32b>;
322
323 let product_composition = crate::composition::ProductComposition::<2> {};
324
325 assert_ne!(
326 crate::polynomial::composition_hash(&circuit_poly),
327 crate::polynomial::composition_hash(&product_composition)
328 );
329 }
330
331 #[test]
332 fn test_fingerprint_same_64b() {
333 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
334
335 let expr =
337 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
338 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
339 as &dyn CompositionPoly<PackedBinaryField4x64b>;
340
341 let product_composition = crate::composition::ProductComposition::<2> {};
342
343 assert_eq!(
344 crate::polynomial::composition_hash(&circuit_poly),
345 crate::polynomial::composition_hash(&product_composition)
346 );
347 }
348
349 #[test]
350 fn test_fingerprint_diff_64b() {
351 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
352
353 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
354 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
355 as &dyn CompositionPoly<PackedBinaryField4x64b>;
356
357 let product_composition = crate::composition::ProductComposition::<2> {};
358
359 assert_ne!(
360 crate::polynomial::composition_hash(&circuit_poly),
361 crate::polynomial::composition_hash(&product_composition)
362 );
363 }
364
365 #[test]
366 fn test_fingerprint_same_128b() {
367 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
368
369 let expr =
371 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
372 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
373 as &dyn CompositionPoly<PackedBinaryField2x128b>;
374
375 let product_composition = crate::composition::ProductComposition::<2> {};
376
377 assert_eq!(
378 crate::polynomial::composition_hash(&circuit_poly),
379 crate::polynomial::composition_hash(&product_composition)
380 );
381 }
382
383 #[test]
384 fn test_fingerprint_diff_128b() {
385 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
386
387 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
388 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
389 as &dyn CompositionPoly<PackedBinaryField2x128b>;
390
391 let product_composition = crate::composition::ProductComposition::<2> {};
392
393 assert_ne!(
394 crate::polynomial::composition_hash(&circuit_poly),
395 crate::polynomial::composition_hash(&product_composition)
396 );
397 }
398}