binius_core/polynomial/
multivariate.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc};
4
5use auto_impl::auto_impl;
6use binius_field::{Field, PackedField, TowerField};
7use binius_math::{
8	ArithCircuit, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef,
9};
10use binius_utils::{SerializationError, SerializationMode, SerializeBytes as _, bail};
11use bytes::BufMut;
12use itertools::Itertools;
13use rand::{SeedableRng, rngs::StdRng};
14
15use super::error::Error;
16
17/// A multivariate polynomial over a binary tower field.
18///
19/// The definition `MultivariatePoly` is nearly identical to that of [`CompositionPoly`], except
20/// that `MultivariatePoly` is _object safe_, whereas `CompositionPoly` is not.
21#[auto_impl(Arc)]
22pub trait MultivariatePoly<P>: Debug + Send + Sync {
23	/// The number of variables.
24	fn n_vars(&self) -> usize;
25
26	/// Total degree of the polynomial.
27	fn degree(&self) -> usize;
28
29	/// Evaluate the polynomial at a point in the extension field.
30	fn evaluate(&self, query: &[P]) -> Result<P, Error>;
31
32	/// Returns the maximum binary tower level of all constants in the arithmetic expression.
33	fn binary_tower_level(&self) -> usize;
34
35	/// Serialize a type erased MultivariatePoly.
36	/// Since not every MultivariatePoly implements serialization, this defaults to returning an
37	/// error.
38	fn erased_serialize(
39		&self,
40		write_buf: &mut dyn BufMut,
41		mode: SerializationMode,
42	) -> Result<(), SerializationError> {
43		let _ = (write_buf, mode);
44		Err(SerializationError::SerializationNotImplemented)
45	}
46}
47
48/// Identity composition function $g(X) = X$.
49#[derive(Clone, Debug)]
50pub struct IdentityCompositionPoly;
51
52impl<P: PackedField> CompositionPoly<P> for IdentityCompositionPoly {
53	fn n_vars(&self) -> usize {
54		1
55	}
56
57	fn degree(&self) -> usize {
58		1
59	}
60
61	fn expression(&self) -> ArithCircuit<P::Scalar> {
62		binius_math::ArithCircuit::var(0)
63	}
64
65	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
66		if query.len() != 1 {
67			bail!(binius_math::Error::IncorrectQuerySize {
68				expected: 1,
69				actual: query.len()
70			});
71		}
72		Ok(query[0])
73	}
74
75	fn binary_tower_level(&self) -> usize {
76		0
77	}
78}
79
80impl<F: TowerField> MultivariatePoly<F> for ArithCircuit<F> {
81	fn n_vars(&self) -> usize {
82		self.n_vars()
83	}
84
85	fn degree(&self) -> usize {
86		self.degree()
87	}
88
89	fn evaluate(&self, query: &[F]) -> Result<F, Error> {
90		self.evaluate(query).map_err(Error::from)
91	}
92
93	fn binary_tower_level(&self) -> usize {
94		self.binary_tower_level()
95	}
96
97	fn erased_serialize(
98		&self,
99		write_buf: &mut dyn BufMut,
100		mode: SerializationMode,
101	) -> Result<(), SerializationError> {
102		self.serialize(write_buf, mode)
103	}
104}
105
106/// An adapter that constructs a [`CompositionPoly`] for a field from a [`CompositionPoly`] for a
107/// packing of that field.
108///
109/// This is not intended for use in performance-critical code sections.
110#[derive(Debug, Clone)]
111pub struct CompositionScalarAdapter<P, Composition> {
112	composition: Composition,
113	_marker: PhantomData<P>,
114}
115
116impl<P, Composition> CompositionScalarAdapter<P, Composition>
117where
118	P: PackedField,
119	Composition: CompositionPoly<P>,
120{
121	pub const fn new(composition: Composition) -> Self {
122		Self {
123			composition,
124			_marker: PhantomData,
125		}
126	}
127}
128
129impl<F, P, Composition> CompositionPoly<F> for CompositionScalarAdapter<P, Composition>
130where
131	F: Field,
132	P: PackedField<Scalar = F>,
133	Composition: CompositionPoly<P>,
134{
135	fn n_vars(&self) -> usize {
136		self.composition.n_vars()
137	}
138
139	fn degree(&self) -> usize {
140		self.composition.degree()
141	}
142
143	fn expression(&self) -> ArithCircuit<F> {
144		self.composition.expression()
145	}
146
147	fn evaluate(&self, query: &[F]) -> Result<F, binius_math::Error> {
148		let packed_query = query.iter().copied().map(P::set_single).collect::<Vec<_>>();
149		let packed_result = self.composition.evaluate(&packed_query)?;
150		Ok(packed_result.get(0))
151	}
152
153	fn binary_tower_level(&self) -> usize {
154		self.composition.binary_tower_level()
155	}
156}
157
158/// A polynomial defined as the composition of several multilinear polynomials.
159///
160/// A $\mu$-variate multilinear composite polynomial $p(X_0, ..., X_{\mu})$ is defined as
161///
162/// $$
163/// g(f_0(X_0, ..., X_{\mu}), ..., f_{k-1}(X_0, ..., X_{\mu}))
164/// $$
165///
166/// where $g(Y_0, ..., Y_{k-1})$ is a $k$-variate polynomial and $f_0, ..., f_k$ are all multilinear
167/// in $\mu$ variables.
168///
169/// The `BM` type parameter is necessary so that we can handle the case of a `MultilinearComposite`
170/// that contains boxed trait objects, as well as the case where it directly holds some
171/// implementation of `MultilinearPoly`.
172#[derive(Debug, Clone)]
173pub struct MultilinearComposite<P, C, M>
174where
175	P: PackedField,
176	M: MultilinearPoly<P>,
177{
178	pub composition: C,
179	n_vars: usize,
180	// The multilinear polynomials. The length of the vector matches `composition.n_vars()`.
181	pub multilinears: Vec<M>,
182	pub _marker: PhantomData<P>,
183}
184
185impl<P, C, M> MultilinearComposite<P, C, M>
186where
187	P: PackedField,
188	C: CompositionPoly<P>,
189	M: MultilinearPoly<P>,
190{
191	pub fn new(n_vars: usize, composition: C, multilinears: Vec<M>) -> Result<Self, Error> {
192		if composition.n_vars() != multilinears.len() {
193			let err_str = format!(
194				"Number of variables in composition {} does not match length of multilinears {}",
195				composition.n_vars(),
196				multilinears.len()
197			);
198			bail!(Error::MultilinearCompositeValidation(err_str));
199		}
200		for multilin in multilinears.iter().map(Borrow::borrow) {
201			if multilin.n_vars() != n_vars {
202				let err_str = format!(
203					"Number of variables in multilinear {} does not match n_vars {}",
204					multilin.n_vars(),
205					n_vars
206				);
207				bail!(Error::MultilinearCompositeValidation(err_str));
208			}
209		}
210		Ok(Self {
211			n_vars,
212			composition,
213			multilinears,
214			_marker: PhantomData,
215		})
216	}
217
218	pub fn evaluate<'a>(
219		&self,
220		query: impl Into<MultilinearQueryRef<'a, P>>,
221	) -> Result<P::Scalar, Error> {
222		let query = query.into();
223		let evals = self
224			.multilinears
225			.iter()
226			.map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate(query)?)))
227			.collect::<Result<Vec<_>, _>>()?;
228		Ok(self.composition.evaluate(&evals)?.get(0))
229	}
230
231	pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
232		let evals = self
233			.multilinears
234			.iter()
235			.map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate_on_hypercube(index)?)))
236			.collect::<Result<Vec<_>, _>>()?;
237
238		Ok(self.composition.evaluate(&evals)?.get(0))
239	}
240
241	pub fn max_individual_degree(&self) -> usize {
242		// Maximum individual degree of the multilinear composite equals composition degree
243		self.composition.degree()
244	}
245
246	pub fn n_multilinears(&self) -> usize {
247		self.composition.n_vars()
248	}
249}
250
251impl<P, C, M> MultilinearComposite<P, C, M>
252where
253	P: PackedField,
254	C: CompositionPoly<P> + 'static,
255	M: MultilinearPoly<P>,
256{
257	pub fn to_arc_dyn_composition(self) -> MultilinearComposite<P, Arc<dyn CompositionPoly<P>>, M> {
258		MultilinearComposite {
259			n_vars: self.n_vars,
260			composition: Arc::new(self.composition),
261			multilinears: self.multilinears,
262			_marker: PhantomData,
263		}
264	}
265}
266
267impl<P, C, M> MultilinearComposite<P, C, M>
268where
269	P: PackedField,
270	M: MultilinearPoly<P>,
271{
272	pub const fn n_vars(&self) -> usize {
273		self.n_vars
274	}
275}
276
277impl<P, C, M> MultilinearComposite<P, C, M>
278where
279	P: PackedField,
280	C: Clone,
281	M: MultilinearPoly<P>,
282{
283	pub fn evaluate_partial_low(
284		&self,
285		query: MultilinearQueryRef<P>,
286	) -> Result<MultilinearComposite<P, C, impl MultilinearPoly<P>>, Error> {
287		let new_multilinears = self
288			.multilinears
289			.iter()
290			.map(|multilin| {
291				multilin
292					.evaluate_partial_low(query)
293					.map(MLEDirectAdapter::from)
294			})
295			.collect::<Result<Vec<_>, _>>()?;
296		Ok(MultilinearComposite {
297			composition: self.composition.clone(),
298			n_vars: self.n_vars - query.n_vars(),
299			multilinears: new_multilinears,
300			_marker: PhantomData,
301		})
302	}
303}
304
305/// Fingerprinting for composition polynomials done by evaluation at a deterministic random point.
306/// Outputs f(r_0,...,r_n-1) where f is a composite and the r_i are the components of the random
307/// point.
308///
309/// Probabilistic collision resistance comes from Schwartz-Zippel on the equation f(x_0,...,x_n-1) =
310/// g(x_0,...,x_n-1) for two distinct multivariate polynomials f and g.
311///
312/// NOTE: THIS IS NOT ADVERSARIALLY COLLISION RESISTANT, COLLISIONS CAN BE MANUFACTURED EASILY
313pub fn composition_hash<P: PackedField, C: CompositionPoly<P>>(composition: &C) -> P {
314	let mut rng = StdRng::from_seed([0; 32]);
315
316	let random_point = repeat_with(|| P::random(&mut rng))
317		.take(composition.n_vars())
318		.collect_vec();
319
320	composition
321		.evaluate(&random_point)
322		.expect("Failed to evaluate composition")
323}
324
325#[cfg(test)]
326mod tests {
327	use binius_fast_compute::arith_circuit::ArithCircuitPoly;
328	use binius_math::{ArithExpr, CompositionPoly};
329
330	#[test]
331	fn test_fingerprint_same_32b() {
332		use binius_field::{BinaryField1b, PackedBinaryField8x32b};
333
334		//Complicated circuit for (x0 + x1) * x0 + x0^2
335		let expr =
336			(ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
337		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
338			as &dyn CompositionPoly<PackedBinaryField8x32b>;
339
340		let product_composition = crate::composition::ProductComposition::<2> {};
341
342		assert_eq!(
343			crate::polynomial::composition_hash(&circuit_poly),
344			crate::polynomial::composition_hash(&product_composition)
345		);
346	}
347
348	#[test]
349	fn test_fingerprint_diff_32b() {
350		use binius_field::{BinaryField1b, PackedBinaryField8x32b};
351
352		let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
353
354		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
355			as &dyn CompositionPoly<PackedBinaryField8x32b>;
356
357		let product_composition = crate::composition::ProductComposition::<2> {};
358
359		assert_ne!(
360			crate::polynomial::composition_hash(&circuit_poly),
361			crate::polynomial::composition_hash(&product_composition)
362		);
363	}
364
365	#[test]
366	fn test_fingerprint_same_64b() {
367		use binius_field::{BinaryField1b, PackedBinaryField4x64b};
368
369		// Complicated circuit for (x0 + x1) * x0 + x0^2
370		let expr =
371			(ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
372		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
373			as &dyn CompositionPoly<PackedBinaryField4x64b>;
374
375		let product_composition = crate::composition::ProductComposition::<2> {};
376
377		assert_eq!(
378			crate::polynomial::composition_hash(&circuit_poly),
379			crate::polynomial::composition_hash(&product_composition)
380		);
381	}
382
383	#[test]
384	fn test_fingerprint_diff_64b() {
385		use binius_field::{BinaryField1b, PackedBinaryField4x64b};
386
387		let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
388		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
389			as &dyn CompositionPoly<PackedBinaryField4x64b>;
390
391		let product_composition = crate::composition::ProductComposition::<2> {};
392
393		assert_ne!(
394			crate::polynomial::composition_hash(&circuit_poly),
395			crate::polynomial::composition_hash(&product_composition)
396		);
397	}
398
399	#[test]
400	fn test_fingerprint_same_128b() {
401		use binius_field::{BinaryField1b, PackedBinaryField2x128b};
402
403		// Complicated circuit for (x0 + x1) * x0 + x0^2
404		let expr =
405			(ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
406		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
407			as &dyn CompositionPoly<PackedBinaryField2x128b>;
408
409		let product_composition = crate::composition::ProductComposition::<2> {};
410
411		assert_eq!(
412			crate::polynomial::composition_hash(&circuit_poly),
413			crate::polynomial::composition_hash(&product_composition)
414		);
415	}
416
417	#[test]
418	fn test_fingerprint_diff_128b() {
419		use binius_field::{BinaryField1b, PackedBinaryField2x128b};
420
421		let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
422		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
423			as &dyn CompositionPoly<PackedBinaryField2x128b>;
424
425		let product_composition = crate::composition::ProductComposition::<2> {};
426
427		assert_ne!(
428			crate::polynomial::composition_hash(&circuit_poly),
429			crate::polynomial::composition_hash(&product_composition)
430		);
431	}
432}