1use std::{fmt::Debug, ops::Deref};
4
5use binius_field::{
6 as_packed_field::{AsSinglePacked, PackScalar, PackedType},
7 underlier::UnderlierType,
8 util::inner_product_par,
9 ExtensionField, Field, PackedField,
10};
11use binius_utils::bail;
12use bytemuck::zeroed_vec;
13use tracing::instrument;
14
15use crate::{fold::fold_left, fold_middle, fold_right, Error, MultilinearQueryRef, PackingDeref};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct MultilinearExtension<P: PackedField, Data: Deref<Target = [P]> = Vec<P>> {
25 mu: usize,
27 evals: Data,
29}
30
31impl<P: PackedField> MultilinearExtension<P> {
32 pub fn zeros(n_vars: usize) -> Result<Self, Error> {
33 if n_vars < P::LOG_WIDTH {
34 bail!(Error::ArgumentRangeError {
35 arg: "n_vars".into(),
36 range: P::LOG_WIDTH..32,
37 });
38 }
39 Ok(Self {
40 mu: n_vars,
41 evals: vec![P::default(); 1 << (n_vars - P::LOG_WIDTH)],
42 })
43 }
44
45 pub fn from_values(v: Vec<P>) -> Result<Self, Error> {
46 Self::from_values_generic(v)
47 }
48
49 pub fn into_evals(self) -> Vec<P> {
50 self.evals
51 }
52}
53
54impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
55 pub fn from_values_generic(v: Data) -> Result<Self, Error> {
56 if !v.len().is_power_of_two() {
57 bail!(Error::PowerOfTwoLengthRequired);
58 }
59 let mu = log2(v.len()) + P::LOG_WIDTH;
60
61 Ok(Self { mu, evals: v })
62 }
63
64 pub fn new(n_vars: usize, v: Data) -> Result<Self, Error> {
65 if !v.len().is_power_of_two() {
66 bail!(Error::PowerOfTwoLengthRequired);
67 }
68
69 if n_vars < P::LOG_WIDTH {
70 if v.len() != 1 {
71 bail!(Error::IncorrectNumberOfVariables {
72 expected: n_vars,
73 actual: P::LOG_WIDTH + log2(v.len())
74 });
75 }
76 } else if P::LOG_WIDTH + log2(v.len()) != n_vars {
77 bail!(Error::IncorrectNumberOfVariables {
78 expected: n_vars,
79 actual: P::LOG_WIDTH + log2(v.len())
80 });
81 }
82
83 Ok(Self {
84 mu: n_vars,
85 evals: v,
86 })
87 }
88}
89
90impl<U, F, Data> MultilinearExtension<PackedType<U, F>, PackingDeref<U, F, Data>>
91where
92 U: UnderlierType + PackScalar<F>,
94 F: Field,
95 Data: Deref<Target = [U]>,
96{
97 pub fn from_underliers(v: Data) -> Result<Self, Error> {
98 Self::from_values_generic(PackingDeref::new(v))
99 }
100}
101
102impl<'a, P: PackedField> MultilinearExtension<P, &'a [P]> {
103 pub fn from_values_slice(v: &'a [P]) -> Result<Self, Error> {
104 if !v.len().is_power_of_two() {
105 bail!(Error::PowerOfTwoLengthRequired);
106 }
107 let mu = log2(v.len() * P::WIDTH);
108 Ok(Self { mu, evals: v })
109 }
110}
111
112impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
113 pub const fn n_vars(&self) -> usize {
114 self.mu
115 }
116
117 pub const fn size(&self) -> usize {
118 1 << self.mu
119 }
120
121 pub fn evals(&self) -> &[P] {
122 &self.evals
123 }
124
125 pub fn to_ref(&self) -> MultilinearExtension<P, &[P]> {
126 MultilinearExtension {
127 mu: self.mu,
128 evals: self.evals(),
129 }
130 }
131
132 pub fn packed_evaluate_on_hypercube(&self, index: usize) -> Result<P, Error> {
139 self.evals()
140 .get(index)
141 .ok_or(Error::HypercubeIndexOutOfRange { index })
142 .copied()
143 }
144
145 pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
146 if self.size() <= index {
147 bail!(Error::HypercubeIndexOutOfRange { index })
148 }
149
150 let subcube_eval = self.packed_evaluate_on_hypercube(index / P::WIDTH)?;
151 Ok(subcube_eval.get(index % P::WIDTH))
152 }
153}
154
155impl<P, Data> MultilinearExtension<P, Data>
156where
157 P: PackedField,
158 Data: Deref<Target = [P]> + Send + Sync,
159{
160 pub fn evaluate<'a, FE, PE>(
161 &self,
162 query: impl Into<MultilinearQueryRef<'a, PE>>,
163 ) -> Result<FE, Error>
164 where
165 FE: ExtensionField<P::Scalar>,
166 PE: PackedField<Scalar = FE>,
167 {
168 let query = query.into();
169 if self.mu != query.n_vars() {
170 bail!(Error::IncorrectQuerySize { expected: self.mu });
171 }
172
173 if self.mu < P::LOG_WIDTH || query.n_vars() < PE::LOG_WIDTH {
174 let evals = PackedField::iter_slice(self.evals())
175 .take(self.size())
176 .collect::<Vec<P::Scalar>>();
177 let querys = PackedField::iter_slice(query.expansion())
178 .take(1 << query.n_vars())
179 .collect::<Vec<PE::Scalar>>();
180 Ok(inner_product_par(&querys, &evals))
181 } else {
182 Ok(inner_product_par(query.expansion(), &self.evals))
183 }
184 }
185
186 #[instrument("MultilinearExtension::evaluate_partial", skip_all, level = "debug")]
187 pub fn evaluate_partial<'a, PE>(
188 &self,
189 query: impl Into<MultilinearQueryRef<'a, PE>>,
190 start_index: usize,
191 ) -> Result<MultilinearExtension<PE>, Error>
192 where
193 PE: PackedField,
194 PE::Scalar: ExtensionField<P::Scalar>,
195 {
196 let query = query.into();
197 if start_index + query.n_vars() > self.mu {
198 bail!(Error::IncorrectStartIndex { expected: self.mu })
199 }
200
201 if start_index == 0 {
202 return self.evaluate_partial_low(query);
203 } else if start_index + query.n_vars() == self.mu {
204 return self.evaluate_partial_high(query);
205 }
206
207 if self.mu < query.n_vars() {
208 bail!(Error::IncorrectQuerySize { expected: self.mu });
209 }
210
211 let new_n_vars = self.mu - query.n_vars();
212 let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
213 let mut result_evals = Vec::with_capacity(result_evals_len);
214
215 fold_middle(
216 self.evals(),
217 self.mu,
218 query.expansion(),
219 query.n_vars(),
220 start_index,
221 result_evals.spare_capacity_mut(),
222 )?;
223 unsafe {
224 result_evals.set_len(result_evals_len);
225 }
226
227 MultilinearExtension::new(new_n_vars, result_evals)
228 }
229
230 #[instrument(
240 "MultilinearExtension::evaluate_partial_high",
241 skip_all,
242 level = "debug"
243 )]
244 pub fn evaluate_partial_high<'a, PE>(
245 &self,
246 query: impl Into<MultilinearQueryRef<'a, PE>>,
247 ) -> Result<MultilinearExtension<PE>, Error>
248 where
249 PE: PackedField,
250 PE::Scalar: ExtensionField<P::Scalar>,
251 {
252 let query = query.into();
253
254 let new_n_vars = self.mu.saturating_sub(query.n_vars());
255 let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
256 let mut result_evals = Vec::with_capacity(result_evals_len);
257
258 fold_left(
259 self.evals(),
260 self.mu,
261 query.expansion(),
262 query.n_vars(),
263 result_evals.spare_capacity_mut(),
264 )?;
265 unsafe {
266 result_evals.set_len(result_evals_len);
267 }
268
269 MultilinearExtension::new(new_n_vars, result_evals)
270 }
271
272 #[instrument(
282 "MultilinearExtension::evaluate_partial_low",
283 skip_all,
284 level = "trace"
285 )]
286 pub fn evaluate_partial_low<'a, PE>(
287 &self,
288 query: impl Into<MultilinearQueryRef<'a, PE>>,
289 ) -> Result<MultilinearExtension<PE>, Error>
290 where
291 PE: PackedField,
292 PE::Scalar: ExtensionField<P::Scalar>,
293 {
294 let query = query.into();
295
296 if self.mu < query.n_vars() {
297 bail!(Error::IncorrectQuerySize { expected: self.mu });
298 }
299
300 let new_n_vars = self.mu - query.n_vars();
301
302 let mut result =
303 zeroed_vec(1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)));
304 self.evaluate_partial_low_into(query, &mut result)?;
305 MultilinearExtension::new(new_n_vars, result)
306 }
307
308 pub fn evaluate_partial_low_into<PE>(
318 &self,
319 query: MultilinearQueryRef<PE>,
320 out: &mut [PE],
321 ) -> Result<(), Error>
322 where
323 PE: PackedField,
324 PE::Scalar: ExtensionField<P::Scalar>,
325 {
326 fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out)
329 }
330}
331
332impl<F: Field + AsSinglePacked, Data: Deref<Target = [F]>> MultilinearExtension<F, Data> {
333 pub fn to_single_packed(self) -> MultilinearExtension<F::Packed> {
335 let packed_evals = self
336 .evals
337 .iter()
338 .map(|eval| eval.to_single_packed())
339 .collect();
340 MultilinearExtension {
341 mu: self.mu,
342 evals: packed_evals,
343 }
344 }
345}
346
347const fn log2(v: usize) -> usize {
348 63 - (v as u64).leading_zeros() as usize
349}
350
351pub type MultilinearExtensionBorrowed<'a, P> = MultilinearExtension<P, &'a [P]>;
353
354#[cfg(test)]
355mod tests {
356 use std::iter::repeat_with;
357
358 use binius_field::{
359 arch::OptimalUnderlier256b, BinaryField128b, BinaryField16b as F, BinaryField1b,
360 BinaryField32b, BinaryField8b, PackedBinaryField16x1b, PackedBinaryField16x8b,
361 PackedBinaryField32x1b, PackedBinaryField4x32b, PackedBinaryField8x16b as P,
362 PackedBinaryField8x1b,
363 };
364 use itertools::Itertools;
365 use rand::{rngs::StdRng, SeedableRng};
366
367 use super::*;
368 use crate::{tensor_prod_eq_ind, MultilinearQuery};
369
370 fn expand_query_naive<F: Field>(query: &[F]) -> Result<Vec<F>, Error> {
376 let result = (0..1 << query.len())
377 .map(|i| eval_basis(query, i))
378 .collect();
379 Ok(result)
380 }
381
382 fn eval_basis<F: Field>(query: &[F], i: usize) -> F {
384 query
385 .iter()
386 .enumerate()
387 .map(|(j, &v)| if i & (1 << j) == 0 { F::ONE - v } else { v })
388 .product()
389 }
390
391 fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
392 let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
393 result[0] = P::set_single(P::Scalar::ONE);
394 tensor_prod_eq_ind(0, &mut result, p).unwrap();
395 MultilinearQuery::with_expansion(p.len(), result).unwrap()
396 }
397
398 #[test]
399 fn test_expand_query_impls_consistent() {
400 let mut rng = StdRng::seed_from_u64(0);
401 let q = repeat_with(|| Field::random(&mut rng))
402 .take(8)
403 .collect::<Vec<F>>();
404 let result1 = multilinear_query::<P>(&q);
405 let result2 = expand_query_naive(&q).unwrap();
406 assert_eq!(PackedField::iter_slice(result1.expansion()).collect_vec(), result2);
407 }
408
409 #[test]
410 fn test_new_from_values_correspondence() {
411 let mut rng = StdRng::seed_from_u64(0);
412 let evals = repeat_with(|| Field::random(&mut rng))
413 .take(256)
414 .collect::<Vec<F>>();
415 let poly1 = MultilinearExtension::from_values(evals.clone()).unwrap();
416 let poly2 = MultilinearExtension::new(8, evals).unwrap();
417
418 assert_eq!(poly1, poly2)
419 }
420
421 #[test]
422 fn test_evaluate_on_hypercube() {
423 let mut values = vec![F::ZERO; 64];
424 values
425 .iter_mut()
426 .enumerate()
427 .for_each(|(i, val)| *val = F::new(i as u16));
428
429 let poly = MultilinearExtension::from_values(values).unwrap();
430 for i in 0..64 {
431 let q = (0..6)
432 .map(|j| if (i >> j) & 1 != 0 { F::ONE } else { F::ZERO })
433 .collect::<Vec<_>>();
434 let multilin_query = multilinear_query::<P>(&q);
435 let result = poly.evaluate(multilin_query.to_ref()).unwrap();
436 assert_eq!(result, F::new(i));
437 }
438 }
439
440 fn evaluate_split<P>(
441 poly: MultilinearExtension<P>,
442 q: &[P::Scalar],
443 splits: &[usize],
444 ) -> P::Scalar
445 where
446 P: PackedField + 'static,
447 {
448 assert_eq!(splits.iter().sum::<usize>(), poly.n_vars());
449
450 let mut partial_result = poly;
451 let mut index = q.len();
452 for split_vars in &splits[0..splits.len() - 1] {
453 let split_vars = *split_vars;
454 let query = multilinear_query(&q[index - split_vars..index]);
455 partial_result = partial_result
456 .evaluate_partial_high(query.to_ref())
457 .unwrap();
458 index -= split_vars;
459 }
460 let multilin_query = multilinear_query::<P>(&q[..index]);
461 partial_result.evaluate(multilin_query.to_ref()).unwrap()
462 }
463
464 #[test]
465 fn test_evaluate_split_is_correct() {
466 let mut rng = StdRng::seed_from_u64(0);
467 let evals = repeat_with(|| Field::random(&mut rng))
468 .take(256)
469 .collect::<Vec<F>>();
470 let poly = MultilinearExtension::from_values(evals).unwrap();
471 let q = repeat_with(|| Field::random(&mut rng))
472 .take(8)
473 .collect::<Vec<F>>();
474 let multilin_query = multilinear_query::<P>(&q);
475 let result1 = poly.evaluate(multilin_query.to_ref()).unwrap();
476 let result2 = evaluate_split(poly, &q, &[2, 3, 3]);
477 assert_eq!(result1, result2);
478 }
479
480 fn get_bits<P, PE>(values: &[PE], start_index: usize) -> Vec<PE::Scalar>
481 where
482 P: PackedField,
483 PE: PackedField<Scalar: ExtensionField<P::Scalar>>,
484 {
485 let new_vals = values
486 .iter()
487 .flat_map(|v| {
488 (P::WIDTH * start_index..P::WIDTH * (start_index + 1))
489 .map(|i| v.get(i))
490 .collect::<Vec<_>>()
491 })
492 .collect::<Vec<_>>();
493
494 new_vals
495 }
496
497 #[test]
498 fn test_evaluate_middle_32b_to_16b() {
499 let mut rng = StdRng::seed_from_u64(0);
500 let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
501 .take(1 << 2)
502 .collect::<Vec<_>>();
503
504 let expected_lower_bits =
505 get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 0);
506
507 let expected_higher_bits =
508 get_bits::<PackedBinaryField16x1b, PackedBinaryField32x1b>(&values, 1);
509
510 let poly = MultilinearExtension::from_values(values).unwrap();
511
512 let q_low = [<BinaryField1b as PackedField>::zero()];
514 let q_hi = [<BinaryField1b as PackedField>::one()];
516
517 let query_low = multilinear_query::<BinaryField1b>(&q_low);
519 let query_hi = multilinear_query::<BinaryField1b>(&q_hi);
521
522 let evals_low = poly.evaluate_partial(query_low.to_ref(), 4).unwrap();
524 let evals_low = evals_low.evals();
525
526 let evals_hi = poly.evaluate_partial(query_hi.to_ref(), 4).unwrap();
528 let evals_hi = evals_hi.evals();
529
530 assert_eq!(evals_low, expected_lower_bits);
531 assert_eq!(evals_hi, expected_higher_bits);
532 }
533
534 #[test]
535 fn test_evaluate_middle_32b_to_8b() {
536 let mut rng = StdRng::seed_from_u64(0);
537 let values = repeat_with(|| PackedBinaryField32x1b::random(&mut rng))
538 .take(1 << 2)
539 .collect::<Vec<_>>();
540
541 let expected_first_quarter =
542 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 0);
543
544 let expected_second_quarter =
545 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 1);
546
547 let expected_third_quarter =
548 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 2);
549
550 let expected_fourth_quarter =
551 get_bits::<PackedBinaryField8x1b, PackedBinaryField32x1b>(&values, 3);
552
553 let poly = MultilinearExtension::from_values(values).unwrap();
554
555 let q_first = [
557 <BinaryField1b as PackedField>::zero(),
558 <BinaryField1b as PackedField>::zero(),
559 ];
560 let q_second = [
562 <BinaryField1b as PackedField>::one(),
563 <BinaryField1b as PackedField>::zero(),
564 ];
565 let q_third = [
567 <BinaryField1b as PackedField>::zero(),
568 <BinaryField1b as PackedField>::one(),
569 ];
570 let q_fourth = [
572 <BinaryField1b as PackedField>::one(),
573 <BinaryField1b as PackedField>::one(),
574 ];
575
576 let query_first = multilinear_query::<BinaryField1b>(&q_first);
578 let query_second = multilinear_query::<BinaryField1b>(&q_second);
580 let query_third = multilinear_query::<BinaryField1b>(&q_third);
582 let query_fourth = multilinear_query::<BinaryField1b>(&q_fourth);
584
585 let evals_first_quarter = poly.evaluate_partial(query_first.to_ref(), 3).unwrap();
587 let evals_first_quarter = evals_first_quarter.evals();
588 let evals_second_quarter = poly.evaluate_partial(query_second.to_ref(), 3).unwrap();
590 let evals_second_quarter = evals_second_quarter.evals();
591 let evals_third_quarter = poly.evaluate_partial(query_third.to_ref(), 3).unwrap();
593 let evals_third_quarter = evals_third_quarter.evals();
594 let evals_fourth_quarter = poly.evaluate_partial(query_fourth.to_ref(), 3).unwrap();
596 let evals_fourth_quarter = evals_fourth_quarter.evals();
597
598 assert_eq!(evals_first_quarter, expected_first_quarter);
599 assert_eq!(evals_second_quarter, expected_second_quarter);
600 assert_eq!(evals_third_quarter, expected_third_quarter);
601 assert_eq!(evals_fourth_quarter, expected_fourth_quarter);
602 }
603
604 #[test]
605 fn test_evaluate_partial_match_evaluate_partial_low() {
606 type P = PackedBinaryField16x8b;
607 type F = BinaryField8b;
608
609 let mut rng = StdRng::seed_from_u64(0);
610
611 let n_vars: usize = 7;
612
613 let values = repeat_with(|| P::random(&mut rng))
614 .take(1 << (n_vars.saturating_sub(P::LOG_WIDTH)))
615 .collect();
616
617 let query_n_minus_1 = repeat_with(|| <F as PackedField>::random(&mut rng))
618 .take(n_vars - 1)
619 .collect::<Vec<_>>();
620
621 let (q_low, q_high) = query_n_minus_1.split_at(query_n_minus_1.len() / 2);
622
623 let query_low = multilinear_query::<P>(q_low);
625 let query_high = multilinear_query::<P>(q_high);
627
628 let me = MultilinearExtension::from_values(values).unwrap();
629
630 assert_eq!(
631 me.evaluate_partial_low(query_low.to_ref())
632 .unwrap()
633 .evaluate_partial_low(query_high.to_ref())
634 .unwrap(),
635 me.evaluate_partial(query_high.to_ref(), q_low.len())
636 .unwrap()
637 .evaluate_partial_low(query_low.to_ref())
638 .unwrap()
639 )
640 }
641
642 #[test]
643 fn test_evaluate_partial_high_packed() {
644 let mut rng = StdRng::seed_from_u64(0);
645 let evals = repeat_with(|| P::random(&mut rng))
646 .take(256 >> P::LOG_WIDTH)
647 .collect::<Vec<_>>();
648 let poly = MultilinearExtension::from_values(evals).unwrap();
649 let q = repeat_with(|| Field::random(&mut rng))
650 .take(8)
651 .collect::<Vec<BinaryField128b>>();
652 let multilin_query = multilinear_query::<BinaryField128b>(&q);
653
654 let expected = poly.evaluate(multilin_query.to_ref()).unwrap();
655
656 let query_hi = multilinear_query::<BinaryField128b>(&q[1..]);
658 let partial_eval = poly.evaluate_partial_high(query_hi.to_ref()).unwrap();
659 assert!(partial_eval.n_vars() < P::LOG_WIDTH);
660
661 let query_lo = multilinear_query::<BinaryField128b>(&q[..1]);
662 let eval = partial_eval.evaluate(query_lo.to_ref()).unwrap();
663 assert_eq!(eval, expected);
664 }
665
666 #[test]
667 fn test_evaluate_partial_low_high_smaller_than_packed_width() {
668 type P = PackedBinaryField16x8b;
669
670 type F = BinaryField8b;
671
672 let n_vars = 3;
673
674 let mut rng = StdRng::seed_from_u64(0);
675
676 let values = repeat_with(|| Field::random(&mut rng))
677 .take(1 << n_vars)
678 .collect::<Vec<F>>();
679
680 let q = repeat_with(|| Field::random(&mut rng))
681 .take(n_vars)
682 .collect::<Vec<F>>();
683
684 let query = multilinear_query::<P>(&q);
685
686 let packed = P::from_scalars(values);
687 let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
688
689 let eval = me.evaluate(&query).unwrap();
690
691 let query_low = multilinear_query::<P>(&q[..n_vars - 1]);
692 let query_high = multilinear_query::<P>(&q[n_vars - 1..]);
693
694 let eval_l_h = me
695 .evaluate_partial_high(&query_high)
696 .unwrap()
697 .evaluate_partial_low(&query_low)
698 .unwrap()
699 .evals()[0]
700 .get(0);
701
702 assert_eq!(eval, eval_l_h);
703 }
704
705 #[test]
706 fn test_evaluate_on_hypercube_small_than_packed_width() {
707 type P = PackedBinaryField16x8b;
708
709 type F = BinaryField8b;
710
711 let n_vars = 3;
712
713 let mut rng = StdRng::seed_from_u64(0);
714
715 let values = repeat_with(|| Field::random(&mut rng))
716 .take(1 << n_vars)
717 .collect::<Vec<F>>();
718
719 let packed = P::from_scalars(values.clone());
720
721 let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
722
723 assert_eq!(me.evaluate_on_hypercube(1).unwrap(), values[1]);
724
725 assert!(me.evaluate_on_hypercube(1 << n_vars).is_err());
726 }
727
728 #[test]
729 fn test_evaluate_partial_high_low_evaluate_consistent() {
730 let mut rng = StdRng::seed_from_u64(0);
731 let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
732 .take(1 << 8)
733 .collect();
734
735 let me = MultilinearExtension::from_values(values).unwrap();
736
737 let q = repeat_with(|| <BinaryField32b as PackedField>::random(&mut rng))
738 .take(me.n_vars())
739 .collect::<Vec<_>>();
740
741 let query = multilinear_query(&q);
742
743 let eval = me
744 .evaluate::<<PackedBinaryField4x32b as PackedField>::Scalar, PackedBinaryField4x32b>(
745 query.to_ref(),
746 )
747 .unwrap();
748
749 assert_eq!(
750 me.evaluate_partial_low::<PackedBinaryField4x32b>(query.to_ref())
751 .unwrap()
752 .evals[0]
753 .get(0),
754 eval
755 );
756 assert_eq!(
757 me.evaluate_partial_high::<PackedBinaryField4x32b>(query.to_ref())
758 .unwrap()
759 .evals[0]
760 .get(0),
761 eval
762 );
763 }
764
765 #[test]
766 fn test_evaluate_partial_low_single_and_multiple_var_consistent() {
767 let mut rng = StdRng::seed_from_u64(0);
768 let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
769 .take(1 << 8)
770 .collect();
771
772 let mle = MultilinearExtension::from_values(values).unwrap();
773 let r1 = <BinaryField32b as PackedField>::random(&mut rng);
774 let r2 = <BinaryField32b as PackedField>::random(&mut rng);
775
776 let eval_1: MultilinearExtension<PackedBinaryField4x32b> = mle
777 .evaluate_partial_low::<PackedBinaryField4x32b>(multilinear_query(&[r1]).to_ref())
778 .unwrap()
779 .evaluate_partial_low(multilinear_query(&[r2]).to_ref())
780 .unwrap();
781 let eval_2 = mle
782 .evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref())
783 .unwrap();
784 assert_eq!(eval_1, eval_2);
785 }
786
787 #[test]
788 fn test_new_mle_with_tiny_nvars() {
789 MultilinearExtension::new(
790 1,
791 vec![PackedType::<OptimalUnderlier256b, BinaryField32b>::one()],
792 )
793 .unwrap();
794 }
795}