1use std::{fmt::Debug, ops::Deref};
4
5use binius_field::{
6 ExtensionField, Field, PackedField,
7 as_packed_field::{AsSinglePacked, PackScalar, PackedType},
8 underlier::UnderlierType,
9 util::inner_product_par,
10};
11use binius_utils::bail;
12use bytemuck::zeroed_vec;
13use tracing::instrument;
14
15use crate::{
16 Error, MultilinearQueryRef, PackingDeref, fold::fold_left, fold_middle, fold_right, zero_pad,
17};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct MultilinearExtension<P: PackedField, Data: Deref<Target = [P]> = Vec<P>> {
27 mu: usize,
29 evals: Data,
31}
32
33impl<P: PackedField> MultilinearExtension<P> {
34 pub fn zeros(n_vars: usize) -> Result<Self, Error> {
35 if n_vars < P::LOG_WIDTH {
36 bail!(Error::ArgumentRangeError {
37 arg: "n_vars".into(),
38 range: P::LOG_WIDTH..32,
39 });
40 }
41 Ok(Self {
42 mu: n_vars,
43 evals: vec![P::default(); 1 << (n_vars - P::LOG_WIDTH)],
44 })
45 }
46
47 pub fn from_values(v: Vec<P>) -> Result<Self, Error> {
48 Self::from_values_generic(v)
49 }
50
51 pub fn into_evals(self) -> Vec<P> {
52 self.evals
53 }
54}
55
56impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
57 pub fn from_values_generic(v: Data) -> Result<Self, Error> {
58 if !v.len().is_power_of_two() {
59 bail!(Error::PowerOfTwoLengthRequired);
60 }
61 let mu = log2(v.len()) + P::LOG_WIDTH;
62
63 Ok(Self { mu, evals: v })
64 }
65
66 pub fn new(n_vars: usize, v: Data) -> Result<Self, Error> {
67 if !v.len().is_power_of_two() {
68 bail!(Error::PowerOfTwoLengthRequired);
69 }
70
71 if n_vars < P::LOG_WIDTH {
72 if v.len() != 1 {
73 bail!(Error::IncorrectNumberOfVariables {
74 expected: n_vars,
75 actual: P::LOG_WIDTH + log2(v.len())
76 });
77 }
78 } else if P::LOG_WIDTH + log2(v.len()) != n_vars {
79 bail!(Error::IncorrectNumberOfVariables {
80 expected: n_vars,
81 actual: P::LOG_WIDTH + log2(v.len())
82 });
83 }
84
85 Ok(Self {
86 mu: n_vars,
87 evals: v,
88 })
89 }
90}
91
92impl<U, F, Data> MultilinearExtension<PackedType<U, F>, PackingDeref<U, F, Data>>
93where
94 U: UnderlierType + PackScalar<F>,
97 F: Field,
98 Data: Deref<Target = [U]>,
99{
100 pub fn from_underliers(v: Data) -> Result<Self, Error> {
101 Self::from_values_generic(PackingDeref::new(v))
102 }
103}
104
105impl<'a, P: PackedField> MultilinearExtension<P, &'a [P]> {
106 pub fn from_values_slice(v: &'a [P]) -> Result<Self, Error> {
107 if !v.len().is_power_of_two() {
108 bail!(Error::PowerOfTwoLengthRequired);
109 }
110 let mu = log2(v.len() * P::WIDTH);
111 Ok(Self { mu, evals: v })
112 }
113}
114
115impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
116 pub const fn n_vars(&self) -> usize {
117 self.mu
118 }
119
120 pub const fn size(&self) -> usize {
121 1 << self.mu
122 }
123
124 pub fn evals(&self) -> &[P] {
125 &self.evals
126 }
127
128 pub fn to_ref(&self) -> MultilinearExtension<P, &[P]> {
129 MultilinearExtension {
130 mu: self.mu,
131 evals: self.evals(),
132 }
133 }
134
135 pub fn packed_evaluate_on_hypercube(&self, index: usize) -> Result<P, Error> {
142 self.evals()
143 .get(index)
144 .ok_or(Error::HypercubeIndexOutOfRange { index })
145 .copied()
146 }
147
148 pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
149 if self.size() <= index {
150 bail!(Error::HypercubeIndexOutOfRange { index })
151 }
152
153 let subcube_eval = self.packed_evaluate_on_hypercube(index / P::WIDTH)?;
154 Ok(subcube_eval.get(index % P::WIDTH))
155 }
156}
157
158impl<P, Data> MultilinearExtension<P, Data>
159where
160 P: PackedField,
161 Data: Deref<Target = [P]> + Send + Sync,
162{
163 pub fn evaluate<'a, FE, PE>(
164 &self,
165 query: impl Into<MultilinearQueryRef<'a, PE>>,
166 ) -> Result<FE, Error>
167 where
168 FE: ExtensionField<P::Scalar>,
169 PE: PackedField<Scalar = FE>,
170 {
171 let query = query.into();
172 if self.mu != query.n_vars() {
173 bail!(Error::IncorrectQuerySize {
174 expected: self.mu,
175 actual: query.n_vars()
176 });
177 }
178
179 if self.mu < P::LOG_WIDTH || query.n_vars() < PE::LOG_WIDTH {
180 let evals = PackedField::iter_slice(self.evals())
181 .take(self.size())
182 .collect::<Vec<P::Scalar>>();
183 let queries = PackedField::iter_slice(query.expansion())
184 .take(1 << query.n_vars())
185 .collect::<Vec<PE::Scalar>>();
186 Ok(inner_product_par(&queries, &evals))
187 } else {
188 Ok(inner_product_par(query.expansion(), &self.evals))
189 }
190 }
191
192 #[instrument("MultilinearExtension::evaluate_partial", skip_all, level = "debug")]
193 pub fn evaluate_partial<'a, PE>(
194 &self,
195 query: impl Into<MultilinearQueryRef<'a, PE>>,
196 start_index: usize,
197 ) -> Result<MultilinearExtension<PE>, Error>
198 where
199 PE: PackedField,
200 PE::Scalar: ExtensionField<P::Scalar>,
201 {
202 let query = query.into();
203 if start_index + query.n_vars() > self.mu {
204 bail!(Error::IncorrectStartIndex { expected: self.mu })
205 }
206
207 if start_index == 0 {
208 return self.evaluate_partial_low(query);
209 } else if start_index + query.n_vars() == self.mu {
210 return self.evaluate_partial_high(query);
211 }
212
213 if self.mu < query.n_vars() {
214 bail!(Error::IncorrectQuerySize {
215 expected: self.mu,
216 actual: query.n_vars()
217 });
218 }
219
220 let new_n_vars = self.mu - query.n_vars();
221 let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
222 let mut result_evals = Vec::with_capacity(result_evals_len);
223
224 fold_middle(
225 self.evals(),
226 self.mu,
227 query.expansion(),
228 query.n_vars(),
229 start_index,
230 result_evals.spare_capacity_mut(),
231 )?;
232 unsafe {
233 result_evals.set_len(result_evals_len);
234 }
235
236 MultilinearExtension::new(new_n_vars, result_evals)
237 }
238
239 #[instrument(
249 "MultilinearExtension::evaluate_partial_high",
250 skip_all,
251 level = "debug"
252 )]
253 pub fn evaluate_partial_high<'a, PE>(
254 &self,
255 query: impl Into<MultilinearQueryRef<'a, PE>>,
256 ) -> Result<MultilinearExtension<PE>, Error>
257 where
258 PE: PackedField,
259 PE::Scalar: ExtensionField<P::Scalar>,
260 {
261 let query = query.into();
262
263 let new_n_vars = self.mu.saturating_sub(query.n_vars());
264 let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
265 let mut result_evals = Vec::with_capacity(result_evals_len);
266
267 fold_left(
268 self.evals(),
269 self.mu,
270 query.expansion(),
271 query.n_vars(),
272 result_evals.spare_capacity_mut(),
273 )?;
274 unsafe {
275 result_evals.set_len(result_evals_len);
276 }
277
278 MultilinearExtension::new(new_n_vars, result_evals)
279 }
280
281 #[instrument(
291 "MultilinearExtension::evaluate_partial_low",
292 skip_all,
293 level = "trace"
294 )]
295 pub fn evaluate_partial_low<'a, PE>(
296 &self,
297 query: impl Into<MultilinearQueryRef<'a, PE>>,
298 ) -> Result<MultilinearExtension<PE>, Error>
299 where
300 PE: PackedField,
301 PE::Scalar: ExtensionField<P::Scalar>,
302 {
303 let query = query.into();
304
305 if self.mu < query.n_vars() {
306 bail!(Error::IncorrectQuerySize {
307 expected: self.mu,
308 actual: query.n_vars()
309 });
310 }
311
312 let new_n_vars = self.mu - query.n_vars();
313
314 let mut result =
315 zeroed_vec(1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)));
316 self.evaluate_partial_low_into(query, &mut result)?;
317 MultilinearExtension::new(new_n_vars, result)
318 }
319
320 pub fn evaluate_partial_low_into<PE>(
330 &self,
331 query: MultilinearQueryRef<PE>,
332 out: &mut [PE],
333 ) -> Result<(), Error>
334 where
335 PE: PackedField,
336 PE::Scalar: ExtensionField<P::Scalar>,
337 {
338 fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out)
341 }
342
343 pub fn zero_pad<PE>(
344 &self,
345 n_pad_vars: usize,
346 start_index: usize,
347 nonzero_index: usize,
348 ) -> Result<MultilinearExtension<PE>, Error>
349 where
350 PE: PackedField,
351 PE::Scalar: ExtensionField<P::Scalar>,
352 {
353 let init_n_vars = self.mu;
354 if start_index > init_n_vars {
355 bail!(Error::IncorrectStartIndexZeroPad { expected: self.mu })
356 }
357 let new_n_vars = init_n_vars + n_pad_vars;
358 if nonzero_index >= 1 << n_pad_vars {
359 bail!(Error::IncorrectNonZeroIndex {
360 expected: 1 << n_pad_vars,
361 });
362 }
363
364 let mut result = zeroed_vec(1 << new_n_vars);
365
366 zero_pad(&self.evals, init_n_vars, new_n_vars, start_index, nonzero_index, &mut result)?;
367 MultilinearExtension::new(new_n_vars, result)
368 }
369}
370
371impl<F: Field + AsSinglePacked, Data: Deref<Target = [F]>> MultilinearExtension<F, Data> {
372 pub fn to_single_packed(self) -> MultilinearExtension<F::Packed> {
375 let packed_evals = self
376 .evals
377 .iter()
378 .map(|eval| eval.to_single_packed())
379 .collect();
380 MultilinearExtension {
381 mu: self.mu,
382 evals: packed_evals,
383 }
384 }
385}
386
387const fn log2(v: usize) -> usize {
388 63 - (v as u64).leading_zeros() as usize
389}
390
391pub type MultilinearExtensionBorrowed<'a, P> = MultilinearExtension<P, &'a [P]>;
393
394#[cfg(test)]
395mod tests {
396 use std::iter::repeat_with;
397
398 use binius_field::{
399 BinaryField1b, BinaryField8b, BinaryField16b as F, BinaryField32b, BinaryField128b,
400 PackedBinaryField4x32b, PackedBinaryField8x1b, PackedBinaryField8x16b as P,
401 PackedBinaryField16x1b, PackedBinaryField16x8b, PackedBinaryField32x1b,
402 arch::OptimalUnderlier256b,
403 };
404 use itertools::Itertools;
405 use rand::{SeedableRng, rngs::StdRng};
406
407 use super::*;
408 use crate::{MultilinearQuery, tensor_prod_eq_ind};
409
410 fn expand_query_naive<F: Field>(query: &[F]) -> Result<Vec<F>, Error> {
416 let result = (0..1 << query.len())
417 .map(|i| eval_basis(query, i))
418 .collect();
419 Ok(result)
420 }
421
422 fn eval_basis<F: Field>(query: &[F], i: usize) -> F {
424 query
425 .iter()
426 .enumerate()
427 .map(|(j, &v)| if i & (1 << j) == 0 { F::ONE - v } else { v })
428 .product()
429 }
430
431 fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
432 let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
433 result[0] = P::set_single(P::Scalar::ONE);
434 tensor_prod_eq_ind(0, &mut result, p).unwrap();
435 MultilinearQuery::with_expansion(p.len(), result).unwrap()
436 }
437
438 #[test]
439 fn test_expand_query_impls_consistent() {
440 let mut rng = StdRng::seed_from_u64(0);
441 let q = repeat_with(|| Field::random(&mut rng))
442 .take(8)
443 .collect::<Vec<F>>();
444 let result1 = multilinear_query::<P>(&q);
445 let result2 = expand_query_naive(&q).unwrap();
446 assert_eq!(PackedField::iter_slice(result1.expansion()).collect_vec(), result2);
447 }
448
449 #[test]
450 fn test_new_from_values_correspondence() {
451 let mut rng = StdRng::seed_from_u64(0);
452 let evals = repeat_with(|| Field::random(&mut rng))
453 .take(256)
454 .collect::<Vec<F>>();
455 let poly1 = MultilinearExtension::from_values(evals.clone()).unwrap();
456 let poly2 = MultilinearExtension::new(8, evals).unwrap();
457
458 assert_eq!(poly1, poly2)
459 }
460
461 #[test]
462 fn test_evaluate_on_hypercube() {
463 let mut values = vec![F::ZERO; 64];
464 values
465 .iter_mut()
466 .enumerate()
467 .for_each(|(i, val)| *val = F::new(i as u16));
468
469 let poly = MultilinearExtension::from_values(values).unwrap();
470 for i in 0..64 {
471 let q = (0..6)
472 .map(|j| if (i >> j) & 1 != 0 { F::ONE } else { F::ZERO })
473 .collect::<Vec<_>>();
474 let multilin_query = multilinear_query::<P>(&q);
475 let result = poly.evaluate(multilin_query.to_ref()).unwrap();
476 assert_eq!(result, F::new(i));
477 }
478 }
479
480 fn evaluate_split<P>(
481 poly: MultilinearExtension<P>,
482 q: &[P::Scalar],
483 splits: &[usize],
484 ) -> P::Scalar
485 where
486 P: PackedField + 'static,
487 {
488 assert_eq!(splits.iter().sum::<usize>(), poly.n_vars());
489
490 let mut partial_result = poly;
491 let mut index = q.len();
492 for split_vars in &splits[0..splits.len() - 1] {
493 let split_vars = *split_vars;
494 let query = multilinear_query(&q[index - split_vars..index]);
495 partial_result = partial_result
496 .evaluate_partial_high(query.to_ref())
497 .unwrap();
498 index -= split_vars;
499 }
500 let multilin_query = multilinear_query::<P>(&q[..index]);
501 partial_result.evaluate(multilin_query.to_ref()).unwrap()
502 }
503
504 #[test]
505 fn test_evaluate_split_is_correct() {
506 let mut rng = StdRng::seed_from_u64(0);
507 let evals = repeat_with(|| Field::random(&mut rng))
508 .take(256)
509 .collect::<Vec<F>>();
510 let poly = MultilinearExtension::from_values(evals).unwrap();
511 let q = repeat_with(|| Field::random(&mut rng))
512 .take(8)
513 .collect::<Vec<F>>();
514 let multilin_query = multilinear_query::<P>(&q);
515 let result1 = poly.evaluate(multilin_query.to_ref()).unwrap();
516 let result2 = evaluate_split(poly, &q, &[2, 3, 3]);
517 assert_eq!(result1, result2);
518 }
519
520 fn get_bits<P, PE>(values: &[PE], start_index: usize) -> Vec<PE::Scalar>
521 where
522 P: PackedField,
523 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
524 {
525 values
526 .iter()
527 .flat_map(|v| {
528 (P::WIDTH * start_index..P::WIDTH * (start_index + 1))
529 .map(|i| v.get(i))
530 .collect::<Vec<_>>()
531 })
532 .collect::<Vec<_>>()
533 }
534
535 #[test]
536 fn test_evaluate_middle_32b_to_16b() {
537 let mut rng = StdRng::seed_from_u64(0);
538 let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
539 .take(1 << 2)
540 .collect::<Vec<_>>();
541
542 let expected_lower_bits =
543 get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 0);
544
545 let expected_higher_bits =
546 get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 1);
547
548 let poly = MultilinearExtension::from_values(values).unwrap();
549
550 let q_low = [<BinaryField1b as PackedField>::zero()];
552 let q_hi = [<BinaryField1b as PackedField>::one()];
554
555 let query_low = multilinear_query::<BinaryField1b>(&q_low);
557 let query_hi = multilinear_query::<BinaryField1b>(&q_hi);
559
560 let evals_low = poly.evaluate_partial(query_low.to_ref(), 4).unwrap();
562 let evals_low = evals_low.evals();
563
564 let evals_hi = poly.evaluate_partial(query_hi.to_ref(), 4).unwrap();
566 let evals_hi = evals_hi.evals();
567
568 assert_eq!(evals_low, expected_lower_bits);
569 assert_eq!(evals_hi, expected_higher_bits);
570 }
571
572 #[test]
573 fn test_evaluate_middle_32b_to_8b() {
574 let mut rng = StdRng::seed_from_u64(0);
575 let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
576 .take(1 << 2)
577 .collect::<Vec<_>>();
578
579 let expected_first_quarter =
580 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 0);
581
582 let expected_second_quarter =
583 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 1);
584
585 let expected_third_quarter =
586 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 2);
587
588 let expected_fourth_quarter =
589 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 3);
590
591 let poly = MultilinearExtension::from_values(values).unwrap();
592
593 let q_first = [
595 <BinaryField1b as PackedField>::zero(),
596 <BinaryField1b as PackedField>::zero(),
597 ];
598 let q_second = [
600 <BinaryField1b as PackedField>::one(),
601 <BinaryField1b as PackedField>::zero(),
602 ];
603 let q_third = [
605 <BinaryField1b as PackedField>::zero(),
606 <BinaryField1b as PackedField>::one(),
607 ];
608 let q_fourth = [
610 <BinaryField1b as PackedField>::one(),
611 <BinaryField1b as PackedField>::one(),
612 ];
613
614 let query_first = multilinear_query::<BinaryField1b>(&q_first);
616 let query_second = multilinear_query::<BinaryField1b>(&q_second);
618 let query_third = multilinear_query::<BinaryField1b>(&q_third);
620 let query_fourth = multilinear_query::<BinaryField1b>(&q_fourth);
622
623 let evals_first_quarter = poly.evaluate_partial(query_first.to_ref(), 3).unwrap();
625 let evals_first_quarter = evals_first_quarter.evals();
626 let evals_second_quarter = poly.evaluate_partial(query_second.to_ref(), 3).unwrap();
628 let evals_second_quarter = evals_second_quarter.evals();
629 let evals_third_quarter = poly.evaluate_partial(query_third.to_ref(), 3).unwrap();
631 let evals_third_quarter = evals_third_quarter.evals();
632 let evals_fourth_quarter = poly.evaluate_partial(query_fourth.to_ref(), 3).unwrap();
634 let evals_fourth_quarter = evals_fourth_quarter.evals();
635
636 assert_eq!(evals_first_quarter, expected_first_quarter);
637 assert_eq!(evals_second_quarter, expected_second_quarter);
638 assert_eq!(evals_third_quarter, expected_third_quarter);
639 assert_eq!(evals_fourth_quarter, expected_fourth_quarter);
640 }
641
642 #[test]
643 fn test_zeropad_8b_to_32b_project() {
644 let mut rng = StdRng::seed_from_u64(0);
645 let values = repeat_with(|| PackedBinaryField8x1b::random(&mut rng))
646 .take(1 << 2)
647 .collect::<Vec<_>>();
648 let expected_out = PackedBinaryField8x1b::iter_slice(&values).collect::<Vec<_>>();
649
650 let poly = MultilinearExtension::from_values(values).unwrap();
651
652 let n_pad_vars = 2;
654 let start_index = 3;
656 let nonzero_index = 0;
658 let padded_poly = poly
660 .zero_pad::<BinaryField1b>(n_pad_vars, start_index, nonzero_index)
661 .unwrap();
662
663 let query_first = multilinear_query::<BinaryField1b>(&[
665 <BinaryField1b as PackedField>::zero(),
666 <BinaryField1b as PackedField>::zero(),
667 ]);
668 let projected_poly = padded_poly
669 .evaluate_partial(query_first.to_ref(), 3)
670 .unwrap();
671
672 assert_eq!(expected_out, projected_poly.evals());
674 }
675
676 #[test]
677 fn test_evaluate_partial_match_evaluate_partial_low() {
678 type P = PackedBinaryField16x8b;
679 type F = BinaryField8b;
680
681 let mut rng = StdRng::seed_from_u64(0);
682
683 let n_vars: usize = 7;
684
685 let values = repeat_with(|| P::random(&mut rng))
686 .take(1 << (n_vars.saturating_sub(P::LOG_WIDTH)))
687 .collect();
688
689 let query_n_minus_1 = repeat_with(|| <F as PackedField>::random(&mut rng))
690 .take(n_vars - 1)
691 .collect::<Vec<_>>();
692
693 let (q_low, q_high) = query_n_minus_1.split_at(query_n_minus_1.len() / 2);
694
695 let query_low = multilinear_query::<P>(q_low);
697 let query_high = multilinear_query::<P>(q_high);
699
700 let me = MultilinearExtension::from_values(values).unwrap();
701
702 assert_eq!(
703 me.evaluate_partial_low(query_low.to_ref())
704 .unwrap()
705 .evaluate_partial_low(query_high.to_ref())
706 .unwrap(),
707 me.evaluate_partial(query_high.to_ref(), q_low.len())
708 .unwrap()
709 .evaluate_partial_low(query_low.to_ref())
710 .unwrap()
711 )
712 }
713
714 #[test]
715 fn test_evaluate_partial_high_packed() {
716 let mut rng = StdRng::seed_from_u64(0);
717 let evals = repeat_with(|| P::random(&mut rng))
718 .take(256 >> P::LOG_WIDTH)
719 .collect::<Vec<_>>();
720 let poly = MultilinearExtension::from_values(evals).unwrap();
721 let q = repeat_with(|| Field::random(&mut rng))
722 .take(8)
723 .collect::<Vec<BinaryField128b>>();
724 let multilin_query = multilinear_query::<BinaryField128b>(&q);
725
726 let expected = poly.evaluate(multilin_query.to_ref()).unwrap();
727
728 let query_hi = multilinear_query::<BinaryField128b>(&q[1..]);
730 let partial_eval = poly.evaluate_partial_high(query_hi.to_ref()).unwrap();
731 assert!(partial_eval.n_vars() < P::LOG_WIDTH);
732
733 let query_lo = multilinear_query::<BinaryField128b>(&q[..1]);
734 let eval = partial_eval.evaluate(query_lo.to_ref()).unwrap();
735 assert_eq!(eval, expected);
736 }
737
738 #[test]
739 fn test_evaluate_partial_low_high_smaller_than_packed_width() {
740 type P = PackedBinaryField16x8b;
741
742 type F = BinaryField8b;
743
744 let n_vars = 3;
745
746 let mut rng = StdRng::seed_from_u64(0);
747
748 let values = repeat_with(|| Field::random(&mut rng))
749 .take(1 << n_vars)
750 .collect::<Vec<F>>();
751
752 let q = repeat_with(|| Field::random(&mut rng))
753 .take(n_vars)
754 .collect::<Vec<F>>();
755
756 let query = multilinear_query::<P>(&q);
757
758 let packed = P::from_scalars(values);
759 let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
760
761 let eval = me.evaluate(&query).unwrap();
762
763 let query_low = multilinear_query::<P>(&q[..n_vars - 1]);
764 let query_high = multilinear_query::<P>(&q[n_vars - 1..]);
765
766 let eval_l_h = me
767 .evaluate_partial_high(&query_high)
768 .unwrap()
769 .evaluate_partial_low(&query_low)
770 .unwrap()
771 .evals()[0]
772 .get(0);
773
774 assert_eq!(eval, eval_l_h);
775 }
776
777 #[test]
778 fn test_evaluate_on_hypercube_small_than_packed_width() {
779 type P = PackedBinaryField16x8b;
780
781 type F = BinaryField8b;
782
783 let n_vars = 3;
784
785 let mut rng = StdRng::seed_from_u64(0);
786
787 let values = repeat_with(|| Field::random(&mut rng))
788 .take(1 << n_vars)
789 .collect::<Vec<F>>();
790
791 let packed = P::from_scalars(values.clone());
792
793 let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
794
795 assert_eq!(me.evaluate_on_hypercube(1).unwrap(), values[1]);
796
797 assert!(me.evaluate_on_hypercube(1 << n_vars).is_err());
798 }
799
800 #[test]
801 fn test_evaluate_partial_high_low_evaluate_consistent() {
802 let mut rng = StdRng::seed_from_u64(0);
803 let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
804 .take(1 << 8)
805 .collect();
806
807 let me = MultilinearExtension::from_values(values).unwrap();
808
809 let q = repeat_with(|| <BinaryField32b as PackedField>::random(&mut rng))
810 .take(me.n_vars())
811 .collect::<Vec<_>>();
812
813 let query = multilinear_query(&q);
814
815 let eval = me
816 .evaluate::<<PackedBinaryField4x32b as PackedField>::Scalar, PackedBinaryField4x32b>(
817 query.to_ref(),
818 )
819 .unwrap();
820
821 assert_eq!(
822 me.evaluate_partial_low::<PackedBinaryField4x32b>(query.to_ref())
823 .unwrap()
824 .evals[0]
825 .get(0),
826 eval
827 );
828 assert_eq!(
829 me.evaluate_partial_high::<PackedBinaryField4x32b>(query.to_ref())
830 .unwrap()
831 .evals[0]
832 .get(0),
833 eval
834 );
835 }
836
837 #[test]
838 fn test_evaluate_partial_low_single_and_multiple_var_consistent() {
839 let mut rng = StdRng::seed_from_u64(0);
840 let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
841 .take(1 << 8)
842 .collect();
843
844 let mle = MultilinearExtension::from_values(values).unwrap();
845 let r1 = <BinaryField32b as PackedField>::random(&mut rng);
846 let r2 = <BinaryField32b as PackedField>::random(&mut rng);
847
848 let eval_1: MultilinearExtension<PackedBinaryField4x32b> = mle
849 .evaluate_partial_low::<PackedBinaryField4x32b>(multilinear_query(&[r1]).to_ref())
850 .unwrap()
851 .evaluate_partial_low(multilinear_query(&[r2]).to_ref())
852 .unwrap();
853 let eval_2 = mle
854 .evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref())
855 .unwrap();
856 assert_eq!(eval_1, eval_2);
857 }
858
859 #[test]
860 fn test_new_mle_with_tiny_nvars() {
861 MultilinearExtension::new(
862 1,
863 vec![PackedType::<OptimalUnderlier256b, BinaryField32b>::one()],
864 )
865 .unwrap();
866 }
867}