binius_core/polynomial/
multivariate.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc};
4
5use auto_impl::auto_impl;
6use binius_field::{Field, PackedField, TowerField};
7use binius_math::{
8	ArithCircuit, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef,
9};
10use binius_utils::{SerializationError, SerializationMode, bail};
11use bytes::BufMut;
12use itertools::Itertools;
13use rand::{SeedableRng, rngs::StdRng};
14
15use super::error::Error;
16
17/// A multivariate polynomial over a binary tower field.
18///
19/// The definition `MultivariatePoly` is nearly identical to that of [`CompositionPoly`], except
20/// that `MultivariatePoly` is _object safe_, whereas `CompositionPoly` is not.
21#[auto_impl(Arc)]
22pub trait MultivariatePoly<P>: Debug + Send + Sync {
23	/// The number of variables.
24	fn n_vars(&self) -> usize;
25
26	/// Total degree of the polynomial.
27	fn degree(&self) -> usize;
28
29	/// Evaluate the polynomial at a point in the extension field.
30	fn evaluate(&self, query: &[P]) -> Result<P, Error>;
31
32	/// Returns the maximum binary tower level of all constants in the arithmetic expression.
33	fn binary_tower_level(&self) -> usize;
34
35	/// Serialize a type erased MultivariatePoly.
36	/// Since not every MultivariatePoly implements serialization, this defaults to returning an
37	/// error.
38	fn erased_serialize(
39		&self,
40		write_buf: &mut dyn BufMut,
41		mode: SerializationMode,
42	) -> Result<(), SerializationError> {
43		let _ = (write_buf, mode);
44		Err(SerializationError::SerializationNotImplemented)
45	}
46}
47
48/// Identity composition function $g(X) = X$.
49#[derive(Clone, Debug)]
50pub struct IdentityCompositionPoly;
51
52impl<P: PackedField> CompositionPoly<P> for IdentityCompositionPoly {
53	fn n_vars(&self) -> usize {
54		1
55	}
56
57	fn degree(&self) -> usize {
58		1
59	}
60
61	fn expression(&self) -> ArithCircuit<P::Scalar> {
62		binius_math::ArithCircuit::var(0)
63	}
64
65	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
66		if query.len() != 1 {
67			bail!(binius_math::Error::IncorrectQuerySize {
68				expected: 1,
69				actual: query.len()
70			});
71		}
72		Ok(query[0])
73	}
74
75	fn binary_tower_level(&self) -> usize {
76		0
77	}
78}
79
80impl<F: TowerField> MultivariatePoly<F> for ArithCircuit<F> {
81	fn n_vars(&self) -> usize {
82		self.n_vars()
83	}
84
85	fn degree(&self) -> usize {
86		self.degree()
87	}
88
89	fn evaluate(&self, query: &[F]) -> Result<F, Error> {
90		self.evaluate(query).map_err(Error::from)
91	}
92
93	fn binary_tower_level(&self) -> usize {
94		self.binary_tower_level()
95	}
96}
97
98/// An adapter that constructs a [`CompositionPoly`] for a field from a [`CompositionPoly`] for a
99/// packing of that field.
100///
101/// This is not intended for use in performance-critical code sections.
102#[derive(Debug, Clone)]
103pub struct CompositionScalarAdapter<P, Composition> {
104	composition: Composition,
105	_marker: PhantomData<P>,
106}
107
108impl<P, Composition> CompositionScalarAdapter<P, Composition>
109where
110	P: PackedField,
111	Composition: CompositionPoly<P>,
112{
113	pub const fn new(composition: Composition) -> Self {
114		Self {
115			composition,
116			_marker: PhantomData,
117		}
118	}
119}
120
121impl<F, P, Composition> CompositionPoly<F> for CompositionScalarAdapter<P, Composition>
122where
123	F: Field,
124	P: PackedField<Scalar = F>,
125	Composition: CompositionPoly<P>,
126{
127	fn n_vars(&self) -> usize {
128		self.composition.n_vars()
129	}
130
131	fn degree(&self) -> usize {
132		self.composition.degree()
133	}
134
135	fn expression(&self) -> ArithCircuit<F> {
136		self.composition.expression()
137	}
138
139	fn evaluate(&self, query: &[F]) -> Result<F, binius_math::Error> {
140		let packed_query = query.iter().copied().map(P::set_single).collect::<Vec<_>>();
141		let packed_result = self.composition.evaluate(&packed_query)?;
142		Ok(packed_result.get(0))
143	}
144
145	fn binary_tower_level(&self) -> usize {
146		self.composition.binary_tower_level()
147	}
148}
149
150/// A polynomial defined as the composition of several multilinear polynomials.
151///
152/// A $\mu$-variate multilinear composite polynomial $p(X_0, ..., X_{\mu})$ is defined as
153///
154/// $$
155/// g(f_0(X_0, ..., X_{\mu}), ..., f_{k-1}(X_0, ..., X_{\mu}))
156/// $$
157///
158/// where $g(Y_0, ..., Y_{k-1})$ is a $k$-variate polynomial and $f_0, ..., f_k$ are all multilinear
159/// in $\mu$ variables.
160///
161/// The `BM` type parameter is necessary so that we can handle the case of a `MultilinearComposite`
162/// that contains boxed trait objects, as well as the case where it directly holds some
163/// implementation of `MultilinearPoly`.
164#[derive(Debug, Clone)]
165pub struct MultilinearComposite<P, C, M>
166where
167	P: PackedField,
168	M: MultilinearPoly<P>,
169{
170	pub composition: C,
171	n_vars: usize,
172	// The multilinear polynomials. The length of the vector matches `composition.n_vars()`.
173	pub multilinears: Vec<M>,
174	pub _marker: PhantomData<P>,
175}
176
177impl<P, C, M> MultilinearComposite<P, C, M>
178where
179	P: PackedField,
180	C: CompositionPoly<P>,
181	M: MultilinearPoly<P>,
182{
183	pub fn new(n_vars: usize, composition: C, multilinears: Vec<M>) -> Result<Self, Error> {
184		if composition.n_vars() != multilinears.len() {
185			let err_str = format!(
186				"Number of variables in composition {} does not match length of multilinears {}",
187				composition.n_vars(),
188				multilinears.len()
189			);
190			bail!(Error::MultilinearCompositeValidation(err_str));
191		}
192		for multilin in multilinears.iter().map(Borrow::borrow) {
193			if multilin.n_vars() != n_vars {
194				let err_str = format!(
195					"Number of variables in multilinear {} does not match n_vars {}",
196					multilin.n_vars(),
197					n_vars
198				);
199				bail!(Error::MultilinearCompositeValidation(err_str));
200			}
201		}
202		Ok(Self {
203			n_vars,
204			composition,
205			multilinears,
206			_marker: PhantomData,
207		})
208	}
209
210	pub fn evaluate<'a>(
211		&self,
212		query: impl Into<MultilinearQueryRef<'a, P>>,
213	) -> Result<P::Scalar, Error> {
214		let query = query.into();
215		let evals = self
216			.multilinears
217			.iter()
218			.map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate(query)?)))
219			.collect::<Result<Vec<_>, _>>()?;
220		Ok(self.composition.evaluate(&evals)?.get(0))
221	}
222
223	pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
224		let evals = self
225			.multilinears
226			.iter()
227			.map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate_on_hypercube(index)?)))
228			.collect::<Result<Vec<_>, _>>()?;
229
230		Ok(self.composition.evaluate(&evals)?.get(0))
231	}
232
233	pub fn max_individual_degree(&self) -> usize {
234		// Maximum individual degree of the multilinear composite equals composition degree
235		self.composition.degree()
236	}
237
238	pub fn n_multilinears(&self) -> usize {
239		self.composition.n_vars()
240	}
241}
242
243impl<P, C, M> MultilinearComposite<P, C, M>
244where
245	P: PackedField,
246	C: CompositionPoly<P> + 'static,
247	M: MultilinearPoly<P>,
248{
249	pub fn to_arc_dyn_composition(self) -> MultilinearComposite<P, Arc<dyn CompositionPoly<P>>, M> {
250		MultilinearComposite {
251			n_vars: self.n_vars,
252			composition: Arc::new(self.composition),
253			multilinears: self.multilinears,
254			_marker: PhantomData,
255		}
256	}
257}
258
259impl<P, C, M> MultilinearComposite<P, C, M>
260where
261	P: PackedField,
262	M: MultilinearPoly<P>,
263{
264	pub const fn n_vars(&self) -> usize {
265		self.n_vars
266	}
267}
268
269impl<P, C, M> MultilinearComposite<P, C, M>
270where
271	P: PackedField,
272	C: Clone,
273	M: MultilinearPoly<P>,
274{
275	pub fn evaluate_partial_low(
276		&self,
277		query: MultilinearQueryRef<P>,
278	) -> Result<MultilinearComposite<P, C, impl MultilinearPoly<P>>, Error> {
279		let new_multilinears = self
280			.multilinears
281			.iter()
282			.map(|multilin| {
283				multilin
284					.evaluate_partial_low(query)
285					.map(MLEDirectAdapter::from)
286			})
287			.collect::<Result<Vec<_>, _>>()?;
288		Ok(MultilinearComposite {
289			composition: self.composition.clone(),
290			n_vars: self.n_vars - query.n_vars(),
291			multilinears: new_multilinears,
292			_marker: PhantomData,
293		})
294	}
295}
296
297/// Fingerprinting for composition polynomials done by evaluation at a deterministic random point.
298/// Outputs f(r_0,...,r_n-1) where f is a composite and the r_i are the components of the random
299/// point.
300///
301/// Probabilistic collision resistance comes from Schwartz-Zippel on the equation f(x_0,...,x_n-1) =
302/// g(x_0,...,x_n-1) for two distinct multivariate polynomials f and g.
303///
304/// NOTE: THIS IS NOT ADVERSARIALLY COLLISION RESISTANT, COLLISIONS CAN BE MANUFACTURED EASILY
305pub fn composition_hash<P: PackedField, C: CompositionPoly<P>>(composition: &C) -> P {
306	let mut rng = StdRng::from_seed([0; 32]);
307
308	let random_point = repeat_with(|| P::random(&mut rng))
309		.take(composition.n_vars())
310		.collect_vec();
311
312	composition
313		.evaluate(&random_point)
314		.expect("Failed to evaluate composition")
315}
316
317#[cfg(test)]
318mod tests {
319	use binius_fast_compute::arith_circuit::ArithCircuitPoly;
320	use binius_math::{ArithExpr, CompositionPoly};
321
322	#[test]
323	fn test_fingerprint_same_32b() {
324		use binius_field::{BinaryField1b, PackedBinaryField8x32b};
325
326		//Complicated circuit for (x0 + x1) * x0 + x0^2
327		let expr =
328			(ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
329		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
330			as &dyn CompositionPoly<PackedBinaryField8x32b>;
331
332		let product_composition = crate::composition::ProductComposition::<2> {};
333
334		assert_eq!(
335			crate::polynomial::composition_hash(&circuit_poly),
336			crate::polynomial::composition_hash(&product_composition)
337		);
338	}
339
340	#[test]
341	fn test_fingerprint_diff_32b() {
342		use binius_field::{BinaryField1b, PackedBinaryField8x32b};
343
344		let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
345
346		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
347			as &dyn CompositionPoly<PackedBinaryField8x32b>;
348
349		let product_composition = crate::composition::ProductComposition::<2> {};
350
351		assert_ne!(
352			crate::polynomial::composition_hash(&circuit_poly),
353			crate::polynomial::composition_hash(&product_composition)
354		);
355	}
356
357	#[test]
358	fn test_fingerprint_same_64b() {
359		use binius_field::{BinaryField1b, PackedBinaryField4x64b};
360
361		// Complicated circuit for (x0 + x1) * x0 + x0^2
362		let expr =
363			(ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
364		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
365			as &dyn CompositionPoly<PackedBinaryField4x64b>;
366
367		let product_composition = crate::composition::ProductComposition::<2> {};
368
369		assert_eq!(
370			crate::polynomial::composition_hash(&circuit_poly),
371			crate::polynomial::composition_hash(&product_composition)
372		);
373	}
374
375	#[test]
376	fn test_fingerprint_diff_64b() {
377		use binius_field::{BinaryField1b, PackedBinaryField4x64b};
378
379		let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
380		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
381			as &dyn CompositionPoly<PackedBinaryField4x64b>;
382
383		let product_composition = crate::composition::ProductComposition::<2> {};
384
385		assert_ne!(
386			crate::polynomial::composition_hash(&circuit_poly),
387			crate::polynomial::composition_hash(&product_composition)
388		);
389	}
390
391	#[test]
392	fn test_fingerprint_same_128b() {
393		use binius_field::{BinaryField1b, PackedBinaryField2x128b};
394
395		// Complicated circuit for (x0 + x1) * x0 + x0^2
396		let expr =
397			(ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
398		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
399			as &dyn CompositionPoly<PackedBinaryField2x128b>;
400
401		let product_composition = crate::composition::ProductComposition::<2> {};
402
403		assert_eq!(
404			crate::polynomial::composition_hash(&circuit_poly),
405			crate::polynomial::composition_hash(&product_composition)
406		);
407	}
408
409	#[test]
410	fn test_fingerprint_diff_128b() {
411		use binius_field::{BinaryField1b, PackedBinaryField2x128b};
412
413		let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
414		let circuit_poly = &ArithCircuitPoly::<BinaryField1b>::new(expr.into())
415			as &dyn CompositionPoly<PackedBinaryField2x128b>;
416
417		let product_composition = crate::composition::ProductComposition::<2> {};
418
419		assert_ne!(
420			crate::polynomial::composition_hash(&circuit_poly),
421			crate::polynomial::composition_hash(&product_composition)
422		);
423	}
424}