binius_math/
multilinear_extension.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{fmt::Debug, ops::Deref};
4
5use binius_field::{
6	as_packed_field::{AsSinglePacked, PackScalar, PackedType},
7	underlier::UnderlierType,
8	util::inner_product_par,
9	ExtensionField, Field, PackedField,
10};
11use binius_utils::bail;
12use bytemuck::zeroed_vec;
13use tracing::instrument;
14
15use crate::{
16	fold::fold_left, fold_middle, fold_right, zero_pad, Error, MultilinearQueryRef, PackingDeref,
17};
18
19/// A multilinear polynomial represented by its evaluations over the boolean hypercube.
20///
21/// This polynomial can also be viewed as the multilinear extension of the slice of hypercube
22/// evaluations. The evaluation data may be either a borrowed or owned slice.
23///
24/// The packed field width must be a power of two.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct MultilinearExtension<P: PackedField, Data: Deref<Target = [P]> = Vec<P>> {
27	// The number of variables
28	mu: usize,
29	// The evaluations of the polynomial over the boolean hypercube, in lexicographic order
30	evals: Data,
31}
32
33impl<P: PackedField> MultilinearExtension<P> {
34	pub fn zeros(n_vars: usize) -> Result<Self, Error> {
35		if n_vars < P::LOG_WIDTH {
36			bail!(Error::ArgumentRangeError {
37				arg: "n_vars".into(),
38				range: P::LOG_WIDTH..32,
39			});
40		}
41		Ok(Self {
42			mu: n_vars,
43			evals: vec![P::default(); 1 << (n_vars - P::LOG_WIDTH)],
44		})
45	}
46
47	pub fn from_values(v: Vec<P>) -> Result<Self, Error> {
48		Self::from_values_generic(v)
49	}
50
51	pub fn into_evals(self) -> Vec<P> {
52		self.evals
53	}
54}
55
56impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
57	pub fn from_values_generic(v: Data) -> Result<Self, Error> {
58		if !v.len().is_power_of_two() {
59			bail!(Error::PowerOfTwoLengthRequired);
60		}
61		let mu = log2(v.len()) + P::LOG_WIDTH;
62
63		Ok(Self { mu, evals: v })
64	}
65
66	pub fn new(n_vars: usize, v: Data) -> Result<Self, Error> {
67		if !v.len().is_power_of_two() {
68			bail!(Error::PowerOfTwoLengthRequired);
69		}
70
71		if n_vars < P::LOG_WIDTH {
72			if v.len() != 1 {
73				bail!(Error::IncorrectNumberOfVariables {
74					expected: n_vars,
75					actual: P::LOG_WIDTH + log2(v.len())
76				});
77			}
78		} else if P::LOG_WIDTH + log2(v.len()) != n_vars {
79			bail!(Error::IncorrectNumberOfVariables {
80				expected: n_vars,
81				actual: P::LOG_WIDTH + log2(v.len())
82			});
83		}
84
85		Ok(Self {
86			mu: n_vars,
87			evals: v,
88		})
89	}
90}
91
92impl<U, F, Data> MultilinearExtension<PackedType<U, F>, PackingDeref<U, F, Data>>
93where
94	// TODO: Add U: Divisible<u8>.
95	U: UnderlierType + PackScalar<F>,
96	F: Field,
97	Data: Deref<Target = [U]>,
98{
99	pub fn from_underliers(v: Data) -> Result<Self, Error> {
100		Self::from_values_generic(PackingDeref::new(v))
101	}
102}
103
104impl<'a, P: PackedField> MultilinearExtension<P, &'a [P]> {
105	pub fn from_values_slice(v: &'a [P]) -> Result<Self, Error> {
106		if !v.len().is_power_of_two() {
107			bail!(Error::PowerOfTwoLengthRequired);
108		}
109		let mu = log2(v.len() * P::WIDTH);
110		Ok(Self { mu, evals: v })
111	}
112}
113
114impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
115	pub const fn n_vars(&self) -> usize {
116		self.mu
117	}
118
119	pub const fn size(&self) -> usize {
120		1 << self.mu
121	}
122
123	pub fn evals(&self) -> &[P] {
124		&self.evals
125	}
126
127	pub fn to_ref(&self) -> MultilinearExtension<P, &[P]> {
128		MultilinearExtension {
129			mu: self.mu,
130			evals: self.evals(),
131		}
132	}
133
134	/// Get the evaluations of the polynomial on a subcube of the hypercube of size equal to the
135	/// packing width.
136	///
137	/// # Arguments
138	///
139	/// * `index` - The index of the subcube
140	pub fn packed_evaluate_on_hypercube(&self, index: usize) -> Result<P, Error> {
141		self.evals()
142			.get(index)
143			.ok_or(Error::HypercubeIndexOutOfRange { index })
144			.copied()
145	}
146
147	pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
148		if self.size() <= index {
149			bail!(Error::HypercubeIndexOutOfRange { index })
150		}
151
152		let subcube_eval = self.packed_evaluate_on_hypercube(index / P::WIDTH)?;
153		Ok(subcube_eval.get(index % P::WIDTH))
154	}
155}
156
157impl<P, Data> MultilinearExtension<P, Data>
158where
159	P: PackedField,
160	Data: Deref<Target = [P]> + Send + Sync,
161{
162	pub fn evaluate<'a, FE, PE>(
163		&self,
164		query: impl Into<MultilinearQueryRef<'a, PE>>,
165	) -> Result<FE, Error>
166	where
167		FE: ExtensionField<P::Scalar>,
168		PE: PackedField<Scalar = FE>,
169	{
170		let query = query.into();
171		if self.mu != query.n_vars() {
172			bail!(Error::IncorrectQuerySize { expected: self.mu });
173		}
174
175		if self.mu < P::LOG_WIDTH || query.n_vars() < PE::LOG_WIDTH {
176			let evals = PackedField::iter_slice(self.evals())
177				.take(self.size())
178				.collect::<Vec<P::Scalar>>();
179			let querys = PackedField::iter_slice(query.expansion())
180				.take(1 << query.n_vars())
181				.collect::<Vec<PE::Scalar>>();
182			Ok(inner_product_par(&querys, &evals))
183		} else {
184			Ok(inner_product_par(query.expansion(), &self.evals))
185		}
186	}
187
188	#[instrument("MultilinearExtension::evaluate_partial", skip_all, level = "debug")]
189	pub fn evaluate_partial<'a, PE>(
190		&self,
191		query: impl Into<MultilinearQueryRef<'a, PE>>,
192		start_index: usize,
193	) -> Result<MultilinearExtension<PE>, Error>
194	where
195		PE: PackedField,
196		PE::Scalar: ExtensionField<P::Scalar>,
197	{
198		let query = query.into();
199		if start_index + query.n_vars() > self.mu {
200			bail!(Error::IncorrectStartIndex { expected: self.mu })
201		}
202
203		if start_index == 0 {
204			return self.evaluate_partial_low(query);
205		} else if start_index + query.n_vars() == self.mu {
206			return self.evaluate_partial_high(query);
207		}
208
209		if self.mu < query.n_vars() {
210			bail!(Error::IncorrectQuerySize { expected: self.mu });
211		}
212
213		let new_n_vars = self.mu - query.n_vars();
214		let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
215		let mut result_evals = Vec::with_capacity(result_evals_len);
216
217		fold_middle(
218			self.evals(),
219			self.mu,
220			query.expansion(),
221			query.n_vars(),
222			start_index,
223			result_evals.spare_capacity_mut(),
224		)?;
225		unsafe {
226			result_evals.set_len(result_evals_len);
227		}
228
229		MultilinearExtension::new(new_n_vars, result_evals)
230	}
231
232	/// Partially evaluate the polynomial with assignment to the high-indexed variables.
233	///
234	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu - 1})$. Given a query
235	/// vector of length $k$ representing $(z_{\mu - k + 1}, ..., z_{\mu - 1})$, this returns the
236	/// multilinear polynomial with $\mu - k$ variables,
237	/// $p(X_0, ..., X_{\mu - k}, z_{\mu - k + 1}, ..., z_{\mu - 1})$.
238	///
239	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
240	/// PE::WIDTH, i.e. 2^(\mu - k) \geq PE::WIDTH, since WIDTH is power of two
241	#[instrument(
242		"MultilinearExtension::evaluate_partial_high",
243		skip_all,
244		level = "debug"
245	)]
246	pub fn evaluate_partial_high<'a, PE>(
247		&self,
248		query: impl Into<MultilinearQueryRef<'a, PE>>,
249	) -> Result<MultilinearExtension<PE>, Error>
250	where
251		PE: PackedField,
252		PE::Scalar: ExtensionField<P::Scalar>,
253	{
254		let query = query.into();
255
256		let new_n_vars = self.mu.saturating_sub(query.n_vars());
257		let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
258		let mut result_evals = Vec::with_capacity(result_evals_len);
259
260		fold_left(
261			self.evals(),
262			self.mu,
263			query.expansion(),
264			query.n_vars(),
265			result_evals.spare_capacity_mut(),
266		)?;
267		unsafe {
268			result_evals.set_len(result_evals_len);
269		}
270
271		MultilinearExtension::new(new_n_vars, result_evals)
272	}
273
274	/// Partially evaluate the polynomial with assignment to the low-indexed variables.
275	///
276	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu-1}$. Given a query
277	/// vector of length $k$ representing $(z_0, ..., z_{k-1})$, this returns the
278	/// multilinear polynomial with $\mu - k$ variables,
279	/// $p(z_0, ..., z_{k-1}, X_k, ..., X_{\mu - 1})$.
280	///
281	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
282	/// P::WIDTH, i.e. 2^(\mu - k) \geq P::WIDTH, since WIDTH is power of two
283	#[instrument(
284		"MultilinearExtension::evaluate_partial_low",
285		skip_all,
286		level = "trace"
287	)]
288	pub fn evaluate_partial_low<'a, PE>(
289		&self,
290		query: impl Into<MultilinearQueryRef<'a, PE>>,
291	) -> Result<MultilinearExtension<PE>, Error>
292	where
293		PE: PackedField,
294		PE::Scalar: ExtensionField<P::Scalar>,
295	{
296		let query = query.into();
297
298		if self.mu < query.n_vars() {
299			bail!(Error::IncorrectQuerySize { expected: self.mu });
300		}
301
302		let new_n_vars = self.mu - query.n_vars();
303
304		let mut result =
305			zeroed_vec(1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)));
306		self.evaluate_partial_low_into(query, &mut result)?;
307		MultilinearExtension::new(new_n_vars, result)
308	}
309
310	/// Partially evaluate the polynomial with assignment to the low-indexed variables.
311	///
312	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu-1}$. Given a query
313	/// vector of length $k$ representing $(z_0, ..., z_{k-1})$, this returns the
314	/// multilinear polynomial with $\mu - k$ variables,
315	/// $p(z_0, ..., z_{k-1}, X_k, ..., X_{\mu - 1})$.
316	///
317	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
318	/// P::WIDTH, i.e. 2^(\mu - k) \geq P::WIDTH, since WIDTH is power of two
319	pub fn evaluate_partial_low_into<PE>(
320		&self,
321		query: MultilinearQueryRef<PE>,
322		out: &mut [PE],
323	) -> Result<(), Error>
324	where
325		PE: PackedField,
326		PE::Scalar: ExtensionField<P::Scalar>,
327	{
328		// This operation is a matrix-vector product of the matrix of multilinear coefficients with
329		// the vector of tensor product-expanded query coefficients.
330		fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out)
331	}
332
333	pub fn zero_pad<PE>(
334		&self,
335		n_pad_vars: usize,
336		start_index: usize,
337		nonzero_index: usize,
338	) -> Result<MultilinearExtension<PE>, Error>
339	where
340		PE: PackedField,
341		PE::Scalar: ExtensionField<P::Scalar>,
342	{
343		let init_n_vars = self.mu;
344		if start_index > init_n_vars {
345			bail!(Error::IncorrectStartIndexZeroPad { expected: self.mu })
346		}
347		let new_n_vars = init_n_vars + n_pad_vars;
348		if nonzero_index >= 1 << n_pad_vars {
349			bail!(Error::IncorrectNonZeroIndex {
350				expected: 1 << n_pad_vars,
351			});
352		}
353
354		let mut result = zeroed_vec(1 << new_n_vars);
355
356		zero_pad(&self.evals, init_n_vars, new_n_vars, start_index, nonzero_index, &mut result)?;
357		MultilinearExtension::new(new_n_vars, result)
358	}
359}
360
361impl<F: Field + AsSinglePacked, Data: Deref<Target = [F]>> MultilinearExtension<F, Data> {
362	/// Convert MultilinearExtension over a scalar to a MultilinearExtension over a packed field with single element.
363	pub fn to_single_packed(self) -> MultilinearExtension<F::Packed> {
364		let packed_evals = self
365			.evals
366			.iter()
367			.map(|eval| eval.to_single_packed())
368			.collect();
369		MultilinearExtension {
370			mu: self.mu,
371			evals: packed_evals,
372		}
373	}
374}
375
376const fn log2(v: usize) -> usize {
377	63 - (v as u64).leading_zeros() as usize
378}
379
380/// Type alias for the common pattern of a [`MultilinearExtension`] backed by borrowed data.
381pub type MultilinearExtensionBorrowed<'a, P> = MultilinearExtension<P, &'a [P]>;
382
383#[cfg(test)]
384mod tests {
385	use std::iter::repeat_with;
386
387	use binius_field::{
388		arch::OptimalUnderlier256b, BinaryField128b, BinaryField16b as F, BinaryField1b,
389		BinaryField32b, BinaryField8b, PackedBinaryField16x1b, PackedBinaryField16x8b,
390		PackedBinaryField32x1b, PackedBinaryField4x32b, PackedBinaryField8x16b as P,
391		PackedBinaryField8x1b,
392	};
393	use itertools::Itertools;
394	use rand::{rngs::StdRng, SeedableRng};
395
396	use super::*;
397	use crate::{tensor_prod_eq_ind, MultilinearQuery};
398
399	/// Expand the tensor product of the query values.
400	///
401	/// [`query`] is a sequence of field elements $z_0, ..., z_{k-1}$.
402	///
403	/// This naive implementation runs in O(k 2^k) time and O(1) space.
404	fn expand_query_naive<F: Field>(query: &[F]) -> Result<Vec<F>, Error> {
405		let result = (0..1 << query.len())
406			.map(|i| eval_basis(query, i))
407			.collect();
408		Ok(result)
409	}
410
411	/// Evaluates the Lagrange basis polynomial over the boolean hypercube at a queried point.
412	fn eval_basis<F: Field>(query: &[F], i: usize) -> F {
413		query
414			.iter()
415			.enumerate()
416			.map(|(j, &v)| if i & (1 << j) == 0 { F::ONE - v } else { v })
417			.product()
418	}
419
420	fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
421		let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
422		result[0] = P::set_single(P::Scalar::ONE);
423		tensor_prod_eq_ind(0, &mut result, p).unwrap();
424		MultilinearQuery::with_expansion(p.len(), result).unwrap()
425	}
426
427	#[test]
428	fn test_expand_query_impls_consistent() {
429		let mut rng = StdRng::seed_from_u64(0);
430		let q = repeat_with(|| Field::random(&mut rng))
431			.take(8)
432			.collect::<Vec<F>>();
433		let result1 = multilinear_query::<P>(&q);
434		let result2 = expand_query_naive(&q).unwrap();
435		assert_eq!(PackedField::iter_slice(result1.expansion()).collect_vec(), result2);
436	}
437
438	#[test]
439	fn test_new_from_values_correspondence() {
440		let mut rng = StdRng::seed_from_u64(0);
441		let evals = repeat_with(|| Field::random(&mut rng))
442			.take(256)
443			.collect::<Vec<F>>();
444		let poly1 = MultilinearExtension::from_values(evals.clone()).unwrap();
445		let poly2 = MultilinearExtension::new(8, evals).unwrap();
446
447		assert_eq!(poly1, poly2)
448	}
449
450	#[test]
451	fn test_evaluate_on_hypercube() {
452		let mut values = vec![F::ZERO; 64];
453		values
454			.iter_mut()
455			.enumerate()
456			.for_each(|(i, val)| *val = F::new(i as u16));
457
458		let poly = MultilinearExtension::from_values(values).unwrap();
459		for i in 0..64 {
460			let q = (0..6)
461				.map(|j| if (i >> j) & 1 != 0 { F::ONE } else { F::ZERO })
462				.collect::<Vec<_>>();
463			let multilin_query = multilinear_query::<P>(&q);
464			let result = poly.evaluate(multilin_query.to_ref()).unwrap();
465			assert_eq!(result, F::new(i));
466		}
467	}
468
469	fn evaluate_split<P>(
470		poly: MultilinearExtension<P>,
471		q: &[P::Scalar],
472		splits: &[usize],
473	) -> P::Scalar
474	where
475		P: PackedField + 'static,
476	{
477		assert_eq!(splits.iter().sum::<usize>(), poly.n_vars());
478
479		let mut partial_result = poly;
480		let mut index = q.len();
481		for split_vars in &splits[0..splits.len() - 1] {
482			let split_vars = *split_vars;
483			let query = multilinear_query(&q[index - split_vars..index]);
484			partial_result = partial_result
485				.evaluate_partial_high(query.to_ref())
486				.unwrap();
487			index -= split_vars;
488		}
489		let multilin_query = multilinear_query::<P>(&q[..index]);
490		partial_result.evaluate(multilin_query.to_ref()).unwrap()
491	}
492
493	#[test]
494	fn test_evaluate_split_is_correct() {
495		let mut rng = StdRng::seed_from_u64(0);
496		let evals = repeat_with(|| Field::random(&mut rng))
497			.take(256)
498			.collect::<Vec<F>>();
499		let poly = MultilinearExtension::from_values(evals).unwrap();
500		let q = repeat_with(|| Field::random(&mut rng))
501			.take(8)
502			.collect::<Vec<F>>();
503		let multilin_query = multilinear_query::<P>(&q);
504		let result1 = poly.evaluate(multilin_query.to_ref()).unwrap();
505		let result2 = evaluate_split(poly, &q, &[2, 3, 3]);
506		assert_eq!(result1, result2);
507	}
508
509	fn get_bits<P, PE>(values: &[PE], start_index: usize) -> Vec<PE::Scalar>
510	where
511		P: PackedField,
512		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
513	{
514		let new_vals = values
515			.iter()
516			.flat_map(|v| {
517				(P::WIDTH * start_index..P::WIDTH * (start_index + 1))
518					.map(|i| v.get(i))
519					.collect::<Vec<_>>()
520			})
521			.collect::<Vec<_>>();
522
523		new_vals
524	}
525
526	#[test]
527	fn test_evaluate_middle_32b_to_16b() {
528		let mut rng = StdRng::seed_from_u64(0);
529		let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
530			.take(1 << 2)
531			.collect::<Vec<_>>();
532
533		let expected_lower_bits =
534			get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 0);
535
536		let expected_higher_bits =
537			get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 1);
538
539		let poly = MultilinearExtension::from_values(values).unwrap();
540
541		// Get query to project on lower bits.
542		let q_low = [<BinaryField1b as PackedField>::zero()];
543		// Get query to project on higher bits.
544		let q_hi = [<BinaryField1b as PackedField>::one()];
545
546		// Get expanded query to project on lower bits.
547		let query_low = multilinear_query::<BinaryField1b>(&q_low);
548		// Get expanded query to project on higher bits.
549		let query_hi = multilinear_query::<BinaryField1b>(&q_hi);
550
551		// Get lower bits evaluations.
552		let evals_low = poly.evaluate_partial(query_low.to_ref(), 4).unwrap();
553		let evals_low = evals_low.evals();
554
555		// Get higher bits evaluations.
556		let evals_hi = poly.evaluate_partial(query_hi.to_ref(), 4).unwrap();
557		let evals_hi = evals_hi.evals();
558
559		assert_eq!(evals_low, expected_lower_bits);
560		assert_eq!(evals_hi, expected_higher_bits);
561	}
562
563	#[test]
564	fn test_evaluate_middle_32b_to_8b() {
565		let mut rng = StdRng::seed_from_u64(0);
566		let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
567			.take(1 << 2)
568			.collect::<Vec<_>>();
569
570		let expected_first_quarter =
571			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 0);
572
573		let expected_second_quarter =
574			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 1);
575
576		let expected_third_quarter =
577			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 2);
578
579		let expected_fourth_quarter =
580			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 3);
581
582		let poly = MultilinearExtension::from_values(values).unwrap();
583
584		// Get query to project on first quarter.
585		let q_first = [
586			<BinaryField1b as PackedField>::zero(),
587			<BinaryField1b as PackedField>::zero(),
588		];
589		// Get query to project on second quarter.
590		let q_second = [
591			<BinaryField1b as PackedField>::one(),
592			<BinaryField1b as PackedField>::zero(),
593		];
594		// Get query to project on third quarter.
595		let q_third = [
596			<BinaryField1b as PackedField>::zero(),
597			<BinaryField1b as PackedField>::one(),
598		];
599		// Get query to project on last quarter.
600		let q_fourth = [
601			<BinaryField1b as PackedField>::one(),
602			<BinaryField1b as PackedField>::one(),
603		];
604
605		// Get expanded query to project on first quarter.
606		let query_first = multilinear_query::<BinaryField1b>(&q_first);
607		// Get expanded query to project on second quarter.
608		let query_second = multilinear_query::<BinaryField1b>(&q_second);
609		// Get expanded query to project on third quarter.
610		let query_third = multilinear_query::<BinaryField1b>(&q_third);
611		// Get expanded query to project on last quarter.
612		let query_fourth = multilinear_query::<BinaryField1b>(&q_fourth);
613
614		// Get first quarter of bits evaluations.
615		let evals_first_quarter = poly.evaluate_partial(query_first.to_ref(), 3).unwrap();
616		let evals_first_quarter = evals_first_quarter.evals();
617		// Get second quarter evaluations.
618		let evals_second_quarter = poly.evaluate_partial(query_second.to_ref(), 3).unwrap();
619		let evals_second_quarter = evals_second_quarter.evals();
620		// Get third quarter evaluations.
621		let evals_third_quarter = poly.evaluate_partial(query_third.to_ref(), 3).unwrap();
622		let evals_third_quarter = evals_third_quarter.evals();
623		// Get last quarter evaluations.
624		let evals_fourth_quarter = poly.evaluate_partial(query_fourth.to_ref(), 3).unwrap();
625		let evals_fourth_quarter = evals_fourth_quarter.evals();
626
627		assert_eq!(evals_first_quarter, expected_first_quarter);
628		assert_eq!(evals_second_quarter, expected_second_quarter);
629		assert_eq!(evals_third_quarter, expected_third_quarter);
630		assert_eq!(evals_fourth_quarter, expected_fourth_quarter);
631	}
632
633	#[test]
634	fn test_zeropad_8b_to_32b_project() {
635		let mut rng = StdRng::seed_from_u64(0);
636		let values = repeat_with(|| PackedBinaryField8x1b::random(&mut rng))
637			.take(1 << 2)
638			.collect::<Vec<_>>();
639		let expected_out = PackedBinaryField8x1b::iter_slice(&values).collect::<Vec<_>>();
640
641		let poly = MultilinearExtension::from_values(values).unwrap();
642
643		// Quadruple the number of elements.
644		let n_pad_vars = 2;
645		// Pad each block of 8 consecutive bits to 32 bits.
646		let start_index = 3;
647		// Pad to the right.
648		let nonzero_index = 0;
649		// Pad the polynomial.
650		let padded_poly = poly
651			.zero_pad::<BinaryField1b>(n_pad_vars, start_index, nonzero_index)
652			.unwrap();
653
654		// Now, project to the first quarter of each block of 32 bits.
655		let query_first = multilinear_query::<BinaryField1b>(&[
656			<BinaryField1b as PackedField>::zero(),
657			<BinaryField1b as PackedField>::zero(),
658		]);
659		let projected_poly = padded_poly
660			.evaluate_partial(query_first.to_ref(), 3)
661			.unwrap();
662
663		// Assert that the projection of the padded polynomials is equal to the original.
664		assert_eq!(expected_out, projected_poly.evals());
665	}
666
667	#[test]
668	fn test_evaluate_partial_match_evaluate_partial_low() {
669		type P = PackedBinaryField16x8b;
670		type F = BinaryField8b;
671
672		let mut rng = StdRng::seed_from_u64(0);
673
674		let n_vars: usize = 7;
675
676		let values = repeat_with(|| P::random(&mut rng))
677			.take(1 << (n_vars.saturating_sub(P::LOG_WIDTH)))
678			.collect();
679
680		let query_n_minus_1 = repeat_with(|| <F as PackedField>::random(&mut rng))
681			.take(n_vars - 1)
682			.collect::<Vec<_>>();
683
684		let (q_low, q_high) = query_n_minus_1.split_at(query_n_minus_1.len() / 2);
685
686		// Get expanded query to project on lower bits.
687		let query_low = multilinear_query::<P>(q_low);
688		// Get expanded query to project on higher bits.
689		let query_high = multilinear_query::<P>(q_high);
690
691		let me = MultilinearExtension::from_values(values).unwrap();
692
693		assert_eq!(
694			me.evaluate_partial_low(query_low.to_ref())
695				.unwrap()
696				.evaluate_partial_low(query_high.to_ref())
697				.unwrap(),
698			me.evaluate_partial(query_high.to_ref(), q_low.len())
699				.unwrap()
700				.evaluate_partial_low(query_low.to_ref())
701				.unwrap()
702		)
703	}
704
705	#[test]
706	fn test_evaluate_partial_high_packed() {
707		let mut rng = StdRng::seed_from_u64(0);
708		let evals = repeat_with(|| P::random(&mut rng))
709			.take(256 >> P::LOG_WIDTH)
710			.collect::<Vec<_>>();
711		let poly = MultilinearExtension::from_values(evals).unwrap();
712		let q = repeat_with(|| Field::random(&mut rng))
713			.take(8)
714			.collect::<Vec<BinaryField128b>>();
715		let multilin_query = multilinear_query::<BinaryField128b>(&q);
716
717		let expected = poly.evaluate(multilin_query.to_ref()).unwrap();
718
719		// The final split has a number of coefficients less than the packing width
720		let query_hi = multilinear_query::<BinaryField128b>(&q[1..]);
721		let partial_eval = poly.evaluate_partial_high(query_hi.to_ref()).unwrap();
722		assert!(partial_eval.n_vars() < P::LOG_WIDTH);
723
724		let query_lo = multilinear_query::<BinaryField128b>(&q[..1]);
725		let eval = partial_eval.evaluate(query_lo.to_ref()).unwrap();
726		assert_eq!(eval, expected);
727	}
728
729	#[test]
730	fn test_evaluate_partial_low_high_smaller_than_packed_width() {
731		type P = PackedBinaryField16x8b;
732
733		type F = BinaryField8b;
734
735		let n_vars = 3;
736
737		let mut rng = StdRng::seed_from_u64(0);
738
739		let values = repeat_with(|| Field::random(&mut rng))
740			.take(1 << n_vars)
741			.collect::<Vec<F>>();
742
743		let q = repeat_with(|| Field::random(&mut rng))
744			.take(n_vars)
745			.collect::<Vec<F>>();
746
747		let query = multilinear_query::<P>(&q);
748
749		let packed = P::from_scalars(values);
750		let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
751
752		let eval = me.evaluate(&query).unwrap();
753
754		let query_low = multilinear_query::<P>(&q[..n_vars - 1]);
755		let query_high = multilinear_query::<P>(&q[n_vars - 1..]);
756
757		let eval_l_h = me
758			.evaluate_partial_high(&query_high)
759			.unwrap()
760			.evaluate_partial_low(&query_low)
761			.unwrap()
762			.evals()[0]
763			.get(0);
764
765		assert_eq!(eval, eval_l_h);
766	}
767
768	#[test]
769	fn test_evaluate_on_hypercube_small_than_packed_width() {
770		type P = PackedBinaryField16x8b;
771
772		type F = BinaryField8b;
773
774		let n_vars = 3;
775
776		let mut rng = StdRng::seed_from_u64(0);
777
778		let values = repeat_with(|| Field::random(&mut rng))
779			.take(1 << n_vars)
780			.collect::<Vec<F>>();
781
782		let packed = P::from_scalars(values.clone());
783
784		let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
785
786		assert_eq!(me.evaluate_on_hypercube(1).unwrap(), values[1]);
787
788		assert!(me.evaluate_on_hypercube(1 << n_vars).is_err());
789	}
790
791	#[test]
792	fn test_evaluate_partial_high_low_evaluate_consistent() {
793		let mut rng = StdRng::seed_from_u64(0);
794		let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
795			.take(1 << 8)
796			.collect();
797
798		let me = MultilinearExtension::from_values(values).unwrap();
799
800		let q = repeat_with(|| <BinaryField32b as PackedField>::random(&mut rng))
801			.take(me.n_vars())
802			.collect::<Vec<_>>();
803
804		let query = multilinear_query(&q);
805
806		let eval = me
807			.evaluate::<<PackedBinaryField4x32b as PackedField>::Scalar, PackedBinaryField4x32b>(
808				query.to_ref(),
809			)
810			.unwrap();
811
812		assert_eq!(
813			me.evaluate_partial_low::<PackedBinaryField4x32b>(query.to_ref())
814				.unwrap()
815				.evals[0]
816				.get(0),
817			eval
818		);
819		assert_eq!(
820			me.evaluate_partial_high::<PackedBinaryField4x32b>(query.to_ref())
821				.unwrap()
822				.evals[0]
823				.get(0),
824			eval
825		);
826	}
827
828	#[test]
829	fn test_evaluate_partial_low_single_and_multiple_var_consistent() {
830		let mut rng = StdRng::seed_from_u64(0);
831		let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
832			.take(1 << 8)
833			.collect();
834
835		let mle = MultilinearExtension::from_values(values).unwrap();
836		let r1 = <BinaryField32b as PackedField>::random(&mut rng);
837		let r2 = <BinaryField32b as PackedField>::random(&mut rng);
838
839		let eval_1: MultilinearExtension<PackedBinaryField4x32b> = mle
840			.evaluate_partial_low::<PackedBinaryField4x32b>(multilinear_query(&[r1]).to_ref())
841			.unwrap()
842			.evaluate_partial_low(multilinear_query(&[r2]).to_ref())
843			.unwrap();
844		let eval_2 = mle
845			.evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref())
846			.unwrap();
847		assert_eq!(eval_1, eval_2);
848	}
849
850	#[test]
851	fn test_new_mle_with_tiny_nvars() {
852		MultilinearExtension::new(
853			1,
854			vec![PackedType::<OptimalUnderlier256b, BinaryField32b>::one()],
855		)
856		.unwrap();
857	}
858}