1use std::{fmt::Debug, ops::Deref};
4
5use binius_field::{
6 as_packed_field::{AsSinglePacked, PackScalar, PackedType},
7 underlier::UnderlierType,
8 util::inner_product_par,
9 ExtensionField, Field, PackedField,
10};
11use binius_utils::bail;
12use bytemuck::zeroed_vec;
13use tracing::instrument;
14
15use crate::{
16 fold::fold_left, fold_middle, fold_right, zero_pad, Error, MultilinearQueryRef, PackingDeref,
17};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct MultilinearExtension<P: PackedField, Data: Deref<Target = [P]> = Vec<P>> {
27 mu: usize,
29 evals: Data,
31}
32
33impl<P: PackedField> MultilinearExtension<P> {
34 pub fn zeros(n_vars: usize) -> Result<Self, Error> {
35 if n_vars < P::LOG_WIDTH {
36 bail!(Error::ArgumentRangeError {
37 arg: "n_vars".into(),
38 range: P::LOG_WIDTH..32,
39 });
40 }
41 Ok(Self {
42 mu: n_vars,
43 evals: vec![P::default(); 1 << (n_vars - P::LOG_WIDTH)],
44 })
45 }
46
47 pub fn from_values(v: Vec<P>) -> Result<Self, Error> {
48 Self::from_values_generic(v)
49 }
50
51 pub fn into_evals(self) -> Vec<P> {
52 self.evals
53 }
54}
55
56impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
57 pub fn from_values_generic(v: Data) -> Result<Self, Error> {
58 if !v.len().is_power_of_two() {
59 bail!(Error::PowerOfTwoLengthRequired);
60 }
61 let mu = log2(v.len()) + P::LOG_WIDTH;
62
63 Ok(Self { mu, evals: v })
64 }
65
66 pub fn new(n_vars: usize, v: Data) -> Result<Self, Error> {
67 if !v.len().is_power_of_two() {
68 bail!(Error::PowerOfTwoLengthRequired);
69 }
70
71 if n_vars < P::LOG_WIDTH {
72 if v.len() != 1 {
73 bail!(Error::IncorrectNumberOfVariables {
74 expected: n_vars,
75 actual: P::LOG_WIDTH + log2(v.len())
76 });
77 }
78 } else if P::LOG_WIDTH + log2(v.len()) != n_vars {
79 bail!(Error::IncorrectNumberOfVariables {
80 expected: n_vars,
81 actual: P::LOG_WIDTH + log2(v.len())
82 });
83 }
84
85 Ok(Self {
86 mu: n_vars,
87 evals: v,
88 })
89 }
90}
91
92impl<U, F, Data> MultilinearExtension<PackedType<U, F>, PackingDeref<U, F, Data>>
93where
94 U: UnderlierType + PackScalar<F>,
96 F: Field,
97 Data: Deref<Target = [U]>,
98{
99 pub fn from_underliers(v: Data) -> Result<Self, Error> {
100 Self::from_values_generic(PackingDeref::new(v))
101 }
102}
103
104impl<'a, P: PackedField> MultilinearExtension<P, &'a [P]> {
105 pub fn from_values_slice(v: &'a [P]) -> Result<Self, Error> {
106 if !v.len().is_power_of_two() {
107 bail!(Error::PowerOfTwoLengthRequired);
108 }
109 let mu = log2(v.len() * P::WIDTH);
110 Ok(Self { mu, evals: v })
111 }
112}
113
114impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
115 pub const fn n_vars(&self) -> usize {
116 self.mu
117 }
118
119 pub const fn size(&self) -> usize {
120 1 << self.mu
121 }
122
123 pub fn evals(&self) -> &[P] {
124 &self.evals
125 }
126
127 pub fn to_ref(&self) -> MultilinearExtension<P, &[P]> {
128 MultilinearExtension {
129 mu: self.mu,
130 evals: self.evals(),
131 }
132 }
133
134 pub fn packed_evaluate_on_hypercube(&self, index: usize) -> Result<P, Error> {
141 self.evals()
142 .get(index)
143 .ok_or(Error::HypercubeIndexOutOfRange { index })
144 .copied()
145 }
146
147 pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
148 if self.size() <= index {
149 bail!(Error::HypercubeIndexOutOfRange { index })
150 }
151
152 let subcube_eval = self.packed_evaluate_on_hypercube(index / P::WIDTH)?;
153 Ok(subcube_eval.get(index % P::WIDTH))
154 }
155}
156
157impl<P, Data> MultilinearExtension<P, Data>
158where
159 P: PackedField,
160 Data: Deref<Target = [P]> + Send + Sync,
161{
162 pub fn evaluate<'a, FE, PE>(
163 &self,
164 query: impl Into<MultilinearQueryRef<'a, PE>>,
165 ) -> Result<FE, Error>
166 where
167 FE: ExtensionField<P::Scalar>,
168 PE: PackedField<Scalar = FE>,
169 {
170 let query = query.into();
171 if self.mu != query.n_vars() {
172 bail!(Error::IncorrectQuerySize { expected: self.mu });
173 }
174
175 if self.mu < P::LOG_WIDTH || query.n_vars() < PE::LOG_WIDTH {
176 let evals = PackedField::iter_slice(self.evals())
177 .take(self.size())
178 .collect::<Vec<P::Scalar>>();
179 let querys = PackedField::iter_slice(query.expansion())
180 .take(1 << query.n_vars())
181 .collect::<Vec<PE::Scalar>>();
182 Ok(inner_product_par(&querys, &evals))
183 } else {
184 Ok(inner_product_par(query.expansion(), &self.evals))
185 }
186 }
187
188 #[instrument("MultilinearExtension::evaluate_partial", skip_all, level = "debug")]
189 pub fn evaluate_partial<'a, PE>(
190 &self,
191 query: impl Into<MultilinearQueryRef<'a, PE>>,
192 start_index: usize,
193 ) -> Result<MultilinearExtension<PE>, Error>
194 where
195 PE: PackedField,
196 PE::Scalar: ExtensionField<P::Scalar>,
197 {
198 let query = query.into();
199 if start_index + query.n_vars() > self.mu {
200 bail!(Error::IncorrectStartIndex { expected: self.mu })
201 }
202
203 if start_index == 0 {
204 return self.evaluate_partial_low(query);
205 } else if start_index + query.n_vars() == self.mu {
206 return self.evaluate_partial_high(query);
207 }
208
209 if self.mu < query.n_vars() {
210 bail!(Error::IncorrectQuerySize { expected: self.mu });
211 }
212
213 let new_n_vars = self.mu - query.n_vars();
214 let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
215 let mut result_evals = Vec::with_capacity(result_evals_len);
216
217 fold_middle(
218 self.evals(),
219 self.mu,
220 query.expansion(),
221 query.n_vars(),
222 start_index,
223 result_evals.spare_capacity_mut(),
224 )?;
225 unsafe {
226 result_evals.set_len(result_evals_len);
227 }
228
229 MultilinearExtension::new(new_n_vars, result_evals)
230 }
231
232 #[instrument(
242 "MultilinearExtension::evaluate_partial_high",
243 skip_all,
244 level = "debug"
245 )]
246 pub fn evaluate_partial_high<'a, PE>(
247 &self,
248 query: impl Into<MultilinearQueryRef<'a, PE>>,
249 ) -> Result<MultilinearExtension<PE>, Error>
250 where
251 PE: PackedField,
252 PE::Scalar: ExtensionField<P::Scalar>,
253 {
254 let query = query.into();
255
256 let new_n_vars = self.mu.saturating_sub(query.n_vars());
257 let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
258 let mut result_evals = Vec::with_capacity(result_evals_len);
259
260 fold_left(
261 self.evals(),
262 self.mu,
263 query.expansion(),
264 query.n_vars(),
265 result_evals.spare_capacity_mut(),
266 )?;
267 unsafe {
268 result_evals.set_len(result_evals_len);
269 }
270
271 MultilinearExtension::new(new_n_vars, result_evals)
272 }
273
274 #[instrument(
284 "MultilinearExtension::evaluate_partial_low",
285 skip_all,
286 level = "trace"
287 )]
288 pub fn evaluate_partial_low<'a, PE>(
289 &self,
290 query: impl Into<MultilinearQueryRef<'a, PE>>,
291 ) -> Result<MultilinearExtension<PE>, Error>
292 where
293 PE: PackedField,
294 PE::Scalar: ExtensionField<P::Scalar>,
295 {
296 let query = query.into();
297
298 if self.mu < query.n_vars() {
299 bail!(Error::IncorrectQuerySize { expected: self.mu });
300 }
301
302 let new_n_vars = self.mu - query.n_vars();
303
304 let mut result =
305 zeroed_vec(1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)));
306 self.evaluate_partial_low_into(query, &mut result)?;
307 MultilinearExtension::new(new_n_vars, result)
308 }
309
310 pub fn evaluate_partial_low_into<PE>(
320 &self,
321 query: MultilinearQueryRef<PE>,
322 out: &mut [PE],
323 ) -> Result<(), Error>
324 where
325 PE: PackedField,
326 PE::Scalar: ExtensionField<P::Scalar>,
327 {
328 fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out)
331 }
332
333 pub fn zero_pad<PE>(
334 &self,
335 n_pad_vars: usize,
336 start_index: usize,
337 nonzero_index: usize,
338 ) -> Result<MultilinearExtension<PE>, Error>
339 where
340 PE: PackedField,
341 PE::Scalar: ExtensionField<P::Scalar>,
342 {
343 let init_n_vars = self.mu;
344 if start_index > init_n_vars {
345 bail!(Error::IncorrectStartIndexZeroPad { expected: self.mu })
346 }
347 let new_n_vars = init_n_vars + n_pad_vars;
348 if nonzero_index >= 1 << n_pad_vars {
349 bail!(Error::IncorrectNonZeroIndex {
350 expected: 1 << n_pad_vars,
351 });
352 }
353
354 let mut result = zeroed_vec(1 << new_n_vars);
355
356 zero_pad(&self.evals, init_n_vars, new_n_vars, start_index, nonzero_index, &mut result)?;
357 MultilinearExtension::new(new_n_vars, result)
358 }
359}
360
361impl<F: Field + AsSinglePacked, Data: Deref<Target = [F]>> MultilinearExtension<F, Data> {
362 pub fn to_single_packed(self) -> MultilinearExtension<F::Packed> {
364 let packed_evals = self
365 .evals
366 .iter()
367 .map(|eval| eval.to_single_packed())
368 .collect();
369 MultilinearExtension {
370 mu: self.mu,
371 evals: packed_evals,
372 }
373 }
374}
375
376const fn log2(v: usize) -> usize {
377 63 - (v as u64).leading_zeros() as usize
378}
379
380pub type MultilinearExtensionBorrowed<'a, P> = MultilinearExtension<P, &'a [P]>;
382
383#[cfg(test)]
384mod tests {
385 use std::iter::repeat_with;
386
387 use binius_field::{
388 arch::OptimalUnderlier256b, BinaryField128b, BinaryField16b as F, BinaryField1b,
389 BinaryField32b, BinaryField8b, PackedBinaryField16x1b, PackedBinaryField16x8b,
390 PackedBinaryField32x1b, PackedBinaryField4x32b, PackedBinaryField8x16b as P,
391 PackedBinaryField8x1b,
392 };
393 use itertools::Itertools;
394 use rand::{rngs::StdRng, SeedableRng};
395
396 use super::*;
397 use crate::{tensor_prod_eq_ind, MultilinearQuery};
398
399 fn expand_query_naive<F: Field>(query: &[F]) -> Result<Vec<F>, Error> {
405 let result = (0..1 << query.len())
406 .map(|i| eval_basis(query, i))
407 .collect();
408 Ok(result)
409 }
410
411 fn eval_basis<F: Field>(query: &[F], i: usize) -> F {
413 query
414 .iter()
415 .enumerate()
416 .map(|(j, &v)| if i & (1 << j) == 0 { F::ONE - v } else { v })
417 .product()
418 }
419
420 fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
421 let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
422 result[0] = P::set_single(P::Scalar::ONE);
423 tensor_prod_eq_ind(0, &mut result, p).unwrap();
424 MultilinearQuery::with_expansion(p.len(), result).unwrap()
425 }
426
427 #[test]
428 fn test_expand_query_impls_consistent() {
429 let mut rng = StdRng::seed_from_u64(0);
430 let q = repeat_with(|| Field::random(&mut rng))
431 .take(8)
432 .collect::<Vec<F>>();
433 let result1 = multilinear_query::<P>(&q);
434 let result2 = expand_query_naive(&q).unwrap();
435 assert_eq!(PackedField::iter_slice(result1.expansion()).collect_vec(), result2);
436 }
437
438 #[test]
439 fn test_new_from_values_correspondence() {
440 let mut rng = StdRng::seed_from_u64(0);
441 let evals = repeat_with(|| Field::random(&mut rng))
442 .take(256)
443 .collect::<Vec<F>>();
444 let poly1 = MultilinearExtension::from_values(evals.clone()).unwrap();
445 let poly2 = MultilinearExtension::new(8, evals).unwrap();
446
447 assert_eq!(poly1, poly2)
448 }
449
450 #[test]
451 fn test_evaluate_on_hypercube() {
452 let mut values = vec![F::ZERO; 64];
453 values
454 .iter_mut()
455 .enumerate()
456 .for_each(|(i, val)| *val = F::new(i as u16));
457
458 let poly = MultilinearExtension::from_values(values).unwrap();
459 for i in 0..64 {
460 let q = (0..6)
461 .map(|j| if (i >> j) & 1 != 0 { F::ONE } else { F::ZERO })
462 .collect::<Vec<_>>();
463 let multilin_query = multilinear_query::<P>(&q);
464 let result = poly.evaluate(multilin_query.to_ref()).unwrap();
465 assert_eq!(result, F::new(i));
466 }
467 }
468
469 fn evaluate_split<P>(
470 poly: MultilinearExtension<P>,
471 q: &[P::Scalar],
472 splits: &[usize],
473 ) -> P::Scalar
474 where
475 P: PackedField + 'static,
476 {
477 assert_eq!(splits.iter().sum::<usize>(), poly.n_vars());
478
479 let mut partial_result = poly;
480 let mut index = q.len();
481 for split_vars in &splits[0..splits.len() - 1] {
482 let split_vars = *split_vars;
483 let query = multilinear_query(&q[index - split_vars..index]);
484 partial_result = partial_result
485 .evaluate_partial_high(query.to_ref())
486 .unwrap();
487 index -= split_vars;
488 }
489 let multilin_query = multilinear_query::<P>(&q[..index]);
490 partial_result.evaluate(multilin_query.to_ref()).unwrap()
491 }
492
493 #[test]
494 fn test_evaluate_split_is_correct() {
495 let mut rng = StdRng::seed_from_u64(0);
496 let evals = repeat_with(|| Field::random(&mut rng))
497 .take(256)
498 .collect::<Vec<F>>();
499 let poly = MultilinearExtension::from_values(evals).unwrap();
500 let q = repeat_with(|| Field::random(&mut rng))
501 .take(8)
502 .collect::<Vec<F>>();
503 let multilin_query = multilinear_query::<P>(&q);
504 let result1 = poly.evaluate(multilin_query.to_ref()).unwrap();
505 let result2 = evaluate_split(poly, &q, &[2, 3, 3]);
506 assert_eq!(result1, result2);
507 }
508
509 fn get_bits<P, PE>(values: &[PE], start_index: usize) -> Vec<PE::Scalar>
510 where
511 P: PackedField,
512 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
513 {
514 let new_vals = values
515 .iter()
516 .flat_map(|v| {
517 (P::WIDTH * start_index..P::WIDTH * (start_index + 1))
518 .map(|i| v.get(i))
519 .collect::<Vec<_>>()
520 })
521 .collect::<Vec<_>>();
522
523 new_vals
524 }
525
526 #[test]
527 fn test_evaluate_middle_32b_to_16b() {
528 let mut rng = StdRng::seed_from_u64(0);
529 let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
530 .take(1 << 2)
531 .collect::<Vec<_>>();
532
533 let expected_lower_bits =
534 get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 0);
535
536 let expected_higher_bits =
537 get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 1);
538
539 let poly = MultilinearExtension::from_values(values).unwrap();
540
541 let q_low = [<BinaryField1b as PackedField>::zero()];
543 let q_hi = [<BinaryField1b as PackedField>::one()];
545
546 let query_low = multilinear_query::<BinaryField1b>(&q_low);
548 let query_hi = multilinear_query::<BinaryField1b>(&q_hi);
550
551 let evals_low = poly.evaluate_partial(query_low.to_ref(), 4).unwrap();
553 let evals_low = evals_low.evals();
554
555 let evals_hi = poly.evaluate_partial(query_hi.to_ref(), 4).unwrap();
557 let evals_hi = evals_hi.evals();
558
559 assert_eq!(evals_low, expected_lower_bits);
560 assert_eq!(evals_hi, expected_higher_bits);
561 }
562
563 #[test]
564 fn test_evaluate_middle_32b_to_8b() {
565 let mut rng = StdRng::seed_from_u64(0);
566 let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
567 .take(1 << 2)
568 .collect::<Vec<_>>();
569
570 let expected_first_quarter =
571 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 0);
572
573 let expected_second_quarter =
574 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 1);
575
576 let expected_third_quarter =
577 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 2);
578
579 let expected_fourth_quarter =
580 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 3);
581
582 let poly = MultilinearExtension::from_values(values).unwrap();
583
584 let q_first = [
586 <BinaryField1b as PackedField>::zero(),
587 <BinaryField1b as PackedField>::zero(),
588 ];
589 let q_second = [
591 <BinaryField1b as PackedField>::one(),
592 <BinaryField1b as PackedField>::zero(),
593 ];
594 let q_third = [
596 <BinaryField1b as PackedField>::zero(),
597 <BinaryField1b as PackedField>::one(),
598 ];
599 let q_fourth = [
601 <BinaryField1b as PackedField>::one(),
602 <BinaryField1b as PackedField>::one(),
603 ];
604
605 let query_first = multilinear_query::<BinaryField1b>(&q_first);
607 let query_second = multilinear_query::<BinaryField1b>(&q_second);
609 let query_third = multilinear_query::<BinaryField1b>(&q_third);
611 let query_fourth = multilinear_query::<BinaryField1b>(&q_fourth);
613
614 let evals_first_quarter = poly.evaluate_partial(query_first.to_ref(), 3).unwrap();
616 let evals_first_quarter = evals_first_quarter.evals();
617 let evals_second_quarter = poly.evaluate_partial(query_second.to_ref(), 3).unwrap();
619 let evals_second_quarter = evals_second_quarter.evals();
620 let evals_third_quarter = poly.evaluate_partial(query_third.to_ref(), 3).unwrap();
622 let evals_third_quarter = evals_third_quarter.evals();
623 let evals_fourth_quarter = poly.evaluate_partial(query_fourth.to_ref(), 3).unwrap();
625 let evals_fourth_quarter = evals_fourth_quarter.evals();
626
627 assert_eq!(evals_first_quarter, expected_first_quarter);
628 assert_eq!(evals_second_quarter, expected_second_quarter);
629 assert_eq!(evals_third_quarter, expected_third_quarter);
630 assert_eq!(evals_fourth_quarter, expected_fourth_quarter);
631 }
632
633 #[test]
634 fn test_zeropad_8b_to_32b_project() {
635 let mut rng = StdRng::seed_from_u64(0);
636 let values = repeat_with(|| PackedBinaryField8x1b::random(&mut rng))
637 .take(1 << 2)
638 .collect::<Vec<_>>();
639 let expected_out = PackedBinaryField8x1b::iter_slice(&values).collect::<Vec<_>>();
640
641 let poly = MultilinearExtension::from_values(values).unwrap();
642
643 let n_pad_vars = 2;
645 let start_index = 3;
647 let nonzero_index = 0;
649 let padded_poly = poly
651 .zero_pad::<BinaryField1b>(n_pad_vars, start_index, nonzero_index)
652 .unwrap();
653
654 let query_first = multilinear_query::<BinaryField1b>(&[
656 <BinaryField1b as PackedField>::zero(),
657 <BinaryField1b as PackedField>::zero(),
658 ]);
659 let projected_poly = padded_poly
660 .evaluate_partial(query_first.to_ref(), 3)
661 .unwrap();
662
663 assert_eq!(expected_out, projected_poly.evals());
665 }
666
667 #[test]
668 fn test_evaluate_partial_match_evaluate_partial_low() {
669 type P = PackedBinaryField16x8b;
670 type F = BinaryField8b;
671
672 let mut rng = StdRng::seed_from_u64(0);
673
674 let n_vars: usize = 7;
675
676 let values = repeat_with(|| P::random(&mut rng))
677 .take(1 << (n_vars.saturating_sub(P::LOG_WIDTH)))
678 .collect();
679
680 let query_n_minus_1 = repeat_with(|| <F as PackedField>::random(&mut rng))
681 .take(n_vars - 1)
682 .collect::<Vec<_>>();
683
684 let (q_low, q_high) = query_n_minus_1.split_at(query_n_minus_1.len() / 2);
685
686 let query_low = multilinear_query::<P>(q_low);
688 let query_high = multilinear_query::<P>(q_high);
690
691 let me = MultilinearExtension::from_values(values).unwrap();
692
693 assert_eq!(
694 me.evaluate_partial_low(query_low.to_ref())
695 .unwrap()
696 .evaluate_partial_low(query_high.to_ref())
697 .unwrap(),
698 me.evaluate_partial(query_high.to_ref(), q_low.len())
699 .unwrap()
700 .evaluate_partial_low(query_low.to_ref())
701 .unwrap()
702 )
703 }
704
705 #[test]
706 fn test_evaluate_partial_high_packed() {
707 let mut rng = StdRng::seed_from_u64(0);
708 let evals = repeat_with(|| P::random(&mut rng))
709 .take(256 >> P::LOG_WIDTH)
710 .collect::<Vec<_>>();
711 let poly = MultilinearExtension::from_values(evals).unwrap();
712 let q = repeat_with(|| Field::random(&mut rng))
713 .take(8)
714 .collect::<Vec<BinaryField128b>>();
715 let multilin_query = multilinear_query::<BinaryField128b>(&q);
716
717 let expected = poly.evaluate(multilin_query.to_ref()).unwrap();
718
719 let query_hi = multilinear_query::<BinaryField128b>(&q[1..]);
721 let partial_eval = poly.evaluate_partial_high(query_hi.to_ref()).unwrap();
722 assert!(partial_eval.n_vars() < P::LOG_WIDTH);
723
724 let query_lo = multilinear_query::<BinaryField128b>(&q[..1]);
725 let eval = partial_eval.evaluate(query_lo.to_ref()).unwrap();
726 assert_eq!(eval, expected);
727 }
728
729 #[test]
730 fn test_evaluate_partial_low_high_smaller_than_packed_width() {
731 type P = PackedBinaryField16x8b;
732
733 type F = BinaryField8b;
734
735 let n_vars = 3;
736
737 let mut rng = StdRng::seed_from_u64(0);
738
739 let values = repeat_with(|| Field::random(&mut rng))
740 .take(1 << n_vars)
741 .collect::<Vec<F>>();
742
743 let q = repeat_with(|| Field::random(&mut rng))
744 .take(n_vars)
745 .collect::<Vec<F>>();
746
747 let query = multilinear_query::<P>(&q);
748
749 let packed = P::from_scalars(values);
750 let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
751
752 let eval = me.evaluate(&query).unwrap();
753
754 let query_low = multilinear_query::<P>(&q[..n_vars - 1]);
755 let query_high = multilinear_query::<P>(&q[n_vars - 1..]);
756
757 let eval_l_h = me
758 .evaluate_partial_high(&query_high)
759 .unwrap()
760 .evaluate_partial_low(&query_low)
761 .unwrap()
762 .evals()[0]
763 .get(0);
764
765 assert_eq!(eval, eval_l_h);
766 }
767
768 #[test]
769 fn test_evaluate_on_hypercube_small_than_packed_width() {
770 type P = PackedBinaryField16x8b;
771
772 type F = BinaryField8b;
773
774 let n_vars = 3;
775
776 let mut rng = StdRng::seed_from_u64(0);
777
778 let values = repeat_with(|| Field::random(&mut rng))
779 .take(1 << n_vars)
780 .collect::<Vec<F>>();
781
782 let packed = P::from_scalars(values.clone());
783
784 let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
785
786 assert_eq!(me.evaluate_on_hypercube(1).unwrap(), values[1]);
787
788 assert!(me.evaluate_on_hypercube(1 << n_vars).is_err());
789 }
790
791 #[test]
792 fn test_evaluate_partial_high_low_evaluate_consistent() {
793 let mut rng = StdRng::seed_from_u64(0);
794 let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
795 .take(1 << 8)
796 .collect();
797
798 let me = MultilinearExtension::from_values(values).unwrap();
799
800 let q = repeat_with(|| <BinaryField32b as PackedField>::random(&mut rng))
801 .take(me.n_vars())
802 .collect::<Vec<_>>();
803
804 let query = multilinear_query(&q);
805
806 let eval = me
807 .evaluate::<<PackedBinaryField4x32b as PackedField>::Scalar, PackedBinaryField4x32b>(
808 query.to_ref(),
809 )
810 .unwrap();
811
812 assert_eq!(
813 me.evaluate_partial_low::<PackedBinaryField4x32b>(query.to_ref())
814 .unwrap()
815 .evals[0]
816 .get(0),
817 eval
818 );
819 assert_eq!(
820 me.evaluate_partial_high::<PackedBinaryField4x32b>(query.to_ref())
821 .unwrap()
822 .evals[0]
823 .get(0),
824 eval
825 );
826 }
827
828 #[test]
829 fn test_evaluate_partial_low_single_and_multiple_var_consistent() {
830 let mut rng = StdRng::seed_from_u64(0);
831 let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
832 .take(1 << 8)
833 .collect();
834
835 let mle = MultilinearExtension::from_values(values).unwrap();
836 let r1 = <BinaryField32b as PackedField>::random(&mut rng);
837 let r2 = <BinaryField32b as PackedField>::random(&mut rng);
838
839 let eval_1: MultilinearExtension<PackedBinaryField4x32b> = mle
840 .evaluate_partial_low::<PackedBinaryField4x32b>(multilinear_query(&[r1]).to_ref())
841 .unwrap()
842 .evaluate_partial_low(multilinear_query(&[r2]).to_ref())
843 .unwrap();
844 let eval_2 = mle
845 .evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref())
846 .unwrap();
847 assert_eq!(eval_1, eval_2);
848 }
849
850 #[test]
851 fn test_new_mle_with_tiny_nvars() {
852 MultilinearExtension::new(
853 1,
854 vec![PackedType::<OptimalUnderlier256b, BinaryField32b>::one()],
855 )
856 .unwrap();
857 }
858}