binius_math/
multilinear_extension.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{fmt::Debug, ops::Deref};
4
5use binius_field::{
6	ExtensionField, Field, PackedField,
7	as_packed_field::{AsSinglePacked, PackScalar, PackedType},
8	underlier::UnderlierType,
9	util::inner_product_par,
10};
11use binius_utils::bail;
12use bytemuck::zeroed_vec;
13use tracing::instrument;
14
15use crate::{
16	Error, MultilinearQueryRef, PackingDeref, fold::fold_left, fold_middle, fold_right, zero_pad,
17};
18
19/// A multilinear polynomial represented by its evaluations over the boolean hypercube.
20///
21/// This polynomial can also be viewed as the multilinear extension of the slice of hypercube
22/// evaluations. The evaluation data may be either a borrowed or owned slice.
23///
24/// The packed field width must be a power of two.
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct MultilinearExtension<P: PackedField, Data: Deref<Target = [P]> = Vec<P>> {
27	// The number of variables
28	mu: usize,
29	// The evaluations of the polynomial over the boolean hypercube, in lexicographic order
30	evals: Data,
31}
32
33impl<P: PackedField> MultilinearExtension<P> {
34	pub fn zeros(n_vars: usize) -> Result<Self, Error> {
35		if n_vars < P::LOG_WIDTH {
36			bail!(Error::ArgumentRangeError {
37				arg: "n_vars".into(),
38				range: P::LOG_WIDTH..32,
39			});
40		}
41		Ok(Self {
42			mu: n_vars,
43			evals: vec![P::default(); 1 << (n_vars - P::LOG_WIDTH)],
44		})
45	}
46
47	pub fn from_values(v: Vec<P>) -> Result<Self, Error> {
48		Self::from_values_generic(v)
49	}
50
51	pub fn into_evals(self) -> Vec<P> {
52		self.evals
53	}
54}
55
56impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
57	pub fn from_values_generic(v: Data) -> Result<Self, Error> {
58		if !v.len().is_power_of_two() {
59			bail!(Error::PowerOfTwoLengthRequired);
60		}
61		let mu = log2(v.len()) + P::LOG_WIDTH;
62
63		Ok(Self { mu, evals: v })
64	}
65
66	pub fn new(n_vars: usize, v: Data) -> Result<Self, Error> {
67		if !v.len().is_power_of_two() {
68			bail!(Error::PowerOfTwoLengthRequired);
69		}
70
71		if n_vars < P::LOG_WIDTH {
72			if v.len() != 1 {
73				bail!(Error::IncorrectNumberOfVariables {
74					expected: n_vars,
75					actual: P::LOG_WIDTH + log2(v.len())
76				});
77			}
78		} else if P::LOG_WIDTH + log2(v.len()) != n_vars {
79			bail!(Error::IncorrectNumberOfVariables {
80				expected: n_vars,
81				actual: P::LOG_WIDTH + log2(v.len())
82			});
83		}
84
85		Ok(Self {
86			mu: n_vars,
87			evals: v,
88		})
89	}
90}
91
92impl<U, F, Data> MultilinearExtension<PackedType<U, F>, PackingDeref<U, F, Data>>
93where
94	// TODO: Add U:
95	// Divisible<u8>.
96	U: UnderlierType + PackScalar<F>,
97	F: Field,
98	Data: Deref<Target = [U]>,
99{
100	pub fn from_underliers(v: Data) -> Result<Self, Error> {
101		Self::from_values_generic(PackingDeref::new(v))
102	}
103}
104
105impl<'a, P: PackedField> MultilinearExtension<P, &'a [P]> {
106	pub fn from_values_slice(v: &'a [P]) -> Result<Self, Error> {
107		if !v.len().is_power_of_two() {
108			bail!(Error::PowerOfTwoLengthRequired);
109		}
110		let mu = log2(v.len() * P::WIDTH);
111		Ok(Self { mu, evals: v })
112	}
113}
114
115impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
116	pub const fn n_vars(&self) -> usize {
117		self.mu
118	}
119
120	pub const fn size(&self) -> usize {
121		1 << self.mu
122	}
123
124	pub fn evals(&self) -> &[P] {
125		&self.evals
126	}
127
128	pub fn to_ref(&self) -> MultilinearExtension<P, &[P]> {
129		MultilinearExtension {
130			mu: self.mu,
131			evals: self.evals(),
132		}
133	}
134
135	/// Get the evaluations of the polynomial on a subcube of the hypercube of size equal to the
136	/// packing width.
137	///
138	/// # Arguments
139	///
140	/// * `index` - The index of the subcube
141	pub fn packed_evaluate_on_hypercube(&self, index: usize) -> Result<P, Error> {
142		self.evals()
143			.get(index)
144			.ok_or(Error::HypercubeIndexOutOfRange { index })
145			.copied()
146	}
147
148	pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
149		if self.size() <= index {
150			bail!(Error::HypercubeIndexOutOfRange { index })
151		}
152
153		let subcube_eval = self.packed_evaluate_on_hypercube(index / P::WIDTH)?;
154		Ok(subcube_eval.get(index % P::WIDTH))
155	}
156}
157
158impl<P, Data> MultilinearExtension<P, Data>
159where
160	P: PackedField,
161	Data: Deref<Target = [P]> + Send + Sync,
162{
163	pub fn evaluate<'a, FE, PE>(
164		&self,
165		query: impl Into<MultilinearQueryRef<'a, PE>>,
166	) -> Result<FE, Error>
167	where
168		FE: ExtensionField<P::Scalar>,
169		PE: PackedField<Scalar = FE>,
170	{
171		let query = query.into();
172		if self.mu != query.n_vars() {
173			bail!(Error::IncorrectQuerySize {
174				expected: self.mu,
175				actual: query.n_vars()
176			});
177		}
178
179		if self.mu < P::LOG_WIDTH || query.n_vars() < PE::LOG_WIDTH {
180			let evals = PackedField::iter_slice(self.evals())
181				.take(self.size())
182				.collect::<Vec<P::Scalar>>();
183			let queries = PackedField::iter_slice(query.expansion())
184				.take(1 << query.n_vars())
185				.collect::<Vec<PE::Scalar>>();
186			Ok(inner_product_par(&queries, &evals))
187		} else {
188			Ok(inner_product_par(query.expansion(), &self.evals))
189		}
190	}
191
192	#[instrument("MultilinearExtension::evaluate_partial", skip_all, level = "debug")]
193	pub fn evaluate_partial<'a, PE>(
194		&self,
195		query: impl Into<MultilinearQueryRef<'a, PE>>,
196		start_index: usize,
197	) -> Result<MultilinearExtension<PE>, Error>
198	where
199		PE: PackedField,
200		PE::Scalar: ExtensionField<P::Scalar>,
201	{
202		let query = query.into();
203		if start_index + query.n_vars() > self.mu {
204			bail!(Error::IncorrectStartIndex { expected: self.mu })
205		}
206
207		if start_index == 0 {
208			return self.evaluate_partial_low(query);
209		} else if start_index + query.n_vars() == self.mu {
210			return self.evaluate_partial_high(query);
211		}
212
213		if self.mu < query.n_vars() {
214			bail!(Error::IncorrectQuerySize {
215				expected: self.mu,
216				actual: query.n_vars()
217			});
218		}
219
220		let new_n_vars = self.mu - query.n_vars();
221		let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
222		let mut result_evals = Vec::with_capacity(result_evals_len);
223
224		fold_middle(
225			self.evals(),
226			self.mu,
227			query.expansion(),
228			query.n_vars(),
229			start_index,
230			result_evals.spare_capacity_mut(),
231		)?;
232		unsafe {
233			result_evals.set_len(result_evals_len);
234		}
235
236		MultilinearExtension::new(new_n_vars, result_evals)
237	}
238
239	/// Partially evaluate the polynomial with assignment to the high-indexed variables.
240	///
241	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu - 1})$. Given a
242	/// query vector of length $k$ representing $(z_{\mu - k + 1}, ..., z_{\mu - 1})$, this returns
243	/// the multilinear polynomial with $\mu - k$ variables,
244	/// $p(X_0, ..., X_{\mu - k}, z_{\mu - k + 1}, ..., z_{\mu - 1})$.
245	///
246	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
247	/// PE::WIDTH, i.e. 2^(\mu - k) \geq PE::WIDTH, since WIDTH is power of two
248	#[instrument(
249		"MultilinearExtension::evaluate_partial_high",
250		skip_all,
251		level = "debug"
252	)]
253	pub fn evaluate_partial_high<'a, PE>(
254		&self,
255		query: impl Into<MultilinearQueryRef<'a, PE>>,
256	) -> Result<MultilinearExtension<PE>, Error>
257	where
258		PE: PackedField,
259		PE::Scalar: ExtensionField<P::Scalar>,
260	{
261		let query = query.into();
262
263		let new_n_vars = self.mu.saturating_sub(query.n_vars());
264		let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
265		let mut result_evals = Vec::with_capacity(result_evals_len);
266
267		fold_left(
268			self.evals(),
269			self.mu,
270			query.expansion(),
271			query.n_vars(),
272			result_evals.spare_capacity_mut(),
273		)?;
274		unsafe {
275			result_evals.set_len(result_evals_len);
276		}
277
278		MultilinearExtension::new(new_n_vars, result_evals)
279	}
280
281	/// Partially evaluate the polynomial with assignment to the low-indexed variables.
282	///
283	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu-1}$. Given a query
284	/// vector of length $k$ representing $(z_0, ..., z_{k-1})$, this returns the
285	/// multilinear polynomial with $\mu - k$ variables,
286	/// $p(z_0, ..., z_{k-1}, X_k, ..., X_{\mu - 1})$.
287	///
288	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
289	/// P::WIDTH, i.e. 2^(\mu - k) \geq P::WIDTH, since WIDTH is power of two
290	#[instrument(
291		"MultilinearExtension::evaluate_partial_low",
292		skip_all,
293		level = "trace"
294	)]
295	pub fn evaluate_partial_low<'a, PE>(
296		&self,
297		query: impl Into<MultilinearQueryRef<'a, PE>>,
298	) -> Result<MultilinearExtension<PE>, Error>
299	where
300		PE: PackedField,
301		PE::Scalar: ExtensionField<P::Scalar>,
302	{
303		let query = query.into();
304
305		if self.mu < query.n_vars() {
306			bail!(Error::IncorrectQuerySize {
307				expected: self.mu,
308				actual: query.n_vars()
309			});
310		}
311
312		let new_n_vars = self.mu - query.n_vars();
313
314		let mut result =
315			zeroed_vec(1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)));
316		self.evaluate_partial_low_into(query, &mut result)?;
317		MultilinearExtension::new(new_n_vars, result)
318	}
319
320	/// Partially evaluate the polynomial with assignment to the low-indexed variables.
321	///
322	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu-1}$. Given a query
323	/// vector of length $k$ representing $(z_0, ..., z_{k-1})$, this returns the
324	/// multilinear polynomial with $\mu - k$ variables,
325	/// $p(z_0, ..., z_{k-1}, X_k, ..., X_{\mu - 1})$.
326	///
327	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
328	/// P::WIDTH, i.e. 2^(\mu - k) \geq P::WIDTH, since WIDTH is power of two
329	pub fn evaluate_partial_low_into<PE>(
330		&self,
331		query: MultilinearQueryRef<PE>,
332		out: &mut [PE],
333	) -> Result<(), Error>
334	where
335		PE: PackedField,
336		PE::Scalar: ExtensionField<P::Scalar>,
337	{
338		// This operation is a matrix-vector product of the matrix of multilinear coefficients with
339		// the vector of tensor product-expanded query coefficients.
340		fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out)
341	}
342
343	pub fn zero_pad<PE>(
344		&self,
345		n_pad_vars: usize,
346		start_index: usize,
347		nonzero_index: usize,
348	) -> Result<MultilinearExtension<PE>, Error>
349	where
350		PE: PackedField,
351		PE::Scalar: ExtensionField<P::Scalar>,
352	{
353		let init_n_vars = self.mu;
354		if start_index > init_n_vars {
355			bail!(Error::IncorrectStartIndexZeroPad { expected: self.mu })
356		}
357		let new_n_vars = init_n_vars + n_pad_vars;
358		if nonzero_index >= 1 << n_pad_vars {
359			bail!(Error::IncorrectNonZeroIndex {
360				expected: 1 << n_pad_vars,
361			});
362		}
363
364		let mut result = zeroed_vec(1 << new_n_vars);
365
366		zero_pad(&self.evals, init_n_vars, new_n_vars, start_index, nonzero_index, &mut result)?;
367		MultilinearExtension::new(new_n_vars, result)
368	}
369}
370
371impl<F: Field + AsSinglePacked, Data: Deref<Target = [F]>> MultilinearExtension<F, Data> {
372	/// Convert MultilinearExtension over a scalar to a MultilinearExtension over a packed field
373	/// with single element.
374	pub fn to_single_packed(self) -> MultilinearExtension<F::Packed> {
375		let packed_evals = self
376			.evals
377			.iter()
378			.map(|eval| eval.to_single_packed())
379			.collect();
380		MultilinearExtension {
381			mu: self.mu,
382			evals: packed_evals,
383		}
384	}
385}
386
387const fn log2(v: usize) -> usize {
388	63 - (v as u64).leading_zeros() as usize
389}
390
391/// Type alias for the common pattern of a [`MultilinearExtension`] backed by borrowed data.
392pub type MultilinearExtensionBorrowed<'a, P> = MultilinearExtension<P, &'a [P]>;
393
394#[cfg(test)]
395mod tests {
396	use std::iter::repeat_with;
397
398	use binius_field::{
399		BinaryField1b, BinaryField8b, BinaryField16b as F, BinaryField32b, BinaryField128b,
400		PackedBinaryField4x32b, PackedBinaryField8x1b, PackedBinaryField8x16b as P,
401		PackedBinaryField16x1b, PackedBinaryField16x8b, PackedBinaryField32x1b,
402		arch::OptimalUnderlier256b,
403	};
404	use itertools::Itertools;
405	use rand::{SeedableRng, rngs::StdRng};
406
407	use super::*;
408	use crate::{MultilinearQuery, tensor_prod_eq_ind};
409
410	/// Expand the tensor product of the query values.
411	///
412	/// [`query`] is a sequence of field elements $z_0, ..., z_{k-1}$.
413	///
414	/// This naive implementation runs in O(k 2^k) time and O(1) space.
415	fn expand_query_naive<F: Field>(query: &[F]) -> Result<Vec<F>, Error> {
416		let result = (0..1 << query.len())
417			.map(|i| eval_basis(query, i))
418			.collect();
419		Ok(result)
420	}
421
422	/// Evaluates the Lagrange basis polynomial over the boolean hypercube at a queried point.
423	fn eval_basis<F: Field>(query: &[F], i: usize) -> F {
424		query
425			.iter()
426			.enumerate()
427			.map(|(j, &v)| if i & (1 << j) == 0 { F::ONE - v } else { v })
428			.product()
429	}
430
431	fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
432		let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
433		result[0] = P::set_single(P::Scalar::ONE);
434		tensor_prod_eq_ind(0, &mut result, p).unwrap();
435		MultilinearQuery::with_expansion(p.len(), result).unwrap()
436	}
437
438	#[test]
439	fn test_expand_query_impls_consistent() {
440		let mut rng = StdRng::seed_from_u64(0);
441		let q = repeat_with(|| Field::random(&mut rng))
442			.take(8)
443			.collect::<Vec<F>>();
444		let result1 = multilinear_query::<P>(&q);
445		let result2 = expand_query_naive(&q).unwrap();
446		assert_eq!(PackedField::iter_slice(result1.expansion()).collect_vec(), result2);
447	}
448
449	#[test]
450	fn test_new_from_values_correspondence() {
451		let mut rng = StdRng::seed_from_u64(0);
452		let evals = repeat_with(|| Field::random(&mut rng))
453			.take(256)
454			.collect::<Vec<F>>();
455		let poly1 = MultilinearExtension::from_values(evals.clone()).unwrap();
456		let poly2 = MultilinearExtension::new(8, evals).unwrap();
457
458		assert_eq!(poly1, poly2)
459	}
460
461	#[test]
462	fn test_evaluate_on_hypercube() {
463		let mut values = vec![F::ZERO; 64];
464		values
465			.iter_mut()
466			.enumerate()
467			.for_each(|(i, val)| *val = F::new(i as u16));
468
469		let poly = MultilinearExtension::from_values(values).unwrap();
470		for i in 0..64 {
471			let q = (0..6)
472				.map(|j| if (i >> j) & 1 != 0 { F::ONE } else { F::ZERO })
473				.collect::<Vec<_>>();
474			let multilin_query = multilinear_query::<P>(&q);
475			let result = poly.evaluate(multilin_query.to_ref()).unwrap();
476			assert_eq!(result, F::new(i));
477		}
478	}
479
480	fn evaluate_split<P>(
481		poly: MultilinearExtension<P>,
482		q: &[P::Scalar],
483		splits: &[usize],
484	) -> P::Scalar
485	where
486		P: PackedField + 'static,
487	{
488		assert_eq!(splits.iter().sum::<usize>(), poly.n_vars());
489
490		let mut partial_result = poly;
491		let mut index = q.len();
492		for split_vars in &splits[0..splits.len() - 1] {
493			let split_vars = *split_vars;
494			let query = multilinear_query(&q[index - split_vars..index]);
495			partial_result = partial_result
496				.evaluate_partial_high(query.to_ref())
497				.unwrap();
498			index -= split_vars;
499		}
500		let multilin_query = multilinear_query::<P>(&q[..index]);
501		partial_result.evaluate(multilin_query.to_ref()).unwrap()
502	}
503
504	#[test]
505	fn test_evaluate_split_is_correct() {
506		let mut rng = StdRng::seed_from_u64(0);
507		let evals = repeat_with(|| Field::random(&mut rng))
508			.take(256)
509			.collect::<Vec<F>>();
510		let poly = MultilinearExtension::from_values(evals).unwrap();
511		let q = repeat_with(|| Field::random(&mut rng))
512			.take(8)
513			.collect::<Vec<F>>();
514		let multilin_query = multilinear_query::<P>(&q);
515		let result1 = poly.evaluate(multilin_query.to_ref()).unwrap();
516		let result2 = evaluate_split(poly, &q, &[2, 3, 3]);
517		assert_eq!(result1, result2);
518	}
519
520	fn get_bits<P, PE>(values: &[PE], start_index: usize) -> Vec<PE::Scalar>
521	where
522		P: PackedField,
523		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
524	{
525		values
526			.iter()
527			.flat_map(|v| {
528				(P::WIDTH * start_index..P::WIDTH * (start_index + 1))
529					.map(|i| v.get(i))
530					.collect::<Vec<_>>()
531			})
532			.collect::<Vec<_>>()
533	}
534
535	#[test]
536	fn test_evaluate_middle_32b_to_16b() {
537		let mut rng = StdRng::seed_from_u64(0);
538		let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
539			.take(1 << 2)
540			.collect::<Vec<_>>();
541
542		let expected_lower_bits =
543			get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 0);
544
545		let expected_higher_bits =
546			get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 1);
547
548		let poly = MultilinearExtension::from_values(values).unwrap();
549
550		// Get query to project on lower bits.
551		let q_low = [<BinaryField1b as PackedField>::zero()];
552		// Get query to project on higher bits.
553		let q_hi = [<BinaryField1b as PackedField>::one()];
554
555		// Get expanded query to project on lower bits.
556		let query_low = multilinear_query::<BinaryField1b>(&q_low);
557		// Get expanded query to project on higher bits.
558		let query_hi = multilinear_query::<BinaryField1b>(&q_hi);
559
560		// Get lower bits evaluations.
561		let evals_low = poly.evaluate_partial(query_low.to_ref(), 4).unwrap();
562		let evals_low = evals_low.evals();
563
564		// Get higher bits evaluations.
565		let evals_hi = poly.evaluate_partial(query_hi.to_ref(), 4).unwrap();
566		let evals_hi = evals_hi.evals();
567
568		assert_eq!(evals_low, expected_lower_bits);
569		assert_eq!(evals_hi, expected_higher_bits);
570	}
571
572	#[test]
573	fn test_evaluate_middle_32b_to_8b() {
574		let mut rng = StdRng::seed_from_u64(0);
575		let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
576			.take(1 << 2)
577			.collect::<Vec<_>>();
578
579		let expected_first_quarter =
580			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 0);
581
582		let expected_second_quarter =
583			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 1);
584
585		let expected_third_quarter =
586			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 2);
587
588		let expected_fourth_quarter =
589			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 3);
590
591		let poly = MultilinearExtension::from_values(values).unwrap();
592
593		// Get query to project on first quarter.
594		let q_first = [
595			<BinaryField1b as PackedField>::zero(),
596			<BinaryField1b as PackedField>::zero(),
597		];
598		// Get query to project on second quarter.
599		let q_second = [
600			<BinaryField1b as PackedField>::one(),
601			<BinaryField1b as PackedField>::zero(),
602		];
603		// Get query to project on third quarter.
604		let q_third = [
605			<BinaryField1b as PackedField>::zero(),
606			<BinaryField1b as PackedField>::one(),
607		];
608		// Get query to project on last quarter.
609		let q_fourth = [
610			<BinaryField1b as PackedField>::one(),
611			<BinaryField1b as PackedField>::one(),
612		];
613
614		// Get expanded query to project on first quarter.
615		let query_first = multilinear_query::<BinaryField1b>(&q_first);
616		// Get expanded query to project on second quarter.
617		let query_second = multilinear_query::<BinaryField1b>(&q_second);
618		// Get expanded query to project on third quarter.
619		let query_third = multilinear_query::<BinaryField1b>(&q_third);
620		// Get expanded query to project on last quarter.
621		let query_fourth = multilinear_query::<BinaryField1b>(&q_fourth);
622
623		// Get first quarter of bits evaluations.
624		let evals_first_quarter = poly.evaluate_partial(query_first.to_ref(), 3).unwrap();
625		let evals_first_quarter = evals_first_quarter.evals();
626		// Get second quarter evaluations.
627		let evals_second_quarter = poly.evaluate_partial(query_second.to_ref(), 3).unwrap();
628		let evals_second_quarter = evals_second_quarter.evals();
629		// Get third quarter evaluations.
630		let evals_third_quarter = poly.evaluate_partial(query_third.to_ref(), 3).unwrap();
631		let evals_third_quarter = evals_third_quarter.evals();
632		// Get last quarter evaluations.
633		let evals_fourth_quarter = poly.evaluate_partial(query_fourth.to_ref(), 3).unwrap();
634		let evals_fourth_quarter = evals_fourth_quarter.evals();
635
636		assert_eq!(evals_first_quarter, expected_first_quarter);
637		assert_eq!(evals_second_quarter, expected_second_quarter);
638		assert_eq!(evals_third_quarter, expected_third_quarter);
639		assert_eq!(evals_fourth_quarter, expected_fourth_quarter);
640	}
641
642	#[test]
643	fn test_zeropad_8b_to_32b_project() {
644		let mut rng = StdRng::seed_from_u64(0);
645		let values = repeat_with(|| PackedBinaryField8x1b::random(&mut rng))
646			.take(1 << 2)
647			.collect::<Vec<_>>();
648		let expected_out = PackedBinaryField8x1b::iter_slice(&values).collect::<Vec<_>>();
649
650		let poly = MultilinearExtension::from_values(values).unwrap();
651
652		// Quadruple the number of elements.
653		let n_pad_vars = 2;
654		// Pad each block of 8 consecutive bits to 32 bits.
655		let start_index = 3;
656		// Pad to the right.
657		let nonzero_index = 0;
658		// Pad the polynomial.
659		let padded_poly = poly
660			.zero_pad::<BinaryField1b>(n_pad_vars, start_index, nonzero_index)
661			.unwrap();
662
663		// Now, project to the first quarter of each block of 32 bits.
664		let query_first = multilinear_query::<BinaryField1b>(&[
665			<BinaryField1b as PackedField>::zero(),
666			<BinaryField1b as PackedField>::zero(),
667		]);
668		let projected_poly = padded_poly
669			.evaluate_partial(query_first.to_ref(), 3)
670			.unwrap();
671
672		// Assert that the projection of the padded polynomials is equal to the original.
673		assert_eq!(expected_out, projected_poly.evals());
674	}
675
676	#[test]
677	fn test_evaluate_partial_match_evaluate_partial_low() {
678		type P = PackedBinaryField16x8b;
679		type F = BinaryField8b;
680
681		let mut rng = StdRng::seed_from_u64(0);
682
683		let n_vars: usize = 7;
684
685		let values = repeat_with(|| P::random(&mut rng))
686			.take(1 << (n_vars.saturating_sub(P::LOG_WIDTH)))
687			.collect();
688
689		let query_n_minus_1 = repeat_with(|| <F as PackedField>::random(&mut rng))
690			.take(n_vars - 1)
691			.collect::<Vec<_>>();
692
693		let (q_low, q_high) = query_n_minus_1.split_at(query_n_minus_1.len() / 2);
694
695		// Get expanded query to project on lower bits.
696		let query_low = multilinear_query::<P>(q_low);
697		// Get expanded query to project on higher bits.
698		let query_high = multilinear_query::<P>(q_high);
699
700		let me = MultilinearExtension::from_values(values).unwrap();
701
702		assert_eq!(
703			me.evaluate_partial_low(query_low.to_ref())
704				.unwrap()
705				.evaluate_partial_low(query_high.to_ref())
706				.unwrap(),
707			me.evaluate_partial(query_high.to_ref(), q_low.len())
708				.unwrap()
709				.evaluate_partial_low(query_low.to_ref())
710				.unwrap()
711		)
712	}
713
714	#[test]
715	fn test_evaluate_partial_high_packed() {
716		let mut rng = StdRng::seed_from_u64(0);
717		let evals = repeat_with(|| P::random(&mut rng))
718			.take(256 >> P::LOG_WIDTH)
719			.collect::<Vec<_>>();
720		let poly = MultilinearExtension::from_values(evals).unwrap();
721		let q = repeat_with(|| Field::random(&mut rng))
722			.take(8)
723			.collect::<Vec<BinaryField128b>>();
724		let multilin_query = multilinear_query::<BinaryField128b>(&q);
725
726		let expected = poly.evaluate(multilin_query.to_ref()).unwrap();
727
728		// The final split has a number of coefficients less than the packing width
729		let query_hi = multilinear_query::<BinaryField128b>(&q[1..]);
730		let partial_eval = poly.evaluate_partial_high(query_hi.to_ref()).unwrap();
731		assert!(partial_eval.n_vars() < P::LOG_WIDTH);
732
733		let query_lo = multilinear_query::<BinaryField128b>(&q[..1]);
734		let eval = partial_eval.evaluate(query_lo.to_ref()).unwrap();
735		assert_eq!(eval, expected);
736	}
737
738	#[test]
739	fn test_evaluate_partial_low_high_smaller_than_packed_width() {
740		type P = PackedBinaryField16x8b;
741
742		type F = BinaryField8b;
743
744		let n_vars = 3;
745
746		let mut rng = StdRng::seed_from_u64(0);
747
748		let values = repeat_with(|| Field::random(&mut rng))
749			.take(1 << n_vars)
750			.collect::<Vec<F>>();
751
752		let q = repeat_with(|| Field::random(&mut rng))
753			.take(n_vars)
754			.collect::<Vec<F>>();
755
756		let query = multilinear_query::<P>(&q);
757
758		let packed = P::from_scalars(values);
759		let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
760
761		let eval = me.evaluate(&query).unwrap();
762
763		let query_low = multilinear_query::<P>(&q[..n_vars - 1]);
764		let query_high = multilinear_query::<P>(&q[n_vars - 1..]);
765
766		let eval_l_h = me
767			.evaluate_partial_high(&query_high)
768			.unwrap()
769			.evaluate_partial_low(&query_low)
770			.unwrap()
771			.evals()[0]
772			.get(0);
773
774		assert_eq!(eval, eval_l_h);
775	}
776
777	#[test]
778	fn test_evaluate_on_hypercube_small_than_packed_width() {
779		type P = PackedBinaryField16x8b;
780
781		type F = BinaryField8b;
782
783		let n_vars = 3;
784
785		let mut rng = StdRng::seed_from_u64(0);
786
787		let values = repeat_with(|| Field::random(&mut rng))
788			.take(1 << n_vars)
789			.collect::<Vec<F>>();
790
791		let packed = P::from_scalars(values.clone());
792
793		let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
794
795		assert_eq!(me.evaluate_on_hypercube(1).unwrap(), values[1]);
796
797		assert!(me.evaluate_on_hypercube(1 << n_vars).is_err());
798	}
799
800	#[test]
801	fn test_evaluate_partial_high_low_evaluate_consistent() {
802		let mut rng = StdRng::seed_from_u64(0);
803		let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
804			.take(1 << 8)
805			.collect();
806
807		let me = MultilinearExtension::from_values(values).unwrap();
808
809		let q = repeat_with(|| <BinaryField32b as PackedField>::random(&mut rng))
810			.take(me.n_vars())
811			.collect::<Vec<_>>();
812
813		let query = multilinear_query(&q);
814
815		let eval = me
816			.evaluate::<<PackedBinaryField4x32b as PackedField>::Scalar, PackedBinaryField4x32b>(
817				query.to_ref(),
818			)
819			.unwrap();
820
821		assert_eq!(
822			me.evaluate_partial_low::<PackedBinaryField4x32b>(query.to_ref())
823				.unwrap()
824				.evals[0]
825				.get(0),
826			eval
827		);
828		assert_eq!(
829			me.evaluate_partial_high::<PackedBinaryField4x32b>(query.to_ref())
830				.unwrap()
831				.evals[0]
832				.get(0),
833			eval
834		);
835	}
836
837	#[test]
838	fn test_evaluate_partial_low_single_and_multiple_var_consistent() {
839		let mut rng = StdRng::seed_from_u64(0);
840		let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
841			.take(1 << 8)
842			.collect();
843
844		let mle = MultilinearExtension::from_values(values).unwrap();
845		let r1 = <BinaryField32b as PackedField>::random(&mut rng);
846		let r2 = <BinaryField32b as PackedField>::random(&mut rng);
847
848		let eval_1: MultilinearExtension<PackedBinaryField4x32b> = mle
849			.evaluate_partial_low::<PackedBinaryField4x32b>(multilinear_query(&[r1]).to_ref())
850			.unwrap()
851			.evaluate_partial_low(multilinear_query(&[r2]).to_ref())
852			.unwrap();
853		let eval_2 = mle
854			.evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref())
855			.unwrap();
856		assert_eq!(eval_1, eval_2);
857	}
858
859	#[test]
860	fn test_new_mle_with_tiny_nvars() {
861		MultilinearExtension::new(
862			1,
863			vec![PackedType::<OptimalUnderlier256b, BinaryField32b>::one()],
864		)
865		.unwrap();
866	}
867}