1use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc};
4
5use auto_impl::auto_impl;
6use binius_field::{Field, PackedField, TowerField};
7use binius_math::{
8 ArithCircuit, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef,
9};
10use binius_utils::{SerializationError, SerializationMode, bail};
11use bytes::BufMut;
12use itertools::Itertools;
13use rand::{SeedableRng, rngs::StdRng};
14
15use super::error::Error;
16
17#[auto_impl(Arc)]
22pub trait MultivariatePoly<P>: Debug + Send + Sync {
23 fn n_vars(&self) -> usize;
25
26 fn degree(&self) -> usize;
28
29 fn evaluate(&self, query: &[P]) -> Result<P, Error>;
31
32 fn binary_tower_level(&self) -> usize;
34
35 fn erased_serialize(
39 &self,
40 write_buf: &mut dyn BufMut,
41 mode: SerializationMode,
42 ) -> Result<(), SerializationError> {
43 let _ = (write_buf, mode);
44 Err(SerializationError::SerializationNotImplemented)
45 }
46}
47
48#[derive(Clone, Debug)]
50pub struct IdentityCompositionPoly;
51
52impl<P: PackedField> CompositionPoly<P> for IdentityCompositionPoly {
53 fn n_vars(&self) -> usize {
54 1
55 }
56
57 fn degree(&self) -> usize {
58 1
59 }
60
61 fn expression(&self) -> ArithCircuit<P::Scalar> {
62 binius_math::ArithCircuit::var(0)
63 }
64
65 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
66 if query.len() != 1 {
67 bail!(binius_math::Error::IncorrectQuerySize {
68 expected: 1,
69 actual: query.len()
70 });
71 }
72 Ok(query[0])
73 }
74
75 fn binary_tower_level(&self) -> usize {
76 0
77 }
78}
79
80impl<F: TowerField> MultivariatePoly<F> for ArithCircuit<F> {
81 fn n_vars(&self) -> usize {
82 self.n_vars()
83 }
84
85 fn degree(&self) -> usize {
86 self.degree()
87 }
88
89 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
90 self.evaluate(query).map_err(Error::from)
91 }
92
93 fn binary_tower_level(&self) -> usize {
94 self.binary_tower_level()
95 }
96}
97
98#[derive(Debug, Clone)]
103pub struct CompositionScalarAdapter<P, Composition> {
104 composition: Composition,
105 _marker: PhantomData<P>,
106}
107
108impl<P, Composition> CompositionScalarAdapter<P, Composition>
109where
110 P: PackedField,
111 Composition: CompositionPoly<P>,
112{
113 pub const fn new(composition: Composition) -> Self {
114 Self {
115 composition,
116 _marker: PhantomData,
117 }
118 }
119}
120
121impl<F, P, Composition> CompositionPoly<F> for CompositionScalarAdapter<P, Composition>
122where
123 F: Field,
124 P: PackedField<Scalar = F>,
125 Composition: CompositionPoly<P>,
126{
127 fn n_vars(&self) -> usize {
128 self.composition.n_vars()
129 }
130
131 fn degree(&self) -> usize {
132 self.composition.degree()
133 }
134
135 fn expression(&self) -> ArithCircuit<F> {
136 self.composition.expression()
137 }
138
139 fn evaluate(&self, query: &[F]) -> Result<F, binius_math::Error> {
140 let packed_query = query.iter().copied().map(P::set_single).collect::<Vec<_>>();
141 let packed_result = self.composition.evaluate(&packed_query)?;
142 Ok(packed_result.get(0))
143 }
144
145 fn binary_tower_level(&self) -> usize {
146 self.composition.binary_tower_level()
147 }
148}
149
150#[derive(Debug, Clone)]
165pub struct MultilinearComposite<P, C, M>
166where
167 P: PackedField,
168 M: MultilinearPoly<P>,
169{
170 pub composition: C,
171 n_vars: usize,
172 pub multilinears: Vec<M>,
174 pub _marker: PhantomData<P>,
175}
176
177impl<P, C, M> MultilinearComposite<P, C, M>
178where
179 P: PackedField,
180 C: CompositionPoly<P>,
181 M: MultilinearPoly<P>,
182{
183 pub fn new(n_vars: usize, composition: C, multilinears: Vec<M>) -> Result<Self, Error> {
184 if composition.n_vars() != multilinears.len() {
185 let err_str = format!(
186 "Number of variables in composition {} does not match length of multilinears {}",
187 composition.n_vars(),
188 multilinears.len()
189 );
190 bail!(Error::MultilinearCompositeValidation(err_str));
191 }
192 for multilin in multilinears.iter().map(Borrow::borrow) {
193 if multilin.n_vars() != n_vars {
194 let err_str = format!(
195 "Number of variables in multilinear {} does not match n_vars {}",
196 multilin.n_vars(),
197 n_vars
198 );
199 bail!(Error::MultilinearCompositeValidation(err_str));
200 }
201 }
202 Ok(Self {
203 n_vars,
204 composition,
205 multilinears,
206 _marker: PhantomData,
207 })
208 }
209
210 pub fn evaluate<'a>(
211 &self,
212 query: impl Into<MultilinearQueryRef<'a, P>>,
213 ) -> Result<P::Scalar, Error> {
214 let query = query.into();
215 let evals = self
216 .multilinears
217 .iter()
218 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate(query)?)))
219 .collect::<Result<Vec<_>, _>>()?;
220 Ok(self.composition.evaluate(&evals)?.get(0))
221 }
222
223 pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
224 let evals = self
225 .multilinears
226 .iter()
227 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate_on_hypercube(index)?)))
228 .collect::<Result<Vec<_>, _>>()?;
229
230 Ok(self.composition.evaluate(&evals)?.get(0))
231 }
232
233 pub fn max_individual_degree(&self) -> usize {
234 self.composition.degree()
236 }
237
238 pub fn n_multilinears(&self) -> usize {
239 self.composition.n_vars()
240 }
241}
242
243impl<P, C, M> MultilinearComposite<P, C, M>
244where
245 P: PackedField,
246 C: CompositionPoly<P> + 'static,
247 M: MultilinearPoly<P>,
248{
249 pub fn to_arc_dyn_composition(self) -> MultilinearComposite<P, Arc<dyn CompositionPoly<P>>, M> {
250 MultilinearComposite {
251 n_vars: self.n_vars,
252 composition: Arc::new(self.composition),
253 multilinears: self.multilinears,
254 _marker: PhantomData,
255 }
256 }
257}
258
259impl<P, C, M> MultilinearComposite<P, C, M>
260where
261 P: PackedField,
262 M: MultilinearPoly<P>,
263{
264 pub const fn n_vars(&self) -> usize {
265 self.n_vars
266 }
267}
268
269impl<P, C, M> MultilinearComposite<P, C, M>
270where
271 P: PackedField,
272 C: Clone,
273 M: MultilinearPoly<P>,
274{
275 pub fn evaluate_partial_low(
276 &self,
277 query: MultilinearQueryRef<P>,
278 ) -> Result<MultilinearComposite<P, C, impl MultilinearPoly<P>>, Error> {
279 let new_multilinears = self
280 .multilinears
281 .iter()
282 .map(|multilin| {
283 multilin
284 .evaluate_partial_low(query)
285 .map(MLEDirectAdapter::from)
286 })
287 .collect::<Result<Vec<_>, _>>()?;
288 Ok(MultilinearComposite {
289 composition: self.composition.clone(),
290 n_vars: self.n_vars - query.n_vars(),
291 multilinears: new_multilinears,
292 _marker: PhantomData,
293 })
294 }
295}
296
297pub fn composition_hash<P: PackedField, C: CompositionPoly<P>>(composition: &C) -> P {
306 let mut rng = StdRng::from_seed([0; 32]);
307
308 let random_point = repeat_with(|| P::random(&mut rng))
309 .take(composition.n_vars())
310 .collect_vec();
311
312 composition
313 .evaluate(&random_point)
314 .expect("Failed to evaluate composition")
315}
316
317#[cfg(test)]
318mod tests {
319 use binius_fast_compute::arith_circuit::ArithCircuitPoly;
320 use binius_math::{ArithExpr, CompositionPoly};
321
322 #[test]
323 fn test_fingerprint_same_32b() {
324 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
325
326 let expr =
328 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
329 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
330 as &dyn CompositionPoly<PackedBinaryField8x32b>;
331
332 let product_composition = crate::composition::ProductComposition::<2> {};
333
334 assert_eq!(
335 crate::polynomial::composition_hash(&circuit_poly),
336 crate::polynomial::composition_hash(&product_composition)
337 );
338 }
339
340 #[test]
341 fn test_fingerprint_diff_32b() {
342 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
343
344 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
345
346 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
347 as &dyn CompositionPoly<PackedBinaryField8x32b>;
348
349 let product_composition = crate::composition::ProductComposition::<2> {};
350
351 assert_ne!(
352 crate::polynomial::composition_hash(&circuit_poly),
353 crate::polynomial::composition_hash(&product_composition)
354 );
355 }
356
357 #[test]
358 fn test_fingerprint_same_64b() {
359 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
360
361 let expr =
363 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
364 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
365 as &dyn CompositionPoly<PackedBinaryField4x64b>;
366
367 let product_composition = crate::composition::ProductComposition::<2> {};
368
369 assert_eq!(
370 crate::polynomial::composition_hash(&circuit_poly),
371 crate::polynomial::composition_hash(&product_composition)
372 );
373 }
374
375 #[test]
376 fn test_fingerprint_diff_64b() {
377 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
378
379 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
380 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
381 as &dyn CompositionPoly<PackedBinaryField4x64b>;
382
383 let product_composition = crate::composition::ProductComposition::<2> {};
384
385 assert_ne!(
386 crate::polynomial::composition_hash(&circuit_poly),
387 crate::polynomial::composition_hash(&product_composition)
388 );
389 }
390
391 #[test]
392 fn test_fingerprint_same_128b() {
393 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
394
395 let expr =
397 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
398 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
399 as &dyn CompositionPoly<PackedBinaryField2x128b>;
400
401 let product_composition = crate::composition::ProductComposition::<2> {};
402
403 assert_eq!(
404 crate::polynomial::composition_hash(&circuit_poly),
405 crate::polynomial::composition_hash(&product_composition)
406 );
407 }
408
409 #[test]
410 fn test_fingerprint_diff_128b() {
411 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
412
413 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
414 let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
415 as &dyn CompositionPoly<PackedBinaryField2x128b>;
416
417 let product_composition = crate::composition::ProductComposition::<2> {};
418
419 assert_ne!(
420 crate::polynomial::composition_hash(&circuit_poly),
421 crate::polynomial::composition_hash(&product_composition)
422 );
423 }
424}