1use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc};
4
5use auto_impl::auto_impl;
6use binius_field::{Field, PackedField, TowerField};
7use binius_math::{
8 ArithCircuit, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef,
9};
10use binius_utils::{SerializationError, SerializationMode, SerializeBytes as _, bail};
11use bytes::BufMut;
12use itertools::Itertools;
13use rand::{SeedableRng, rngs::StdRng};
14
15use super::error::Error;
16
17#[auto_impl(Arc)]
22pub trait MultivariatePoly<P>: Debug + Send + Sync {
23 fn n_vars(&self) -> usize;
25
26 fn degree(&self) -> usize;
28
29 fn evaluate(&self, query: &[P]) -> Result<P, Error>;
31
32 fn binary_tower_level(&self) -> usize;
34
35 fn erased_serialize(
39 &self,
40 write_buf: &mut dyn BufMut,
41 mode: SerializationMode,
42 ) -> Result<(), SerializationError> {
43 let _ = (write_buf, mode);
44 Err(SerializationError::SerializationNotImplemented)
45 }
46}
47
48#[derive(Clone, Debug)]
50pub struct IdentityCompositionPoly;
51
52impl<P: PackedField> CompositionPoly<P> for IdentityCompositionPoly {
53 fn n_vars(&self) -> usize {
54 1
55 }
56
57 fn degree(&self) -> usize {
58 1
59 }
60
61 fn expression(&self) -> ArithCircuit<P::Scalar> {
62 binius_math::ArithCircuit::var(0)
63 }
64
65 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
66 if query.len() != 1 {
67 bail!(binius_math::Error::IncorrectQuerySize {
68 expected: 1,
69 actual: query.len()
70 });
71 }
72 Ok(query[0])
73 }
74
75 fn binary_tower_level(&self) -> usize {
76 0
77 }
78}
79
80impl<F: TowerField> MultivariatePoly<F> for ArithCircuit<F> {
81 fn n_vars(&self) -> usize {
82 self.n_vars()
83 }
84
85 fn degree(&self) -> usize {
86 self.degree()
87 }
88
89 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
90 self.evaluate(query).map_err(Error::from)
91 }
92
93 fn binary_tower_level(&self) -> usize {
94 self.binary_tower_level()
95 }
96
97 fn erased_serialize(
98 &self,
99 write_buf: &mut dyn BufMut,
100 mode: SerializationMode,
101 ) -> Result<(), SerializationError> {
102 self.serialize(write_buf, mode)
103 }
104}
105
106#[derive(Debug, Clone)]
111pub struct CompositionScalarAdapter<P, Composition> {
112 composition: Composition,
113 _marker: PhantomData<P>,
114}
115
116impl<P, Composition> CompositionScalarAdapter<P, Composition>
117where
118 P: PackedField,
119 Composition: CompositionPoly<P>,
120{
121 pub const fn new(composition: Composition) -> Self {
122 Self {
123 composition,
124 _marker: PhantomData,
125 }
126 }
127}
128
129impl<F, P, Composition> CompositionPoly<F> for CompositionScalarAdapter<P, Composition>
130where
131 F: Field,
132 P: PackedField<Scalar = F>,
133 Composition: CompositionPoly<P>,
134{
135 fn n_vars(&self) -> usize {
136 self.composition.n_vars()
137 }
138
139 fn degree(&self) -> usize {
140 self.composition.degree()
141 }
142
143 fn expression(&self) -> ArithCircuit<F> {
144 self.composition.expression()
145 }
146
147 fn evaluate(&self, query: &[F]) -> Result<F, binius_math::Error> {
148 let packed_query = query.iter().copied().map(P::set_single).collect::<Vec<_>>();
149 let packed_result = self.composition.evaluate(&packed_query)?;
150 Ok(packed_result.get(0))
151 }
152
153 fn binary_tower_level(&self) -> usize {
154 self.composition.binary_tower_level()
155 }
156}
157
158#[derive(Debug, Clone)]
173pub struct MultilinearComposite<P, C, M>
174where
175 P: PackedField,
176 M: MultilinearPoly<P>,
177{
178 pub composition: C,
179 n_vars: usize,
180 pub multilinears: Vec<M>,
182 pub _marker: PhantomData<P>,
183}
184
185impl<P, C, M> MultilinearComposite<P, C, M>
186where
187 P: PackedField,
188 C: CompositionPoly<P>,
189 M: MultilinearPoly<P>,
190{
191 pub fn new(n_vars: usize, composition: C, multilinears: Vec<M>) -> Result<Self, Error> {
192 if composition.n_vars() != multilinears.len() {
193 let err_str = format!(
194 "Number of variables in composition {} does not match length of multilinears {}",
195 composition.n_vars(),
196 multilinears.len()
197 );
198 bail!(Error::MultilinearCompositeValidation(err_str));
199 }
200 for multilin in multilinears.iter().map(Borrow::borrow) {
201 if multilin.n_vars() != n_vars {
202 let err_str = format!(
203 "Number of variables in multilinear {} does not match n_vars {}",
204 multilin.n_vars(),
205 n_vars
206 );
207 bail!(Error::MultilinearCompositeValidation(err_str));
208 }
209 }
210 Ok(Self {
211 n_vars,
212 composition,
213 multilinears,
214 _marker: PhantomData,
215 })
216 }
217
218 pub fn evaluate<'a>(
219 &self,
220 query: impl Into<MultilinearQueryRef<'a, P>>,
221 ) -> Result<P::Scalar, Error> {
222 let query = query.into();
223 let evals = self
224 .multilinears
225 .iter()
226 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate(query)?)))
227 .collect::<Result<Vec<_>, _>>()?;
228 Ok(self.composition.evaluate(&evals)?.get(0))
229 }
230
231 pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
232 let evals = self
233 .multilinears
234 .iter()
235 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate_on_hypercube(index)?)))
236 .collect::<Result<Vec<_>, _>>()?;
237
238 Ok(self.composition.evaluate(&evals)?.get(0))
239 }
240
241 pub fn max_individual_degree(&self) -> usize {
242 self.composition.degree()
244 }
245
246 pub fn n_multilinears(&self) -> usize {
247 self.composition.n_vars()
248 }
249}
250
251impl<P, C, M> MultilinearComposite<P, C, M>
252where
253 P: PackedField,
254 C: CompositionPoly<P> + 'static,
255 M: MultilinearPoly<P>,
256{
257 pub fn to_arc_dyn_composition(self) -> MultilinearComposite<P, Arc<dyn CompositionPoly<P>>, M> {
258 MultilinearComposite {
259 n_vars: self.n_vars,
260 composition: Arc::new(self.composition),
261 multilinears: self.multilinears,
262 _marker: PhantomData,
263 }
264 }
265}
266
267impl<P, C, M> MultilinearComposite<P, C, M>
268where
269 P: PackedField,
270 M: MultilinearPoly<P>,
271{
272 pub const fn n_vars(&self) -> usize {
273 self.n_vars
274 }
275}
276
277impl<P, C, M> MultilinearComposite<P, C, M>
278where
279 P: PackedField,
280 C: Clone,
281 M: MultilinearPoly<P>,
282{
283 pub fn evaluate_partial_low(
284 &self,
285 query: MultilinearQueryRef<P>,
286 ) -> Result<MultilinearComposite<P, C, impl MultilinearPoly<P>>, Error> {
287 let new_multilinears = self
288 .multilinears
289 .iter()
290 .map(|multilin| {
291 multilin
292 .evaluate_partial_low(query)
293 .map(MLEDirectAdapter::from)
294 })
295 .collect::<Result<Vec<_>, _>>()?;
296 Ok(MultilinearComposite {
297 composition: self.composition.clone(),
298 n_vars: self.n_vars - query.n_vars(),
299 multilinears: new_multilinears,
300 _marker: PhantomData,
301 })
302 }
303}
304
305pub fn composition_hash<P: PackedField, C: CompositionPoly<P>>(composition: &C) -> P {
314 let mut rng = StdRng::from_seed([0; 32]);
315
316 let random_point = repeat_with(|| P::random(&mut rng))
317 .take(composition.n_vars())
318 .collect_vec();
319
320 composition
321 .evaluate(&random_point)
322 .expect("Failed to evaluate composition")
323}
324
325#[cfg(test)]
326mod tests {
327 use binius_fast_compute::arith_circuit::ArithCircuitPoly;
328 use binius_math::{ArithExpr, CompositionPoly};
329
330 #[test]
331 fn test_fingerprint_same_32b() {
332 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
333
334 let expr =
336 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
337 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
338 as &dyn CompositionPoly<PackedBinaryField8x32b>;
339
340 let product_composition = crate::composition::ProductComposition::<2> {};
341
342 assert_eq!(
343 crate::polynomial::composition_hash(&circuit_poly),
344 crate::polynomial::composition_hash(&product_composition)
345 );
346 }
347
348 #[test]
349 fn test_fingerprint_diff_32b() {
350 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
351
352 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
353
354 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
355 as &dyn CompositionPoly<PackedBinaryField8x32b>;
356
357 let product_composition = crate::composition::ProductComposition::<2> {};
358
359 assert_ne!(
360 crate::polynomial::composition_hash(&circuit_poly),
361 crate::polynomial::composition_hash(&product_composition)
362 );
363 }
364
365 #[test]
366 fn test_fingerprint_same_64b() {
367 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
368
369 let expr =
371 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
372 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
373 as &dyn CompositionPoly<PackedBinaryField4x64b>;
374
375 let product_composition = crate::composition::ProductComposition::<2> {};
376
377 assert_eq!(
378 crate::polynomial::composition_hash(&circuit_poly),
379 crate::polynomial::composition_hash(&product_composition)
380 );
381 }
382
383 #[test]
384 fn test_fingerprint_diff_64b() {
385 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
386
387 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
388 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
389 as &dyn CompositionPoly<PackedBinaryField4x64b>;
390
391 let product_composition = crate::composition::ProductComposition::<2> {};
392
393 assert_ne!(
394 crate::polynomial::composition_hash(&circuit_poly),
395 crate::polynomial::composition_hash(&product_composition)
396 );
397 }
398
399 #[test]
400 fn test_fingerprint_same_128b() {
401 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
402
403 let expr =
405 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
406 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
407 as &dyn CompositionPoly<PackedBinaryField2x128b>;
408
409 let product_composition = crate::composition::ProductComposition::<2> {};
410
411 assert_eq!(
412 crate::polynomial::composition_hash(&circuit_poly),
413 crate::polynomial::composition_hash(&product_composition)
414 );
415 }
416
417 #[test]
418 fn test_fingerprint_diff_128b() {
419 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
420
421 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
422 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
423 as &dyn CompositionPoly<PackedBinaryField2x128b>;
424
425 let product_composition = crate::composition::ProductComposition::<2> {};
426
427 assert_ne!(
428 crate::polynomial::composition_hash(&circuit_poly),
429 crate::polynomial::composition_hash(&product_composition)
430 );
431 }
432}