1use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc};
4
5use auto_impl::auto_impl;
6use binius_field::{Field, PackedField, TowerField};
7use binius_math::{
8 ArithCircuit, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef,
9};
10use binius_utils::{bail, SerializationError, SerializationMode};
11use bytes::BufMut;
12use itertools::Itertools;
13use rand::{rngs::StdRng, SeedableRng};
14
15use super::error::Error;
16
17#[auto_impl(Arc)]
22pub trait MultivariatePoly<P>: Debug + Send + Sync {
23 fn n_vars(&self) -> usize;
25
26 fn degree(&self) -> usize;
28
29 fn evaluate(&self, query: &[P]) -> Result<P, Error>;
31
32 fn binary_tower_level(&self) -> usize;
34
35 fn erased_serialize(
38 &self,
39 write_buf: &mut dyn BufMut,
40 mode: SerializationMode,
41 ) -> Result<(), SerializationError> {
42 let _ = (write_buf, mode);
43 Err(SerializationError::SerializationNotImplemented)
44 }
45}
46
47#[derive(Clone, Debug)]
49pub struct IdentityCompositionPoly;
50
51impl<P: PackedField> CompositionPoly<P> for IdentityCompositionPoly {
52 fn n_vars(&self) -> usize {
53 1
54 }
55
56 fn degree(&self) -> usize {
57 1
58 }
59
60 fn expression(&self) -> ArithCircuit<P::Scalar> {
61 binius_math::ArithCircuit::var(0)
62 }
63
64 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
65 if query.len() != 1 {
66 bail!(binius_math::Error::IncorrectQuerySize { expected: 1 });
67 }
68 Ok(query[0])
69 }
70
71 fn binary_tower_level(&self) -> usize {
72 0
73 }
74}
75
76impl<F: TowerField> MultivariatePoly<F> for ArithCircuit<F> {
77 fn n_vars(&self) -> usize {
78 self.n_vars()
79 }
80
81 fn degree(&self) -> usize {
82 self.degree()
83 }
84
85 fn evaluate(&self, query: &[F]) -> Result<F, Error> {
86 self.evaluate(query).map_err(Error::from)
87 }
88
89 fn binary_tower_level(&self) -> usize {
90 self.binary_tower_level()
91 }
92}
93
94#[derive(Debug, Clone)]
99pub struct CompositionScalarAdapter<P, Composition> {
100 composition: Composition,
101 _marker: PhantomData<P>,
102}
103
104impl<P, Composition> CompositionScalarAdapter<P, Composition>
105where
106 P: PackedField,
107 Composition: CompositionPoly<P>,
108{
109 pub const fn new(composition: Composition) -> Self {
110 Self {
111 composition,
112 _marker: PhantomData,
113 }
114 }
115}
116
117impl<F, P, Composition> CompositionPoly<F> for CompositionScalarAdapter<P, Composition>
118where
119 F: Field,
120 P: PackedField<Scalar = F>,
121 Composition: CompositionPoly<P>,
122{
123 fn n_vars(&self) -> usize {
124 self.composition.n_vars()
125 }
126
127 fn degree(&self) -> usize {
128 self.composition.degree()
129 }
130
131 fn expression(&self) -> ArithCircuit<F> {
132 self.composition.expression()
133 }
134
135 fn evaluate(&self, query: &[F]) -> Result<F, binius_math::Error> {
136 let packed_query = query.iter().copied().map(P::set_single).collect::<Vec<_>>();
137 let packed_result = self.composition.evaluate(&packed_query)?;
138 Ok(packed_result.get(0))
139 }
140
141 fn binary_tower_level(&self) -> usize {
142 self.composition.binary_tower_level()
143 }
144}
145
146#[derive(Debug, Clone)]
161pub struct MultilinearComposite<P, C, M>
162where
163 P: PackedField,
164 M: MultilinearPoly<P>,
165{
166 pub composition: C,
167 n_vars: usize,
168 pub multilinears: Vec<M>,
170 pub _marker: PhantomData<P>,
171}
172
173impl<P, C, M> MultilinearComposite<P, C, M>
174where
175 P: PackedField,
176 C: CompositionPoly<P>,
177 M: MultilinearPoly<P>,
178{
179 pub fn new(n_vars: usize, composition: C, multilinears: Vec<M>) -> Result<Self, Error> {
180 if composition.n_vars() != multilinears.len() {
181 let err_str = format!(
182 "Number of variables in composition {} does not match length of multilinears {}",
183 composition.n_vars(),
184 multilinears.len()
185 );
186 bail!(Error::MultilinearCompositeValidation(err_str));
187 }
188 for multilin in multilinears.iter().map(Borrow::borrow) {
189 if multilin.n_vars() != n_vars {
190 let err_str = format!(
191 "Number of variables in multilinear {} does not match n_vars {}",
192 multilin.n_vars(),
193 n_vars
194 );
195 bail!(Error::MultilinearCompositeValidation(err_str));
196 }
197 }
198 Ok(Self {
199 n_vars,
200 composition,
201 multilinears,
202 _marker: PhantomData,
203 })
204 }
205
206 pub fn evaluate<'a>(
207 &self,
208 query: impl Into<MultilinearQueryRef<'a, P>>,
209 ) -> Result<P::Scalar, Error> {
210 let query = query.into();
211 let evals = self
212 .multilinears
213 .iter()
214 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate(query)?)))
215 .collect::<Result<Vec<_>, _>>()?;
216 Ok(self.composition.evaluate(&evals)?.get(0))
217 }
218
219 pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
220 let evals = self
221 .multilinears
222 .iter()
223 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate_on_hypercube(index)?)))
224 .collect::<Result<Vec<_>, _>>()?;
225
226 Ok(self.composition.evaluate(&evals)?.get(0))
227 }
228
229 pub fn max_individual_degree(&self) -> usize {
230 self.composition.degree()
232 }
233
234 pub fn n_multilinears(&self) -> usize {
235 self.composition.n_vars()
236 }
237}
238
239impl<P, C, M> MultilinearComposite<P, C, M>
240where
241 P: PackedField,
242 C: CompositionPoly<P> + 'static,
243 M: MultilinearPoly<P>,
244{
245 pub fn to_arc_dyn_composition(self) -> MultilinearComposite<P, Arc<dyn CompositionPoly<P>>, M> {
246 MultilinearComposite {
247 n_vars: self.n_vars,
248 composition: Arc::new(self.composition),
249 multilinears: self.multilinears,
250 _marker: PhantomData,
251 }
252 }
253}
254
255impl<P, C, M> MultilinearComposite<P, C, M>
256where
257 P: PackedField,
258 M: MultilinearPoly<P>,
259{
260 pub const fn n_vars(&self) -> usize {
261 self.n_vars
262 }
263}
264
265impl<P, C, M> MultilinearComposite<P, C, M>
266where
267 P: PackedField,
268 C: Clone,
269 M: MultilinearPoly<P>,
270{
271 pub fn evaluate_partial_low(
272 &self,
273 query: MultilinearQueryRef<P>,
274 ) -> Result<MultilinearComposite<P, C, impl MultilinearPoly<P>>, Error> {
275 let new_multilinears = self
276 .multilinears
277 .iter()
278 .map(|multilin| {
279 multilin
280 .evaluate_partial_low(query)
281 .map(MLEDirectAdapter::from)
282 })
283 .collect::<Result<Vec<_>, _>>()?;
284 Ok(MultilinearComposite {
285 composition: self.composition.clone(),
286 n_vars: self.n_vars - query.n_vars(),
287 multilinears: new_multilinears,
288 _marker: PhantomData,
289 })
290 }
291}
292
293pub fn composition_hash<P: PackedField, C: CompositionPoly<P>>(composition: &C) -> P {
301 let mut rng = StdRng::from_seed([0; 32]);
302
303 let random_point = repeat_with(|| P::random(&mut rng))
304 .take(composition.n_vars())
305 .collect_vec();
306
307 composition
308 .evaluate(&random_point)
309 .expect("Failed to evaluate composition")
310}
311
312#[cfg(test)]
313mod tests {
314 use binius_math::{ArithExpr, CompositionPoly};
315
316 #[test]
317 fn test_fingerprint_same_32b() {
318 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
319
320 let expr =
322 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
323 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
324 as &dyn CompositionPoly<PackedBinaryField8x32b>;
325
326 let product_composition = crate::composition::ProductComposition::<2> {};
327
328 assert_eq!(
329 crate::polynomial::composition_hash(&circuit_poly),
330 crate::polynomial::composition_hash(&product_composition)
331 );
332 }
333
334 #[test]
335 fn test_fingerprint_diff_32b() {
336 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
337
338 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
339
340 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
341 as &dyn CompositionPoly<PackedBinaryField8x32b>;
342
343 let product_composition = crate::composition::ProductComposition::<2> {};
344
345 assert_ne!(
346 crate::polynomial::composition_hash(&circuit_poly),
347 crate::polynomial::composition_hash(&product_composition)
348 );
349 }
350
351 #[test]
352 fn test_fingerprint_same_64b() {
353 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
354
355 let expr =
357 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
358 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
359 as &dyn CompositionPoly<PackedBinaryField4x64b>;
360
361 let product_composition = crate::composition::ProductComposition::<2> {};
362
363 assert_eq!(
364 crate::polynomial::composition_hash(&circuit_poly),
365 crate::polynomial::composition_hash(&product_composition)
366 );
367 }
368
369 #[test]
370 fn test_fingerprint_diff_64b() {
371 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
372
373 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
374 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
375 as &dyn CompositionPoly<PackedBinaryField4x64b>;
376
377 let product_composition = crate::composition::ProductComposition::<2> {};
378
379 assert_ne!(
380 crate::polynomial::composition_hash(&circuit_poly),
381 crate::polynomial::composition_hash(&product_composition)
382 );
383 }
384
385 #[test]
386 fn test_fingerprint_same_128b() {
387 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
388
389 let expr =
391 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
392 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
393 as &dyn CompositionPoly<PackedBinaryField2x128b>;
394
395 let product_composition = crate::composition::ProductComposition::<2> {};
396
397 assert_eq!(
398 crate::polynomial::composition_hash(&circuit_poly),
399 crate::polynomial::composition_hash(&product_composition)
400 );
401 }
402
403 #[test]
404 fn test_fingerprint_diff_128b() {
405 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
406
407 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
408 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
409 as &dyn CompositionPoly<PackedBinaryField2x128b>;
410
411 let product_composition = crate::composition::ProductComposition::<2> {};
412
413 assert_ne!(
414 crate::polynomial::composition_hash(&circuit_poly),
415 crate::polynomial::composition_hash(&product_composition)
416 );
417 }
418}