binius_math/
multilinear_extension.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{fmt::Debug, ops::Deref};
4
5use binius_field::{
6	as_packed_field::{AsSinglePacked, PackScalar, PackedType},
7	underlier::UnderlierType,
8	util::inner_product_par,
9	ExtensionField, Field, PackedField,
10};
11use binius_utils::bail;
12use bytemuck::zeroed_vec;
13use tracing::instrument;
14
15use crate::{fold::fold_left, fold_middle, fold_right, Error, MultilinearQueryRef, PackingDeref};
16
17/// A multilinear polynomial represented by its evaluations over the boolean hypercube.
18///
19/// This polynomial can also be viewed as the multilinear extension of the slice of hypercube
20/// evaluations. The evaluation data may be either a borrowed or owned slice.
21///
22/// The packed field width must be a power of two.
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct MultilinearExtension<P: PackedField, Data: Deref<Target = [P]> = Vec<P>> {
25	// The number of variables
26	mu: usize,
27	// The evaluations of the polynomial over the boolean hypercube, in lexicographic order
28	evals: Data,
29}
30
31impl<P: PackedField> MultilinearExtension<P> {
32	pub fn zeros(n_vars: usize) -> Result<Self, Error> {
33		if n_vars < P::LOG_WIDTH {
34			bail!(Error::ArgumentRangeError {
35				arg: "n_vars".into(),
36				range: P::LOG_WIDTH..32,
37			});
38		}
39		Ok(Self {
40			mu: n_vars,
41			evals: vec![P::default(); 1 << (n_vars - P::LOG_WIDTH)],
42		})
43	}
44
45	pub fn from_values(v: Vec<P>) -> Result<Self, Error> {
46		Self::from_values_generic(v)
47	}
48
49	pub fn into_evals(self) -> Vec<P> {
50		self.evals
51	}
52}
53
54impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
55	pub fn from_values_generic(v: Data) -> Result<Self, Error> {
56		if !v.len().is_power_of_two() {
57			bail!(Error::PowerOfTwoLengthRequired);
58		}
59		let mu = log2(v.len()) + P::LOG_WIDTH;
60
61		Ok(Self { mu, evals: v })
62	}
63
64	pub fn new(n_vars: usize, v: Data) -> Result<Self, Error> {
65		if !v.len().is_power_of_two() {
66			bail!(Error::PowerOfTwoLengthRequired);
67		}
68
69		if n_vars < P::LOG_WIDTH {
70			if v.len() != 1 {
71				bail!(Error::IncorrectNumberOfVariables {
72					expected: n_vars,
73					actual: P::LOG_WIDTH + log2(v.len())
74				});
75			}
76		} else if P::LOG_WIDTH + log2(v.len()) != n_vars {
77			bail!(Error::IncorrectNumberOfVariables {
78				expected: n_vars,
79				actual: P::LOG_WIDTH + log2(v.len())
80			});
81		}
82
83		Ok(Self {
84			mu: n_vars,
85			evals: v,
86		})
87	}
88}
89
90impl<U, F, Data> MultilinearExtension<PackedType<U, F>, PackingDeref<U, F, Data>>
91where
92	// TODO: Add U: Divisible<u8>.
93	U: UnderlierType + PackScalar<F>,
94	F: Field,
95	Data: Deref<Target = [U]>,
96{
97	pub fn from_underliers(v: Data) -> Result<Self, Error> {
98		Self::from_values_generic(PackingDeref::new(v))
99	}
100}
101
102impl<'a, P: PackedField> MultilinearExtension<P, &'a [P]> {
103	pub fn from_values_slice(v: &'a [P]) -> Result<Self, Error> {
104		if !v.len().is_power_of_two() {
105			bail!(Error::PowerOfTwoLengthRequired);
106		}
107		let mu = log2(v.len() * P::WIDTH);
108		Ok(Self { mu, evals: v })
109	}
110}
111
112impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
113	pub const fn n_vars(&self) -> usize {
114		self.mu
115	}
116
117	pub const fn size(&self) -> usize {
118		1 << self.mu
119	}
120
121	pub fn evals(&self) -> &[P] {
122		&self.evals
123	}
124
125	pub fn to_ref(&self) -> MultilinearExtension<P, &[P]> {
126		MultilinearExtension {
127			mu: self.mu,
128			evals: self.evals(),
129		}
130	}
131
132	/// Get the evaluations of the polynomial on a subcube of the hypercube of size equal to the
133	/// packing width.
134	///
135	/// # Arguments
136	///
137	/// * `index` - The index of the subcube
138	pub fn packed_evaluate_on_hypercube(&self, index: usize) -> Result<P, Error> {
139		self.evals()
140			.get(index)
141			.ok_or(Error::HypercubeIndexOutOfRange { index })
142			.copied()
143	}
144
145	pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
146		if self.size() <= index {
147			bail!(Error::HypercubeIndexOutOfRange { index })
148		}
149
150		let subcube_eval = self.packed_evaluate_on_hypercube(index / P::WIDTH)?;
151		Ok(subcube_eval.get(index % P::WIDTH))
152	}
153}
154
155impl<P, Data> MultilinearExtension<P, Data>
156where
157	P: PackedField,
158	Data: Deref<Target = [P]> + Send + Sync,
159{
160	pub fn evaluate<'a, FE, PE>(
161		&self,
162		query: impl Into<MultilinearQueryRef<'a, PE>>,
163	) -> Result<FE, Error>
164	where
165		FE: ExtensionField<P::Scalar>,
166		PE: PackedField<Scalar = FE>,
167	{
168		let query = query.into();
169		if self.mu != query.n_vars() {
170			bail!(Error::IncorrectQuerySize { expected: self.mu });
171		}
172
173		if self.mu < P::LOG_WIDTH || query.n_vars() < PE::LOG_WIDTH {
174			let evals = PackedField::iter_slice(self.evals())
175				.take(self.size())
176				.collect::<Vec<P::Scalar>>();
177			let querys = PackedField::iter_slice(query.expansion())
178				.take(1 << query.n_vars())
179				.collect::<Vec<PE::Scalar>>();
180			Ok(inner_product_par(&querys, &evals))
181		} else {
182			Ok(inner_product_par(query.expansion(), &self.evals))
183		}
184	}
185
186	#[instrument("MultilinearExtension::evaluate_partial", skip_all, level = "debug")]
187	pub fn evaluate_partial<'a, PE>(
188		&self,
189		query: impl Into<MultilinearQueryRef<'a, PE>>,
190		start_index: usize,
191	) -> Result<MultilinearExtension<PE>, Error>
192	where
193		PE: PackedField,
194		PE::Scalar: ExtensionField<P::Scalar>,
195	{
196		let query = query.into();
197		if start_index + query.n_vars() > self.mu {
198			bail!(Error::IncorrectStartIndex { expected: self.mu })
199		}
200
201		if start_index == 0 {
202			return self.evaluate_partial_low(query);
203		} else if start_index + query.n_vars() == self.mu {
204			return self.evaluate_partial_high(query);
205		}
206
207		if self.mu < query.n_vars() {
208			bail!(Error::IncorrectQuerySize { expected: self.mu });
209		}
210
211		let new_n_vars = self.mu - query.n_vars();
212		let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
213		let mut result_evals = Vec::with_capacity(result_evals_len);
214
215		fold_middle(
216			self.evals(),
217			self.mu,
218			query.expansion(),
219			query.n_vars(),
220			start_index,
221			result_evals.spare_capacity_mut(),
222		)?;
223		unsafe {
224			result_evals.set_len(result_evals_len);
225		}
226
227		MultilinearExtension::new(new_n_vars, result_evals)
228	}
229
230	/// Partially evaluate the polynomial with assignment to the high-indexed variables.
231	///
232	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu - 1})$. Given a query
233	/// vector of length $k$ representing $(z_{\mu - k + 1}, ..., z_{\mu - 1})$, this returns the
234	/// multilinear polynomial with $\mu - k$ variables,
235	/// $p(X_0, ..., X_{\mu - k}, z_{\mu - k + 1}, ..., z_{\mu - 1})$.
236	///
237	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
238	/// PE::WIDTH, i.e. 2^(\mu - k) \geq PE::WIDTH, since WIDTH is power of two
239	#[instrument(
240		"MultilinearExtension::evaluate_partial_high",
241		skip_all,
242		level = "debug"
243	)]
244	pub fn evaluate_partial_high<'a, PE>(
245		&self,
246		query: impl Into<MultilinearQueryRef<'a, PE>>,
247	) -> Result<MultilinearExtension<PE>, Error>
248	where
249		PE: PackedField,
250		PE::Scalar: ExtensionField<P::Scalar>,
251	{
252		let query = query.into();
253
254		let new_n_vars = self.mu.saturating_sub(query.n_vars());
255		let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
256		let mut result_evals = Vec::with_capacity(result_evals_len);
257
258		fold_left(
259			self.evals(),
260			self.mu,
261			query.expansion(),
262			query.n_vars(),
263			result_evals.spare_capacity_mut(),
264		)?;
265		unsafe {
266			result_evals.set_len(result_evals_len);
267		}
268
269		MultilinearExtension::new(new_n_vars, result_evals)
270	}
271
272	/// Partially evaluate the polynomial with assignment to the low-indexed variables.
273	///
274	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu-1}$. Given a query
275	/// vector of length $k$ representing $(z_0, ..., z_{k-1})$, this returns the
276	/// multilinear polynomial with $\mu - k$ variables,
277	/// $p(z_0, ..., z_{k-1}, X_k, ..., X_{\mu - 1})$.
278	///
279	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
280	/// P::WIDTH, i.e. 2^(\mu - k) \geq P::WIDTH, since WIDTH is power of two
281	#[instrument(
282		"MultilinearExtension::evaluate_partial_low",
283		skip_all,
284		level = "trace"
285	)]
286	pub fn evaluate_partial_low<'a, PE>(
287		&self,
288		query: impl Into<MultilinearQueryRef<'a, PE>>,
289	) -> Result<MultilinearExtension<PE>, Error>
290	where
291		PE: PackedField,
292		PE::Scalar: ExtensionField<P::Scalar>,
293	{
294		let query = query.into();
295
296		if self.mu < query.n_vars() {
297			bail!(Error::IncorrectQuerySize { expected: self.mu });
298		}
299
300		let new_n_vars = self.mu - query.n_vars();
301
302		let mut result =
303			zeroed_vec(1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)));
304		self.evaluate_partial_low_into(query, &mut result)?;
305		MultilinearExtension::new(new_n_vars, result)
306	}
307
308	/// Partially evaluate the polynomial with assignment to the low-indexed variables.
309	///
310	/// The polynomial is multilinear with $\mu$ variables, $p(X_0, ..., X_{\mu-1}$. Given a query
311	/// vector of length $k$ representing $(z_0, ..., z_{k-1})$, this returns the
312	/// multilinear polynomial with $\mu - k$ variables,
313	/// $p(z_0, ..., z_{k-1}, X_k, ..., X_{\mu - 1})$.
314	///
315	/// REQUIRES: the size of the resulting polynomial must have a length which is a multiple of
316	/// P::WIDTH, i.e. 2^(\mu - k) \geq P::WIDTH, since WIDTH is power of two
317	pub fn evaluate_partial_low_into<PE>(
318		&self,
319		query: MultilinearQueryRef<PE>,
320		out: &mut [PE],
321	) -> Result<(), Error>
322	where
323		PE: PackedField,
324		PE::Scalar: ExtensionField<P::Scalar>,
325	{
326		// This operation is a matrix-vector product of the matrix of multilinear coefficients with
327		// the vector of tensor product-expanded query coefficients.
328		fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out)
329	}
330}
331
332impl<F: Field + AsSinglePacked, Data: Deref<Target = [F]>> MultilinearExtension<F, Data> {
333	/// Convert MultilinearExtension over a scalar to a MultilinearExtension over a packed field with single element.
334	pub fn to_single_packed(self) -> MultilinearExtension<F::Packed> {
335		let packed_evals = self
336			.evals
337			.iter()
338			.map(|eval| eval.to_single_packed())
339			.collect();
340		MultilinearExtension {
341			mu: self.mu,
342			evals: packed_evals,
343		}
344	}
345}
346
347const fn log2(v: usize) -> usize {
348	63 - (v as u64).leading_zeros() as usize
349}
350
351/// Type alias for the common pattern of a [`MultilinearExtension`] backed by borrowed data.
352pub type MultilinearExtensionBorrowed<'a, P> = MultilinearExtension<P, &'a [P]>;
353
354#[cfg(test)]
355mod tests {
356	use std::iter::repeat_with;
357
358	use binius_field::{
359		arch::OptimalUnderlier256b, BinaryField128b, BinaryField16b as F, BinaryField1b,
360		BinaryField32b, BinaryField8b, PackedBinaryField16x1b, PackedBinaryField16x8b,
361		PackedBinaryField32x1b, PackedBinaryField4x32b, PackedBinaryField8x16b as P,
362		PackedBinaryField8x1b,
363	};
364	use itertools::Itertools;
365	use rand::{rngs::StdRng, SeedableRng};
366
367	use super::*;
368	use crate::{tensor_prod_eq_ind, MultilinearQuery};
369
370	/// Expand the tensor product of the query values.
371	///
372	/// [`query`] is a sequence of field elements $z_0, ..., z_{k-1}$.
373	///
374	/// This naive implementation runs in O(k 2^k) time and O(1) space.
375	fn expand_query_naive<F: Field>(query: &[F]) -> Result<Vec<F>, Error> {
376		let result = (0..1 << query.len())
377			.map(|i| eval_basis(query, i))
378			.collect();
379		Ok(result)
380	}
381
382	/// Evaluates the Lagrange basis polynomial over the boolean hypercube at a queried point.
383	fn eval_basis<F: Field>(query: &[F], i: usize) -> F {
384		query
385			.iter()
386			.enumerate()
387			.map(|(j, &v)| if i & (1 << j) == 0 { F::ONE - v } else { v })
388			.product()
389	}
390
391	fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
392		let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
393		result[0] = P::set_single(P::Scalar::ONE);
394		tensor_prod_eq_ind(0, &mut result, p).unwrap();
395		MultilinearQuery::with_expansion(p.len(), result).unwrap()
396	}
397
398	#[test]
399	fn test_expand_query_impls_consistent() {
400		let mut rng = StdRng::seed_from_u64(0);
401		let q = repeat_with(|| Field::random(&mut rng))
402			.take(8)
403			.collect::<Vec<F>>();
404		let result1 = multilinear_query::<P>(&q);
405		let result2 = expand_query_naive(&q).unwrap();
406		assert_eq!(PackedField::iter_slice(result1.expansion()).collect_vec(), result2);
407	}
408
409	#[test]
410	fn test_new_from_values_correspondence() {
411		let mut rng = StdRng::seed_from_u64(0);
412		let evals = repeat_with(|| Field::random(&mut rng))
413			.take(256)
414			.collect::<Vec<F>>();
415		let poly1 = MultilinearExtension::from_values(evals.clone()).unwrap();
416		let poly2 = MultilinearExtension::new(8, evals).unwrap();
417
418		assert_eq!(poly1, poly2)
419	}
420
421	#[test]
422	fn test_evaluate_on_hypercube() {
423		let mut values = vec![F::ZERO; 64];
424		values
425			.iter_mut()
426			.enumerate()
427			.for_each(|(i, val)| *val = F::new(i as u16));
428
429		let poly = MultilinearExtension::from_values(values).unwrap();
430		for i in 0..64 {
431			let q = (0..6)
432				.map(|j| if (i >> j) & 1 != 0 { F::ONE } else { F::ZERO })
433				.collect::<Vec<_>>();
434			let multilin_query = multilinear_query::<P>(&q);
435			let result = poly.evaluate(multilin_query.to_ref()).unwrap();
436			assert_eq!(result, F::new(i));
437		}
438	}
439
440	fn evaluate_split<P>(
441		poly: MultilinearExtension<P>,
442		q: &[P::Scalar],
443		splits: &[usize],
444	) -> P::Scalar
445	where
446		P: PackedField + 'static,
447	{
448		assert_eq!(splits.iter().sum::<usize>(), poly.n_vars());
449
450		let mut partial_result = poly;
451		let mut index = q.len();
452		for split_vars in &splits[0..splits.len() - 1] {
453			let split_vars = *split_vars;
454			let query = multilinear_query(&q[index - split_vars..index]);
455			partial_result = partial_result
456				.evaluate_partial_high(query.to_ref())
457				.unwrap();
458			index -= split_vars;
459		}
460		let multilin_query = multilinear_query::<P>(&q[..index]);
461		partial_result.evaluate(multilin_query.to_ref()).unwrap()
462	}
463
464	#[test]
465	fn test_evaluate_split_is_correct() {
466		let mut rng = StdRng::seed_from_u64(0);
467		let evals = repeat_with(|| Field::random(&mut rng))
468			.take(256)
469			.collect::<Vec<F>>();
470		let poly = MultilinearExtension::from_values(evals).unwrap();
471		let q = repeat_with(|| Field::random(&mut rng))
472			.take(8)
473			.collect::<Vec<F>>();
474		let multilin_query = multilinear_query::<P>(&q);
475		let result1 = poly.evaluate(multilin_query.to_ref()).unwrap();
476		let result2 = evaluate_split(poly, &q, &[2, 3, 3]);
477		assert_eq!(result1, result2);
478	}
479
480	fn get_bits<P, PE>(values: &[PE], start_index: usize) -> Vec<PE::Scalar>
481	where
482		P: PackedField,
483		PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
484	{
485		let new_vals = values
486			.iter()
487			.flat_map(|v| {
488				(P::WIDTH * start_index..P::WIDTH * (start_index + 1))
489					.map(|i| v.get(i))
490					.collect::<Vec<_>>()
491			})
492			.collect::<Vec<_>>();
493
494		new_vals
495	}
496
497	#[test]
498	fn test_evaluate_middle_32b_to_16b() {
499		let mut rng = StdRng::seed_from_u64(0);
500		let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
501			.take(1 << 2)
502			.collect::<Vec<_>>();
503
504		let expected_lower_bits =
505			get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 0);
506
507		let expected_higher_bits =
508			get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 1);
509
510		let poly = MultilinearExtension::from_values(values).unwrap();
511
512		// Get query to project on lower bits.
513		let q_low = [<BinaryField1b as PackedField>::zero()];
514		// Get query to project on higher bits.
515		let q_hi = [<BinaryField1b as PackedField>::one()];
516
517		// Get expanded query to project on lower bits.
518		let query_low = multilinear_query::<BinaryField1b>(&q_low);
519		// Get expanded query to project on higher bits.
520		let query_hi = multilinear_query::<BinaryField1b>(&q_hi);
521
522		// Get lower bits evaluations.
523		let evals_low = poly.evaluate_partial(query_low.to_ref(), 4).unwrap();
524		let evals_low = evals_low.evals();
525
526		// Get higher bits evaluations.
527		let evals_hi = poly.evaluate_partial(query_hi.to_ref(), 4).unwrap();
528		let evals_hi = evals_hi.evals();
529
530		assert_eq!(evals_low, expected_lower_bits);
531		assert_eq!(evals_hi, expected_higher_bits);
532	}
533
534	#[test]
535	fn test_evaluate_middle_32b_to_8b() {
536		let mut rng = StdRng::seed_from_u64(0);
537		let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
538			.take(1 << 2)
539			.collect::<Vec<_>>();
540
541		let expected_first_quarter =
542			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 0);
543
544		let expected_second_quarter =
545			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 1);
546
547		let expected_third_quarter =
548			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 2);
549
550		let expected_fourth_quarter =
551			get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 3);
552
553		let poly = MultilinearExtension::from_values(values).unwrap();
554
555		// Get query to project on first quarter.
556		let q_first = [
557			<BinaryField1b as PackedField>::zero(),
558			<BinaryField1b as PackedField>::zero(),
559		];
560		// Get query to project on second quarter.
561		let q_second = [
562			<BinaryField1b as PackedField>::one(),
563			<BinaryField1b as PackedField>::zero(),
564		];
565		// Get query to project on third quarter.
566		let q_third = [
567			<BinaryField1b as PackedField>::zero(),
568			<BinaryField1b as PackedField>::one(),
569		];
570		// Get query to project on last quarter.
571		let q_fourth = [
572			<BinaryField1b as PackedField>::one(),
573			<BinaryField1b as PackedField>::one(),
574		];
575
576		// Get expanded query to project on first quarter.
577		let query_first = multilinear_query::<BinaryField1b>(&q_first);
578		// Get expanded query to project on second quarter.
579		let query_second = multilinear_query::<BinaryField1b>(&q_second);
580		// Get expanded query to project on third quarter.
581		let query_third = multilinear_query::<BinaryField1b>(&q_third);
582		// Get expanded query to project on last quarter.
583		let query_fourth = multilinear_query::<BinaryField1b>(&q_fourth);
584
585		// Get first quarter of bits evaluations.
586		let evals_first_quarter = poly.evaluate_partial(query_first.to_ref(), 3).unwrap();
587		let evals_first_quarter = evals_first_quarter.evals();
588		// Get second quarter evaluations.
589		let evals_second_quarter = poly.evaluate_partial(query_second.to_ref(), 3).unwrap();
590		let evals_second_quarter = evals_second_quarter.evals();
591		// Get third quarter evaluations.
592		let evals_third_quarter = poly.evaluate_partial(query_third.to_ref(), 3).unwrap();
593		let evals_third_quarter = evals_third_quarter.evals();
594		// Get last quarter evaluations.
595		let evals_fourth_quarter = poly.evaluate_partial(query_fourth.to_ref(), 3).unwrap();
596		let evals_fourth_quarter = evals_fourth_quarter.evals();
597
598		assert_eq!(evals_first_quarter, expected_first_quarter);
599		assert_eq!(evals_second_quarter, expected_second_quarter);
600		assert_eq!(evals_third_quarter, expected_third_quarter);
601		assert_eq!(evals_fourth_quarter, expected_fourth_quarter);
602	}
603
604	#[test]
605	fn test_evaluate_partial_match_evaluate_partial_low() {
606		type P = PackedBinaryField16x8b;
607		type F = BinaryField8b;
608
609		let mut rng = StdRng::seed_from_u64(0);
610
611		let n_vars: usize = 7;
612
613		let values = repeat_with(|| P::random(&mut rng))
614			.take(1 << (n_vars.saturating_sub(P::LOG_WIDTH)))
615			.collect();
616
617		let query_n_minus_1 = repeat_with(|| <F as PackedField>::random(&mut rng))
618			.take(n_vars - 1)
619			.collect::<Vec<_>>();
620
621		let (q_low, q_high) = query_n_minus_1.split_at(query_n_minus_1.len() / 2);
622
623		// Get expanded query to project on lower bits.
624		let query_low = multilinear_query::<P>(q_low);
625		// Get expanded query to project on higher bits.
626		let query_high = multilinear_query::<P>(q_high);
627
628		let me = MultilinearExtension::from_values(values).unwrap();
629
630		assert_eq!(
631			me.evaluate_partial_low(query_low.to_ref())
632				.unwrap()
633				.evaluate_partial_low(query_high.to_ref())
634				.unwrap(),
635			me.evaluate_partial(query_high.to_ref(), q_low.len())
636				.unwrap()
637				.evaluate_partial_low(query_low.to_ref())
638				.unwrap()
639		)
640	}
641
642	#[test]
643	fn test_evaluate_partial_high_packed() {
644		let mut rng = StdRng::seed_from_u64(0);
645		let evals = repeat_with(|| P::random(&mut rng))
646			.take(256 >> P::LOG_WIDTH)
647			.collect::<Vec<_>>();
648		let poly = MultilinearExtension::from_values(evals).unwrap();
649		let q = repeat_with(|| Field::random(&mut rng))
650			.take(8)
651			.collect::<Vec<BinaryField128b>>();
652		let multilin_query = multilinear_query::<BinaryField128b>(&q);
653
654		let expected = poly.evaluate(multilin_query.to_ref()).unwrap();
655
656		// The final split has a number of coefficients less than the packing width
657		let query_hi = multilinear_query::<BinaryField128b>(&q[1..]);
658		let partial_eval = poly.evaluate_partial_high(query_hi.to_ref()).unwrap();
659		assert!(partial_eval.n_vars() < P::LOG_WIDTH);
660
661		let query_lo = multilinear_query::<BinaryField128b>(&q[..1]);
662		let eval = partial_eval.evaluate(query_lo.to_ref()).unwrap();
663		assert_eq!(eval, expected);
664	}
665
666	#[test]
667	fn test_evaluate_partial_low_high_smaller_than_packed_width() {
668		type P = PackedBinaryField16x8b;
669
670		type F = BinaryField8b;
671
672		let n_vars = 3;
673
674		let mut rng = StdRng::seed_from_u64(0);
675
676		let values = repeat_with(|| Field::random(&mut rng))
677			.take(1 << n_vars)
678			.collect::<Vec<F>>();
679
680		let q = repeat_with(|| Field::random(&mut rng))
681			.take(n_vars)
682			.collect::<Vec<F>>();
683
684		let query = multilinear_query::<P>(&q);
685
686		let packed = P::from_scalars(values);
687		let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
688
689		let eval = me.evaluate(&query).unwrap();
690
691		let query_low = multilinear_query::<P>(&q[..n_vars - 1]);
692		let query_high = multilinear_query::<P>(&q[n_vars - 1..]);
693
694		let eval_l_h = me
695			.evaluate_partial_high(&query_high)
696			.unwrap()
697			.evaluate_partial_low(&query_low)
698			.unwrap()
699			.evals()[0]
700			.get(0);
701
702		assert_eq!(eval, eval_l_h);
703	}
704
705	#[test]
706	fn test_evaluate_on_hypercube_small_than_packed_width() {
707		type P = PackedBinaryField16x8b;
708
709		type F = BinaryField8b;
710
711		let n_vars = 3;
712
713		let mut rng = StdRng::seed_from_u64(0);
714
715		let values = repeat_with(|| Field::random(&mut rng))
716			.take(1 << n_vars)
717			.collect::<Vec<F>>();
718
719		let packed = P::from_scalars(values.clone());
720
721		let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
722
723		assert_eq!(me.evaluate_on_hypercube(1).unwrap(), values[1]);
724
725		assert!(me.evaluate_on_hypercube(1 << n_vars).is_err());
726	}
727
728	#[test]
729	fn test_evaluate_partial_high_low_evaluate_consistent() {
730		let mut rng = StdRng::seed_from_u64(0);
731		let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
732			.take(1 << 8)
733			.collect();
734
735		let me = MultilinearExtension::from_values(values).unwrap();
736
737		let q = repeat_with(|| <BinaryField32b as PackedField>::random(&mut rng))
738			.take(me.n_vars())
739			.collect::<Vec<_>>();
740
741		let query = multilinear_query(&q);
742
743		let eval = me
744			.evaluate::<<PackedBinaryField4x32b as PackedField>::Scalar, PackedBinaryField4x32b>(
745				query.to_ref(),
746			)
747			.unwrap();
748
749		assert_eq!(
750			me.evaluate_partial_low::<PackedBinaryField4x32b>(query.to_ref())
751				.unwrap()
752				.evals[0]
753				.get(0),
754			eval
755		);
756		assert_eq!(
757			me.evaluate_partial_high::<PackedBinaryField4x32b>(query.to_ref())
758				.unwrap()
759				.evals[0]
760				.get(0),
761			eval
762		);
763	}
764
765	#[test]
766	fn test_evaluate_partial_low_single_and_multiple_var_consistent() {
767		let mut rng = StdRng::seed_from_u64(0);
768		let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
769			.take(1 << 8)
770			.collect();
771
772		let mle = MultilinearExtension::from_values(values).unwrap();
773		let r1 = <BinaryField32b as PackedField>::random(&mut rng);
774		let r2 = <BinaryField32b as PackedField>::random(&mut rng);
775
776		let eval_1: MultilinearExtension<PackedBinaryField4x32b> = mle
777			.evaluate_partial_low::<PackedBinaryField4x32b>(multilinear_query(&[r1]).to_ref())
778			.unwrap()
779			.evaluate_partial_low(multilinear_query(&[r2]).to_ref())
780			.unwrap();
781		let eval_2 = mle
782			.evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref())
783			.unwrap();
784		assert_eq!(eval_1, eval_2);
785	}
786
787	#[test]
788	fn test_new_mle_with_tiny_nvars() {
789		MultilinearExtension::new(
790			1,
791			vec![PackedType::<OptimalUnderlier256b, BinaryField32b>::one()],
792		)
793		.unwrap();
794	}
795}