binius_math/
multilinear_extension.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{fmt::Debug, ops::Deref};
4
5use binius_field::{
6	as_packed_field::{AsSinglePacked, PackScalar, PackedType},
7	underlier::UnderlierType,
8	util::inner_product_par,
9	ExtensionField, Field, PackedField,
10};
11use binius_utils::bail;
12use bytemuck::zeroed_vec;
13use tracing::instrument;
14
15use crate::{fold::fold_left, fold_right, Error, MultilinearQueryRef, PackingDeref};
16
17/// A multilinear polynomial represented by its evaluations over the boolean hypercube.
18///
19/// This polynomial can also be viewed as the multilinear extension of the slice of hypercube
20/// evaluations. The evaluation data may be either a borrowed or owned slice.
21///
22/// The packed field width must be a power of two.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct MultilinearExtension<P: PackedField, Data: Deref<Target = [P]> = Vec<P>> {
25	// The number of variables
26	mu: usize,
27	// The evaluations of the polynomial over the boolean hypercube, in lexicographic order
28	evals: Data,
29}
30
31impl<P: PackedField> MultilinearExtension<P> {
32	pub fn zeros(n_vars: usize) -> Result<Self, Error> {
33		if n_vars < P::LOG_WIDTH {
34			bail!(Error::ArgumentRangeError {
35				arg: "n_vars".into(),
36				range: P::LOG_WIDTH..32,
37			});
38		}
39		Ok(Self {
40			mu: n_vars,
41			evals: vec![P::default(); 1 << (n_vars - P::LOG_WIDTH)],
42		})
43	}
44
45	pub fn from_values(v: Vec<P>) -> Result<Self, Error> {
46		Self::from_values_generic(v)
47	}
48
49	pub fn into_evals(self) -> Vec<P> {
50		self.evals
51	}
52}
53
54impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
55	pub fn from_values_generic(v: Data) -> Result<Self, Error> {
56		if !v.len().is_power_of_two() {
57			bail!(Error::PowerOfTwoLengthRequired);
58		}
59		let mu = log2(v.len()) + P::LOG_WIDTH;
60
61		Ok(Self { mu, evals: v })
62	}
63
64	pub fn new(n_vars: usize, v: Data) -> Result<Self, Error> {
65		if !v.len().is_power_of_two() {
66			bail!(Error::PowerOfTwoLengthRequired);
67		}
68
69		if n_vars < P::LOG_WIDTH {
70			if v.len() != 1 {
71				bail!(Error::IncorrectNumberOfVariables {
72					expected: n_vars,
73					actual: P::LOG_WIDTH + log2(v.len())
74				});
75			}
76		} else if P::LOG_WIDTH + log2(v.len()) != n_vars {
77			bail!(Error::IncorrectNumberOfVariables {
78				expected: n_vars,
79				actual: P::LOG_WIDTH + log2(v.len())
80			});
81		}
82
83		Ok(Self {
84			mu: n_vars,
85			evals: v,
86		})
87	}
88}
89
90impl<U, F, Data> MultilinearExtension<PackedType<U, F>, PackingDeref<U, F, Data>>
91where
92	// TODO: Add U: Divisible<u8>.
93	U: UnderlierType + PackScalar<F>,
94	F: Field,
95	Data: Deref<Target = [U]>,
96{
97	pub fn from_underliers(v: Data) -> Result<Self, Error> {
98		Self::from_values_generic(PackingDeref::new(v))
99	}
100}
101
102impl<'a, P: PackedField> MultilinearExtension<P, &'a [P]> {
103	pub fn from_values_slice(v: &'a [P]) -> Result<Self, Error> {
104		if !v.len().is_power_of_two() {
105			bail!(Error::PowerOfTwoLengthRequired);
106		}
107		let mu = log2(v.len() * P::WIDTH);
108		Ok(Self { mu, evals: v })
109	}
110}
111
112impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
113	pub const fn n_vars(&self) -> usize {
114		self.mu
115	}
116
117	pub const fn size(&self) -> usize {
118		1 << self.mu
119	}
120
121	pub fn evals(&self) -> &[P] {
122		&self.evals
123	}
124
125	pub fn to_ref(&self) -> MultilinearExtension<P, &[P]> {
126		MultilinearExtension {
127			mu: self.mu,
128			evals: self.evals(),
129		}
130	}
131
132	/// Get the evaluations of the polynomial on a subcube of the hypercube of size equal to the
133	/// packing width.
134	///
135	/// # Arguments
136	///
137	/// * `index` - The index of the subcube
138	pub fn packed_evaluate_on_hypercube(&self, index: usize) -> Result<P, Error> {
139		self.evals()
140			.get(index)
141			.ok_or(Error::HypercubeIndexOutOfRange { index })
142			.copied()
143	}
144
145	pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
146		if self.size() <= index {
147			bail!(Error::HypercubeIndexOutOfRange { index })
148		}
149
150		let subcube_eval = self.packed_evaluate_on_hypercube(index / P::WIDTH)?;
151		Ok(subcube_eval.get(index % P::WIDTH))
152	}
153}
154
155impl<P, Data> MultilinearExtension<P, Data>
156where
157	P: PackedField,
158	Data: Deref<Target = [P]> + Send + Sync,
159{
160	pub fn evaluate<'a, FE, PE>(
161		&self,
162		query: impl Into<MultilinearQueryRef<'a, PE>>,
163	) -> Result<FE, Error>
164	where
165		FE: ExtensionField<P::Scalar>,
166		PE: PackedField<Scalar = FE>,
167	{
168		let query = query.into();
169		if self.mu != query.n_vars() {
170			bail!(Error::IncorrectQuerySize { expected: self.mu });
171		}
172
173		if self.mu < P::LOG_WIDTH || query.n_vars() < PE::LOG_WIDTH {
174			let evals = PackedField::iter_slice(self.evals())
175				.take(self.size())
176				.collect::<Vec<P::Scalar>>();
177			let querys = PackedField::iter_slice(query.expansion())
178				.take(1 << query.n_vars())
179				.collect::<Vec<PE::Scalar>>();
180			Ok(inner_product_par(&querys, &evals))
181		} else {
182			Ok(inner_product_par(query.expansion(), &self.evals))
183		}
184	}
185
186	/// Partially evaluate the polynomial with assignment to the high-indexed variables.
187	///
188	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu - 1})$. Given a query
189	/// vector of length $k$ representing $(z_{\mu - k + 1}, ..., z_{\mu - 1})$, this returns the
190	/// multilinear polynomial with $\mu - k$ variables,
191	/// $p(X_0, ..., X_{\mu - k}, z_{\mu - k + 1}, ..., z_{\mu - 1})$.
192	///
193	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
194	/// PE::WIDTH, i.e. 2^(\mu - k) \geq PE::WIDTH, since WIDTH is power of two
195	#[instrument(
196		"MultilinearExtension::evaluate_partial_high",
197		skip_all,
198		level = "debug"
199	)]
200	pub fn evaluate_partial_high<'a, PE>(
201		&self,
202		query: impl Into<MultilinearQueryRef<'a, PE>>,
203	) -> Result<MultilinearExtension<PE>, Error>
204	where
205		PE: PackedField,
206		PE::Scalar: ExtensionField<P::Scalar>,
207	{
208		let query = query.into();
209
210		let new_n_vars = self.mu.saturating_sub(query.n_vars());
211		let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
212		let mut result_evals = Vec::with_capacity(result_evals_len);
213
214		fold_left(
215			self.evals(),
216			self.mu,
217			query.expansion(),
218			query.n_vars(),
219			result_evals.spare_capacity_mut(),
220		)?;
221		unsafe {
222			result_evals.set_len(result_evals_len);
223		}
224
225		MultilinearExtension::new(new_n_vars, result_evals)
226	}
227
228	/// Partially evaluate the polynomial with assignment to the low-indexed variables.
229	///
230	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu-1}$. Given a query
231	/// vector of length $k$ representing $(z_0, ..., z_{k-1})$, this returns the
232	/// multilinear polynomial with $\mu - k$ variables,
233	/// $p(z_0, ..., z_{k-1}, X_k, ..., X_{\mu - 1})$.
234	///
235	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
236	/// P::WIDTH, i.e. 2^(\mu - k) \geq P::WIDTH, since WIDTH is power of two
237	#[instrument(
238		"MultilinearExtension::evaluate_partial_low",
239		skip_all,
240		level = "trace"
241	)]
242	pub fn evaluate_partial_low<'a, PE>(
243		&self,
244		query: impl Into<MultilinearQueryRef<'a, PE>>,
245	) -> Result<MultilinearExtension<PE>, Error>
246	where
247		PE: PackedField,
248		PE::Scalar: ExtensionField<P::Scalar>,
249	{
250		let query = query.into();
251
252		if self.mu < query.n_vars() {
253			bail!(Error::IncorrectQuerySize { expected: self.mu });
254		}
255
256		let new_n_vars = self.mu - query.n_vars();
257
258		let mut result =
259			zeroed_vec(1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)));
260		self.evaluate_partial_low_into(query, &mut result)?;
261		MultilinearExtension::new(new_n_vars, result)
262	}
263
264	/// Partially evaluate the polynomial with assignment to the low-indexed variables.
265	///
266	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu-1}$. Given a query
267	/// vector of length $k$ representing $(z_0, ..., z_{k-1})$, this returns the
268	/// multilinear polynomial with $\mu - k$ variables,
269	/// $p(z_0, ..., z_{k-1}, X_k, ..., X_{\mu - 1})$.
270	///
271	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
272	/// P::WIDTH, i.e. 2^(\mu - k) \geq P::WIDTH, since WIDTH is power of two
273	pub fn evaluate_partial_low_into<PE>(
274		&self,
275		query: MultilinearQueryRef<PE>,
276		out: &mut [PE],
277	) -> Result<(), Error>
278	where
279		PE: PackedField,
280		PE::Scalar: ExtensionField<P::Scalar>,
281	{
282		// This operation is a matrix-vector product of the matrix of multilinear coefficients with
283		// the vector of tensor product-expanded query coefficients.
284		fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out)
285	}
286}
287
288impl<F: Field + AsSinglePacked, Data: Deref<Target = [F]>> MultilinearExtension<F, Data> {
289	/// Convert MultilinearExtension over a scalar to a MultilinearExtension over a packed field with single element.
290	pub fn to_single_packed(self) -> MultilinearExtension<F::Packed> {
291		let packed_evals = self
292			.evals
293			.iter()
294			.map(|eval| eval.to_single_packed())
295			.collect();
296		MultilinearExtension {
297			mu: self.mu,
298			evals: packed_evals,
299		}
300	}
301}
302
303const fn log2(v: usize) -> usize {
304	63 - (v as u64).leading_zeros() as usize
305}
306
307/// Type alias for the common pattern of a [`MultilinearExtension`] backed by borrowed data.
308pub type MultilinearExtensionBorrowed<'a, P> = MultilinearExtension<P, &'a [P]>;
309
310#[cfg(test)]
311mod tests {
312	use std::iter::repeat_with;
313
314	use binius_field::{
315		arch::OptimalUnderlier256b, BinaryField128b, BinaryField16b as F, BinaryField32b,
316		BinaryField8b, PackedBinaryField16x8b, PackedBinaryField4x32b, PackedBinaryField8x16b as P,
317	};
318	use itertools::Itertools;
319	use rand::{rngs::StdRng, SeedableRng};
320
321	use super::*;
322	use crate::{tensor_prod_eq_ind, MultilinearQuery};
323
324	/// Expand the tensor product of the query values.
325	///
326	/// [`query`] is a sequence of field elements $z_0, ..., z_{k-1}$.
327	///
328	/// This naive implementation runs in O(k 2^k) time and O(1) space.
329	fn expand_query_naive<F: Field>(query: &[F]) -> Result<Vec<F>, Error> {
330		let result = (0..1 << query.len())
331			.map(|i| eval_basis(query, i))
332			.collect();
333		Ok(result)
334	}
335
336	/// Evaluates the Lagrange basis polynomial over the boolean hypercube at a queried point.
337	fn eval_basis<F: Field>(query: &[F], i: usize) -> F {
338		query
339			.iter()
340			.enumerate()
341			.map(|(j, &v)| if i & (1 << j) == 0 { F::ONE - v } else { v })
342			.product()
343	}
344
345	fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
346		let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
347		result[0] = P::set_single(P::Scalar::ONE);
348		tensor_prod_eq_ind(0, &mut result, p).unwrap();
349		MultilinearQuery::with_expansion(p.len(), result).unwrap()
350	}
351
352	#[test]
353	fn test_expand_query_impls_consistent() {
354		let mut rng = StdRng::seed_from_u64(0);
355		let q = repeat_with(|| Field::random(&mut rng))
356			.take(8)
357			.collect::<Vec<F>>();
358		let result1 = multilinear_query::<P>(&q);
359		let result2 = expand_query_naive(&q).unwrap();
360		assert_eq!(PackedField::iter_slice(result1.expansion()).collect_vec(), result2);
361	}
362
363	#[test]
364	fn test_new_from_values_correspondence() {
365		let mut rng = StdRng::seed_from_u64(0);
366		let evals = repeat_with(|| Field::random(&mut rng))
367			.take(256)
368			.collect::<Vec<F>>();
369		let poly1 = MultilinearExtension::from_values(evals.clone()).unwrap();
370		let poly2 = MultilinearExtension::new(8, evals).unwrap();
371
372		assert_eq!(poly1, poly2)
373	}
374
375	#[test]
376	fn test_evaluate_on_hypercube() {
377		let mut values = vec![F::ZERO; 64];
378		values
379			.iter_mut()
380			.enumerate()
381			.for_each(|(i, val)| *val = F::new(i as u16));
382
383		let poly = MultilinearExtension::from_values(values).unwrap();
384		for i in 0..64 {
385			let q = (0..6)
386				.map(|j| if (i >> j) & 1 != 0 { F::ONE } else { F::ZERO })
387				.collect::<Vec<_>>();
388			let multilin_query = multilinear_query::<P>(&q);
389			let result = poly.evaluate(multilin_query.to_ref()).unwrap();
390			assert_eq!(result, F::new(i));
391		}
392	}
393
394	fn evaluate_split<P>(
395		poly: MultilinearExtension<P>,
396		q: &[P::Scalar],
397		splits: &[usize],
398	) -> P::Scalar
399	where
400		P: PackedField + 'static,
401	{
402		assert_eq!(splits.iter().sum::<usize>(), poly.n_vars());
403
404		let mut partial_result = poly;
405		let mut index = q.len();
406		for split_vars in &splits[0..splits.len() - 1] {
407			let split_vars = *split_vars;
408			let query = multilinear_query(&q[index - split_vars..index]);
409			partial_result = partial_result
410				.evaluate_partial_high(query.to_ref())
411				.unwrap();
412			index -= split_vars;
413		}
414		let multilin_query = multilinear_query::<P>(&q[..index]);
415		partial_result.evaluate(multilin_query.to_ref()).unwrap()
416	}
417
418	#[test]
419	fn test_evaluate_split_is_correct() {
420		let mut rng = StdRng::seed_from_u64(0);
421		let evals = repeat_with(|| Field::random(&mut rng))
422			.take(256)
423			.collect::<Vec<F>>();
424		let poly = MultilinearExtension::from_values(evals).unwrap();
425		let q = repeat_with(|| Field::random(&mut rng))
426			.take(8)
427			.collect::<Vec<F>>();
428		let multilin_query = multilinear_query::<P>(&q);
429		let result1 = poly.evaluate(multilin_query.to_ref()).unwrap();
430		let result2 = evaluate_split(poly, &q, &[2, 3, 3]);
431		assert_eq!(result1, result2);
432	}
433
434	#[test]
435	fn test_evaluate_partial_high_packed() {
436		let mut rng = StdRng::seed_from_u64(0);
437		let evals = repeat_with(|| P::random(&mut rng))
438			.take(256 >> P::LOG_WIDTH)
439			.collect::<Vec<_>>();
440		let poly = MultilinearExtension::from_values(evals).unwrap();
441		let q = repeat_with(|| Field::random(&mut rng))
442			.take(8)
443			.collect::<Vec<BinaryField128b>>();
444		let multilin_query = multilinear_query::<BinaryField128b>(&q);
445
446		let expected = poly.evaluate(multilin_query.to_ref()).unwrap();
447
448		// The final split has a number of coefficients less than the packing width
449		let query_hi = multilinear_query::<BinaryField128b>(&q[1..]);
450		let partial_eval = poly.evaluate_partial_high(query_hi.to_ref()).unwrap();
451		assert!(partial_eval.n_vars() < P::LOG_WIDTH);
452
453		let query_lo = multilinear_query::<BinaryField128b>(&q[..1]);
454		let eval = partial_eval.evaluate(query_lo.to_ref()).unwrap();
455		assert_eq!(eval, expected);
456	}
457
458	#[test]
459	fn test_evaluate_partial_low_high_smaller_than_packed_width() {
460		type P = PackedBinaryField16x8b;
461
462		type F = BinaryField8b;
463
464		let n_vars = 3;
465
466		let mut rng = StdRng::seed_from_u64(0);
467
468		let values = repeat_with(|| Field::random(&mut rng))
469			.take(1 << n_vars)
470			.collect::<Vec<F>>();
471
472		let q = repeat_with(|| Field::random(&mut rng))
473			.take(n_vars)
474			.collect::<Vec<F>>();
475
476		let query = multilinear_query::<P>(&q);
477
478		let packed = P::from_scalars(values);
479		let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
480
481		let eval = me.evaluate(&query).unwrap();
482
483		let query_low = multilinear_query::<P>(&q[..n_vars - 1]);
484		let query_high = multilinear_query::<P>(&q[n_vars - 1..]);
485
486		let eval_l_h = me
487			.evaluate_partial_high(&query_high)
488			.unwrap()
489			.evaluate_partial_low(&query_low)
490			.unwrap()
491			.evals()[0]
492			.get(0);
493
494		assert_eq!(eval, eval_l_h);
495	}
496
497	#[test]
498	fn test_evaluate_on_hypercube_small_than_packed_width() {
499		type P = PackedBinaryField16x8b;
500
501		type F = BinaryField8b;
502
503		let n_vars = 3;
504
505		let mut rng = StdRng::seed_from_u64(0);
506
507		let values = repeat_with(|| Field::random(&mut rng))
508			.take(1 << n_vars)
509			.collect::<Vec<F>>();
510
511		let packed = P::from_scalars(values.clone());
512
513		let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
514
515		assert_eq!(me.evaluate_on_hypercube(1).unwrap(), values[1]);
516
517		assert!(me.evaluate_on_hypercube(1 << n_vars).is_err());
518	}
519
520	#[test]
521	fn test_evaluate_partial_high_low_evaluate_consistent() {
522		let mut rng = StdRng::seed_from_u64(0);
523		let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
524			.take(1 << 8)
525			.collect();
526
527		let me = MultilinearExtension::from_values(values).unwrap();
528
529		let q = repeat_with(|| <BinaryField32b as PackedField>::random(&mut rng))
530			.take(me.n_vars())
531			.collect::<Vec<_>>();
532
533		let query = multilinear_query(&q);
534
535		let eval = me
536			.evaluate::<<PackedBinaryField4x32b as PackedField>::Scalar, PackedBinaryField4x32b>(
537				query.to_ref(),
538			)
539			.unwrap();
540
541		assert_eq!(
542			me.evaluate_partial_low::<PackedBinaryField4x32b>(query.to_ref())
543				.unwrap()
544				.evals[0]
545				.get(0),
546			eval
547		);
548		assert_eq!(
549			me.evaluate_partial_high::<PackedBinaryField4x32b>(query.to_ref())
550				.unwrap()
551				.evals[0]
552				.get(0),
553			eval
554		);
555	}
556
557	#[test]
558	fn test_evaluate_partial_low_single_and_multiple_var_consistent() {
559		let mut rng = StdRng::seed_from_u64(0);
560		let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
561			.take(1 << 8)
562			.collect();
563
564		let mle = MultilinearExtension::from_values(values).unwrap();
565		let r1 = <BinaryField32b as PackedField>::random(&mut rng);
566		let r2 = <BinaryField32b as PackedField>::random(&mut rng);
567
568		let eval_1: MultilinearExtension<PackedBinaryField4x32b> = mle
569			.evaluate_partial_low::<PackedBinaryField4x32b>(multilinear_query(&[r1]).to_ref())
570			.unwrap()
571			.evaluate_partial_low(multilinear_query(&[r2]).to_ref())
572			.unwrap();
573		let eval_2 = mle
574			.evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref())
575			.unwrap();
576		assert_eq!(eval_1, eval_2);
577	}
578
579	#[test]
580	fn test_new_mle_with_tiny_nvars() {
581		MultilinearExtension::new(
582			1,
583			vec![PackedType::<OptimalUnderlier256b, BinaryField32b>::one()],
584		)
585		.unwrap();
586	}
587}