1use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc};
4
5use auto_impl::auto_impl;
6use binius_field::{Field, PackedField};
7use binius_math::{
8 ArithExpr, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef,
9};
10use binius_utils::{bail, SerializationError, SerializationMode};
11use bytes::BufMut;
12use itertools::Itertools;
13use rand::{rngs::StdRng, SeedableRng};
14
15use super::error::Error;
16
17#[auto_impl(Arc)]
22pub trait MultivariatePoly<P>: Debug + Send + Sync {
23 fn n_vars(&self) -> usize;
25
26 fn degree(&self) -> usize;
28
29 fn evaluate(&self, query: &[P]) -> Result<P, Error>;
31
32 fn binary_tower_level(&self) -> usize;
34
35 fn erased_serialize(
38 &self,
39 write_buf: &mut dyn BufMut,
40 mode: SerializationMode,
41 ) -> Result<(), SerializationError> {
42 let _ = (write_buf, mode);
43 Err(SerializationError::SerializationNotImplemented)
44 }
45}
46
47#[derive(Clone, Debug)]
49pub struct IdentityCompositionPoly;
50
51impl<P: PackedField> CompositionPoly<P> for IdentityCompositionPoly {
52 fn n_vars(&self) -> usize {
53 1
54 }
55
56 fn degree(&self) -> usize {
57 1
58 }
59
60 fn expression(&self) -> ArithExpr<P::Scalar> {
61 binius_math::ArithExpr::Var(0)
62 }
63
64 fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
65 if query.len() != 1 {
66 bail!(binius_math::Error::IncorrectQuerySize { expected: 1 });
67 }
68 Ok(query[0])
69 }
70
71 fn binary_tower_level(&self) -> usize {
72 0
73 }
74}
75
76#[derive(Debug, Clone)]
81pub struct CompositionScalarAdapter<P, Composition> {
82 composition: Composition,
83 _marker: PhantomData<P>,
84}
85
86impl<P, Composition> CompositionScalarAdapter<P, Composition>
87where
88 P: PackedField,
89 Composition: CompositionPoly<P>,
90{
91 pub const fn new(composition: Composition) -> Self {
92 Self {
93 composition,
94 _marker: PhantomData,
95 }
96 }
97}
98
99impl<F, P, Composition> CompositionPoly<F> for CompositionScalarAdapter<P, Composition>
100where
101 F: Field,
102 P: PackedField<Scalar = F>,
103 Composition: CompositionPoly<P>,
104{
105 fn n_vars(&self) -> usize {
106 self.composition.n_vars()
107 }
108
109 fn degree(&self) -> usize {
110 self.composition.degree()
111 }
112
113 fn expression(&self) -> ArithExpr<F> {
114 self.composition.expression()
115 }
116
117 fn evaluate(&self, query: &[F]) -> Result<F, binius_math::Error> {
118 let packed_query = query.iter().copied().map(P::set_single).collect::<Vec<_>>();
119 let packed_result = self.composition.evaluate(&packed_query)?;
120 Ok(packed_result.get(0))
121 }
122
123 fn binary_tower_level(&self) -> usize {
124 self.composition.binary_tower_level()
125 }
126}
127
128#[derive(Debug, Clone)]
143pub struct MultilinearComposite<P, C, M>
144where
145 P: PackedField,
146 M: MultilinearPoly<P>,
147{
148 pub composition: C,
149 n_vars: usize,
150 pub multilinears: Vec<M>,
152 pub _marker: PhantomData<P>,
153}
154
155impl<P, C, M> MultilinearComposite<P, C, M>
156where
157 P: PackedField,
158 C: CompositionPoly<P>,
159 M: MultilinearPoly<P>,
160{
161 pub fn new(n_vars: usize, composition: C, multilinears: Vec<M>) -> Result<Self, Error> {
162 if composition.n_vars() != multilinears.len() {
163 let err_str = format!(
164 "Number of variables in composition {} does not match length of multilinears {}",
165 composition.n_vars(),
166 multilinears.len()
167 );
168 bail!(Error::MultilinearCompositeValidation(err_str));
169 }
170 for multilin in multilinears.iter().map(Borrow::borrow) {
171 if multilin.n_vars() != n_vars {
172 let err_str = format!(
173 "Number of variables in multilinear {} does not match n_vars {}",
174 multilin.n_vars(),
175 n_vars
176 );
177 bail!(Error::MultilinearCompositeValidation(err_str));
178 }
179 }
180 Ok(Self {
181 n_vars,
182 composition,
183 multilinears,
184 _marker: PhantomData,
185 })
186 }
187
188 pub fn evaluate<'a>(
189 &self,
190 query: impl Into<MultilinearQueryRef<'a, P>>,
191 ) -> Result<P::Scalar, Error> {
192 let query = query.into();
193 let evals = self
194 .multilinears
195 .iter()
196 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate(query)?)))
197 .collect::<Result<Vec<_>, _>>()?;
198 Ok(self.composition.evaluate(&evals)?.get(0))
199 }
200
201 pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
202 let evals = self
203 .multilinears
204 .iter()
205 .map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate_on_hypercube(index)?)))
206 .collect::<Result<Vec<_>, _>>()?;
207
208 Ok(self.composition.evaluate(&evals)?.get(0))
209 }
210
211 pub fn max_individual_degree(&self) -> usize {
212 self.composition.degree()
214 }
215
216 pub fn n_multilinears(&self) -> usize {
217 self.composition.n_vars()
218 }
219}
220
221impl<P, C, M> MultilinearComposite<P, C, M>
222where
223 P: PackedField,
224 C: CompositionPoly<P> + 'static,
225 M: MultilinearPoly<P>,
226{
227 pub fn to_arc_dyn_composition(self) -> MultilinearComposite<P, Arc<dyn CompositionPoly<P>>, M> {
228 MultilinearComposite {
229 n_vars: self.n_vars,
230 composition: Arc::new(self.composition),
231 multilinears: self.multilinears,
232 _marker: PhantomData,
233 }
234 }
235}
236
237impl<P, C, M> MultilinearComposite<P, C, M>
238where
239 P: PackedField,
240 M: MultilinearPoly<P>,
241{
242 pub const fn n_vars(&self) -> usize {
243 self.n_vars
244 }
245}
246
247impl<P, C, M> MultilinearComposite<P, C, M>
248where
249 P: PackedField,
250 C: Clone,
251 M: MultilinearPoly<P>,
252{
253 pub fn evaluate_partial_low(
254 &self,
255 query: MultilinearQueryRef<P>,
256 ) -> Result<MultilinearComposite<P, C, impl MultilinearPoly<P>>, Error> {
257 let new_multilinears = self
258 .multilinears
259 .iter()
260 .map(|multilin| {
261 multilin
262 .evaluate_partial_low(query)
263 .map(MLEDirectAdapter::from)
264 })
265 .collect::<Result<Vec<_>, _>>()?;
266 Ok(MultilinearComposite {
267 composition: self.composition.clone(),
268 n_vars: self.n_vars - query.n_vars(),
269 multilinears: new_multilinears,
270 _marker: PhantomData,
271 })
272 }
273}
274
275pub fn composition_hash<P: PackedField, C: CompositionPoly<P>>(composition: &C) -> P {
283 let mut rng = StdRng::from_seed([0; 32]);
284
285 let random_point = repeat_with(|| P::random(&mut rng))
286 .take(composition.n_vars())
287 .collect_vec();
288
289 composition
290 .evaluate(&random_point)
291 .expect("Failed to evaluate composition")
292}
293
294#[cfg(test)]
295mod tests {
296 use binius_math::{ArithExpr, CompositionPoly};
297
298 #[test]
299 fn test_fingerprint_same_32b() {
300 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
301
302 let expr =
304 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
305 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
306 as &dyn CompositionPoly<PackedBinaryField8x32b>;
307
308 let product_composition = crate::composition::ProductComposition::<2> {};
309
310 assert_eq!(
311 crate::polynomial::composition_hash(&circuit_poly),
312 crate::polynomial::composition_hash(&product_composition)
313 );
314 }
315
316 #[test]
317 fn test_fingerprint_diff_32b() {
318 use binius_field::{BinaryField1b, PackedBinaryField8x32b};
319
320 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
321
322 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
323 as &dyn CompositionPoly<PackedBinaryField8x32b>;
324
325 let product_composition = crate::composition::ProductComposition::<2> {};
326
327 assert_ne!(
328 crate::polynomial::composition_hash(&circuit_poly),
329 crate::polynomial::composition_hash(&product_composition)
330 );
331 }
332
333 #[test]
334 fn test_fingerprint_same_64b() {
335 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
336
337 let expr =
339 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
340 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
341 as &dyn CompositionPoly<PackedBinaryField4x64b>;
342
343 let product_composition = crate::composition::ProductComposition::<2> {};
344
345 assert_eq!(
346 crate::polynomial::composition_hash(&circuit_poly),
347 crate::polynomial::composition_hash(&product_composition)
348 );
349 }
350
351 #[test]
352 fn test_fingerprint_diff_64b() {
353 use binius_field::{BinaryField1b, PackedBinaryField4x64b};
354
355 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
356 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
357 as &dyn CompositionPoly<PackedBinaryField4x64b>;
358
359 let product_composition = crate::composition::ProductComposition::<2> {};
360
361 assert_ne!(
362 crate::polynomial::composition_hash(&circuit_poly),
363 crate::polynomial::composition_hash(&product_composition)
364 );
365 }
366
367 #[test]
368 fn test_fingerprint_same_128b() {
369 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
370
371 let expr =
373 (ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
374 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
375 as &dyn CompositionPoly<PackedBinaryField2x128b>;
376
377 let product_composition = crate::composition::ProductComposition::<2> {};
378
379 assert_eq!(
380 crate::polynomial::composition_hash(&circuit_poly),
381 crate::polynomial::composition_hash(&product_composition)
382 );
383 }
384
385 #[test]
386 fn test_fingerprint_diff_128b() {
387 use binius_field::{BinaryField1b, PackedBinaryField2x128b};
388
389 let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
390 let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr)
391 as &dyn CompositionPoly<PackedBinaryField2x128b>;
392
393 let product_composition = crate::composition::ProductComposition::<2> {};
394
395 assert_ne!(
396 crate::polynomial::composition_hash(&circuit_poly),
397 crate::polynomial::composition_hash(&product_composition)
398 );
399 }
400}