binius_core/polynomial/
multivariate.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{borrow::Borrow, fmt::Debug, iter::repeat_with, marker::PhantomData, sync::Arc};
4
5use auto_impl::auto_impl;
6use binius_field::{Field, PackedField, TowerField};
7use binius_math::{
8	ArithCircuit, CompositionPoly, MLEDirectAdapter, MultilinearPoly, MultilinearQueryRef,
9};
10use binius_utils::{bail, SerializationError, SerializationMode};
11use bytes::BufMut;
12use itertools::Itertools;
13use rand::{rngs::StdRng, SeedableRng};
14
15use super::error::Error;
16
17/// A multivariate polynomial over a binary tower field.
18///
19/// The definition `MultivariatePoly` is nearly identical to that of [`CompositionPoly`], except that
20/// `MultivariatePoly` is _object safe_, whereas `CompositionPoly` is not.
21#[auto_impl(Arc)]
22pub trait MultivariatePoly<P>: Debug + Send + Sync {
23	/// The number of variables.
24	fn n_vars(&self) -> usize;
25
26	/// Total degree of the polynomial.
27	fn degree(&self) -> usize;
28
29	/// Evaluate the polynomial at a point in the extension field.
30	fn evaluate(&self, query: &[P]) -> Result<P, Error>;
31
32	/// Returns the maximum binary tower level of all constants in the arithmetic expression.
33	fn binary_tower_level(&self) -> usize;
34
35	/// Serialize a type erased MultivariatePoly.
36	/// Since not every MultivariatePoly implements serialization, this defaults to returning an error.
37	fn erased_serialize(
38		&self,
39		write_buf: &mut dyn BufMut,
40		mode: SerializationMode,
41	) -> Result<(), SerializationError> {
42		let _ = (write_buf, mode);
43		Err(SerializationError::SerializationNotImplemented)
44	}
45}
46
47/// Identity composition function $g(X) = X$.
48#[derive(Clone, Debug)]
49pub struct IdentityCompositionPoly;
50
51impl<P: PackedField> CompositionPoly<P> for IdentityCompositionPoly {
52	fn n_vars(&self) -> usize {
53		1
54	}
55
56	fn degree(&self) -> usize {
57		1
58	}
59
60	fn expression(&self) -> ArithCircuit<P::Scalar> {
61		binius_math::ArithCircuit::var(0)
62	}
63
64	fn evaluate(&self, query: &[P]) -> Result<P, binius_math::Error> {
65		if query.len() != 1 {
66			bail!(binius_math::Error::IncorrectQuerySize { expected: 1 });
67		}
68		Ok(query[0])
69	}
70
71	fn binary_tower_level(&self) -> usize {
72		0
73	}
74}
75
76impl<F: TowerField> MultivariatePoly<F> for ArithCircuit<F> {
77	fn n_vars(&self) -> usize {
78		self.n_vars()
79	}
80
81	fn degree(&self) -> usize {
82		self.degree()
83	}
84
85	fn evaluate(&self, query: &[F]) -> Result<F, Error> {
86		self.evaluate(query).map_err(Error::from)
87	}
88
89	fn binary_tower_level(&self) -> usize {
90		self.binary_tower_level()
91	}
92}
93
94/// An adapter that constructs a [`CompositionPoly`] for a field from a [`CompositionPoly`] for a
95/// packing of that field.
96///
97/// This is not intended for use in performance-critical code sections.
98#[derive(Debug, Clone)]
99pub struct CompositionScalarAdapter<P, Composition> {
100	composition: Composition,
101	_marker: PhantomData<P>,
102}
103
104impl<P, Composition> CompositionScalarAdapter<P, Composition>
105where
106	P: PackedField,
107	Composition: CompositionPoly<P>,
108{
109	pub const fn new(composition: Composition) -> Self {
110		Self {
111			composition,
112			_marker: PhantomData,
113		}
114	}
115}
116
117impl<F, P, Composition> CompositionPoly<F> for CompositionScalarAdapter<P, Composition>
118where
119	F: Field,
120	P: PackedField<Scalar = F>,
121	Composition: CompositionPoly<P>,
122{
123	fn n_vars(&self) -> usize {
124		self.composition.n_vars()
125	}
126
127	fn degree(&self) -> usize {
128		self.composition.degree()
129	}
130
131	fn expression(&self) -> ArithCircuit<F> {
132		self.composition.expression()
133	}
134
135	fn evaluate(&self, query: &[F]) -> Result<F, binius_math::Error> {
136		let packed_query = query.iter().copied().map(P::set_single).collect::<Vec<_>>();
137		let packed_result = self.composition.evaluate(&packed_query)?;
138		Ok(packed_result.get(0))
139	}
140
141	fn binary_tower_level(&self) -> usize {
142		self.composition.binary_tower_level()
143	}
144}
145
146/// A polynomial defined as the composition of several multilinear polynomials.
147///
148/// A $\mu$-variate multilinear composite polynomial $p(X_0, ..., X_{\mu})$ is defined as
149///
150/// $$
151/// g(f_0(X_0, ..., X_{\mu}), ..., f_{k-1}(X_0, ..., X_{\mu}))
152/// $$
153///
154/// where $g(Y_0, ..., Y_{k-1})$ is a $k$-variate polynomial and $f_0, ..., f_k$ are all multilinear
155/// in $\mu$ variables.
156///
157/// The `BM` type parameter is necessary so that we can handle the case of a `MultilinearComposite`
158/// that contains boxed trait objects, as well as the case where it directly holds some
159/// implementation of `MultilinearPoly`.
160#[derive(Debug, Clone)]
161pub struct MultilinearComposite<P, C, M>
162where
163	P: PackedField,
164	M: MultilinearPoly<P>,
165{
166	pub composition: C,
167	n_vars: usize,
168	// The multilinear polynomials. The length of the vector matches `composition.n_vars()`.
169	pub multilinears: Vec<M>,
170	pub _marker: PhantomData<P>,
171}
172
173impl<P, C, M> MultilinearComposite<P, C, M>
174where
175	P: PackedField,
176	C: CompositionPoly<P>,
177	M: MultilinearPoly<P>,
178{
179	pub fn new(n_vars: usize, composition: C, multilinears: Vec<M>) -> Result<Self, Error> {
180		if composition.n_vars() != multilinears.len() {
181			let err_str = format!(
182				"Number of variables in composition {} does not match length of multilinears {}",
183				composition.n_vars(),
184				multilinears.len()
185			);
186			bail!(Error::MultilinearCompositeValidation(err_str));
187		}
188		for multilin in multilinears.iter().map(Borrow::borrow) {
189			if multilin.n_vars() != n_vars {
190				let err_str = format!(
191					"Number of variables in multilinear {} does not match n_vars {}",
192					multilin.n_vars(),
193					n_vars
194				);
195				bail!(Error::MultilinearCompositeValidation(err_str));
196			}
197		}
198		Ok(Self {
199			n_vars,
200			composition,
201			multilinears,
202			_marker: PhantomData,
203		})
204	}
205
206	pub fn evaluate<'a>(
207		&self,
208		query: impl Into<MultilinearQueryRef<'a, P>>,
209	) -> Result<P::Scalar, Error> {
210		let query = query.into();
211		let evals = self
212			.multilinears
213			.iter()
214			.map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate(query)?)))
215			.collect::<Result<Vec<_>, _>>()?;
216		Ok(self.composition.evaluate(&evals)?.get(0))
217	}
218
219	pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
220		let evals = self
221			.multilinears
222			.iter()
223			.map(|multilin| Ok::<P, Error>(P::set_single(multilin.evaluate_on_hypercube(index)?)))
224			.collect::<Result<Vec<_>, _>>()?;
225
226		Ok(self.composition.evaluate(&evals)?.get(0))
227	}
228
229	pub fn max_individual_degree(&self) -> usize {
230		// Maximum individual degree of the multilinear composite equals composition degree
231		self.composition.degree()
232	}
233
234	pub fn n_multilinears(&self) -> usize {
235		self.composition.n_vars()
236	}
237}
238
239impl<P, C, M> MultilinearComposite<P, C, M>
240where
241	P: PackedField,
242	C: CompositionPoly<P> + 'static,
243	M: MultilinearPoly<P>,
244{
245	pub fn to_arc_dyn_composition(self) -> MultilinearComposite<P, Arc<dyn CompositionPoly<P>>, M> {
246		MultilinearComposite {
247			n_vars: self.n_vars,
248			composition: Arc::new(self.composition),
249			multilinears: self.multilinears,
250			_marker: PhantomData,
251		}
252	}
253}
254
255impl<P, C, M> MultilinearComposite<P, C, M>
256where
257	P: PackedField,
258	M: MultilinearPoly<P>,
259{
260	pub const fn n_vars(&self) -> usize {
261		self.n_vars
262	}
263}
264
265impl<P, C, M> MultilinearComposite<P, C, M>
266where
267	P: PackedField,
268	C: Clone,
269	M: MultilinearPoly<P>,
270{
271	pub fn evaluate_partial_low(
272		&self,
273		query: MultilinearQueryRef<P>,
274	) -> Result<MultilinearComposite<P, C, impl MultilinearPoly<P>>, Error> {
275		let new_multilinears = self
276			.multilinears
277			.iter()
278			.map(|multilin| {
279				multilin
280					.evaluate_partial_low(query)
281					.map(MLEDirectAdapter::from)
282			})
283			.collect::<Result<Vec<_>, _>>()?;
284		Ok(MultilinearComposite {
285			composition: self.composition.clone(),
286			n_vars: self.n_vars - query.n_vars(),
287			multilinears: new_multilinears,
288			_marker: PhantomData,
289		})
290	}
291}
292
293/// Fingerprinting for composition polynomials done by evaluation at a deterministic random point.
294/// Outputs f(r_0,...,r_n-1) where f is a composite and the r_i are the components of the random point.
295///
296/// Probabilistic collision resistance comes from Schwartz-Zippel on the equation f(x_0,...,x_n-1) = g(x_0,...,x_n-1)
297/// for two distinct multivariate polynomials f and g.
298///
299/// NOTE: THIS IS NOT ADVERSARIALLY COLLISION RESISTANT, COLLISIONS CAN BE MANUFACTURED EASILY
300pub fn composition_hash<P: PackedField, C: CompositionPoly<P>>(composition: &C) -> P {
301	let mut rng = StdRng::from_seed([0; 32]);
302
303	let random_point = repeat_with(|| P::random(&mut rng))
304		.take(composition.n_vars())
305		.collect_vec();
306
307	composition
308		.evaluate(&random_point)
309		.expect("Failed to evaluate composition")
310}
311
312#[cfg(test)]
313mod tests {
314	use binius_math::{ArithExpr, CompositionPoly};
315
316	#[test]
317	fn test_fingerprint_same_32b() {
318		use binius_field::{BinaryField1b, PackedBinaryField8x32b};
319
320		//Complicated circuit for (x0 + x1) * x0 + x0^2
321		let expr =
322			(ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
323		let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
324			as &dyn CompositionPoly<PackedBinaryField8x32b>;
325
326		let product_composition = crate::composition::ProductComposition::<2> {};
327
328		assert_eq!(
329			crate::polynomial::composition_hash(&circuit_poly),
330			crate::polynomial::composition_hash(&product_composition)
331		);
332	}
333
334	#[test]
335	fn test_fingerprint_diff_32b() {
336		use binius_field::{BinaryField1b, PackedBinaryField8x32b};
337
338		let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
339
340		let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
341			as &dyn CompositionPoly<PackedBinaryField8x32b>;
342
343		let product_composition = crate::composition::ProductComposition::<2> {};
344
345		assert_ne!(
346			crate::polynomial::composition_hash(&circuit_poly),
347			crate::polynomial::composition_hash(&product_composition)
348		);
349	}
350
351	#[test]
352	fn test_fingerprint_same_64b() {
353		use binius_field::{BinaryField1b, PackedBinaryField4x64b};
354
355		// Complicated circuit for (x0 + x1) * x0 + x0^2
356		let expr =
357			(ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
358		let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
359			as &dyn CompositionPoly<PackedBinaryField4x64b>;
360
361		let product_composition = crate::composition::ProductComposition::<2> {};
362
363		assert_eq!(
364			crate::polynomial::composition_hash(&circuit_poly),
365			crate::polynomial::composition_hash(&product_composition)
366		);
367	}
368
369	#[test]
370	fn test_fingerprint_diff_64b() {
371		use binius_field::{BinaryField1b, PackedBinaryField4x64b};
372
373		let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
374		let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
375			as &dyn CompositionPoly<PackedBinaryField4x64b>;
376
377		let product_composition = crate::composition::ProductComposition::<2> {};
378
379		assert_ne!(
380			crate::polynomial::composition_hash(&circuit_poly),
381			crate::polynomial::composition_hash(&product_composition)
382		);
383	}
384
385	#[test]
386	fn test_fingerprint_same_128b() {
387		use binius_field::{BinaryField1b, PackedBinaryField2x128b};
388
389		// Complicated circuit for (x0 + x1) * x0 + x0^2
390		let expr =
391			(ArithExpr::Var(0) + ArithExpr::Var(1)) * ArithExpr::Var(0) + ArithExpr::Var(0).pow(2);
392		let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
393			as &dyn CompositionPoly<PackedBinaryField2x128b>;
394
395		let product_composition = crate::composition::ProductComposition::<2> {};
396
397		assert_eq!(
398			crate::polynomial::composition_hash(&circuit_poly),
399			crate::polynomial::composition_hash(&product_composition)
400		);
401	}
402
403	#[test]
404	fn test_fingerprint_diff_128b() {
405		use binius_field::{BinaryField1b, PackedBinaryField2x128b};
406
407		let expr = ArithExpr::Var(0) + ArithExpr::Var(1);
408		let circuit_poly = &crate::polynomial::ArithCircuitPoly::<BinaryField1b>::new(expr.into())
409			as &dyn CompositionPoly<PackedBinaryField2x128b>;
410
411		let product_composition = crate::composition::ProductComposition::<2> {};
412
413		assert_ne!(
414			crate::polynomial::composition_hash(&circuit_poly),
415			crate::polynomial::composition_hash(&product_composition)
416		);
417	}
418}