1use std::{fmt::Debug, marker::PhantomData, ops::Deref, sync::Arc};
4
5use binius_field::{
6 packed::{
7 get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
8 },
9 ExtensionField, Field, PackedField, RepackedExtension,
10};
11use binius_utils::bail;
12
13use super::{Error, MultilinearExtension, MultilinearPoly, MultilinearQueryRef};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct MLEEmbeddingAdapter<P, PE, Data = Vec<P>>(
22 MultilinearExtension<P, Data>,
23 PhantomData<PE>,
24)
25where
26 P: PackedField,
27 PE: PackedField,
28 PE::Scalar: ExtensionField<P::Scalar>,
29 Data: Deref<Target = [P]>;
30
31impl<'a, P, PE, Data> MLEEmbeddingAdapter<P, PE, Data>
32where
33 P: PackedField,
34 PE: PackedField + RepackedExtension<P>,
35 PE::Scalar: ExtensionField<P::Scalar>,
36 Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
37{
38 pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
39 Arc::new(self)
40 }
41}
42
43impl<P, PE, Data> From<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
44where
45 P: PackedField,
46 PE: PackedField,
47 PE::Scalar: ExtensionField<P::Scalar>,
48 Data: Deref<Target = [P]>,
49{
50 fn from(inner: MultilinearExtension<P, Data>) -> Self {
51 Self(inner, PhantomData)
52 }
53}
54
55impl<P, PE, Data> AsRef<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
56where
57 P: PackedField,
58 PE: PackedField,
59 PE::Scalar: ExtensionField<P::Scalar>,
60 Data: Deref<Target = [P]>,
61{
62 fn as_ref(&self) -> &MultilinearExtension<P, Data> {
63 &self.0
64 }
65}
66
67impl<P, PE, Data> MultilinearPoly<PE> for MLEEmbeddingAdapter<P, PE, Data>
68where
69 P: PackedField + Debug,
70 PE: PackedField + RepackedExtension<P>,
71 PE::Scalar: ExtensionField<P::Scalar>,
72 Data: Deref<Target = [P]> + Send + Sync + Debug,
73{
74 fn n_vars(&self) -> usize {
75 self.0.n_vars()
76 }
77
78 fn log_extension_degree(&self) -> usize {
79 PE::Scalar::LOG_DEGREE
80 }
81
82 fn evaluate_on_hypercube(&self, index: usize) -> Result<PE::Scalar, Error> {
83 let eval = self.0.evaluate_on_hypercube(index)?;
84 Ok(eval.into())
85 }
86
87 fn evaluate_on_hypercube_and_scale(
88 &self,
89 index: usize,
90 scalar: PE::Scalar,
91 ) -> Result<PE::Scalar, Error> {
92 let eval = self.0.evaluate_on_hypercube(index)?;
93 Ok(scalar * eval)
94 }
95
96 fn evaluate(&self, query: MultilinearQueryRef<PE>) -> Result<PE::Scalar, Error> {
97 self.0.evaluate(query)
98 }
99
100 fn evaluate_partial_low(
101 &self,
102 query: MultilinearQueryRef<PE>,
103 ) -> Result<MultilinearExtension<PE>, Error> {
104 self.0.evaluate_partial_low(query)
105 }
106
107 fn evaluate_partial_high(
108 &self,
109 query: MultilinearQueryRef<PE>,
110 ) -> Result<MultilinearExtension<PE>, Error> {
111 self.0.evaluate_partial_high(query)
112 }
113
114 fn evaluate_partial(
115 &self,
116 query: MultilinearQueryRef<PE>,
117 start_index: usize,
118 ) -> Result<MultilinearExtension<PE>, Error> {
119 self.0.evaluate_partial(query, start_index)
120 }
121
122 fn subcube_partial_low_evals(
123 &self,
124 query: MultilinearQueryRef<PE>,
125 subcube_vars: usize,
126 subcube_index: usize,
127 partial_low_evals: &mut [PE],
128 ) -> Result<(), Error> {
129 validate_subcube_partial_evals_params(
130 self.n_vars(),
131 query,
132 subcube_vars,
133 subcube_index,
134 partial_low_evals,
135 )?;
136
137 let query_n_vars = query.n_vars();
138 let subcube_start = subcube_index << (query_n_vars + subcube_vars);
139
140 for scalar_index in 0..1 << subcube_vars {
143 let evals_start = subcube_start + (scalar_index << query_n_vars);
144 let mut inner_product = PE::Scalar::ZERO;
145 for i in 0..1 << query_n_vars {
146 inner_product += get_packed_slice(query.expansion(), i)
147 * get_packed_slice(self.0.evals(), evals_start + i);
148 }
149
150 set_packed_slice(partial_low_evals, scalar_index, inner_product);
151 }
152
153 Ok(())
154 }
155
156 fn subcube_partial_high_evals(
157 &self,
158 query: MultilinearQueryRef<PE>,
159 subcube_vars: usize,
160 subcube_index: usize,
161 partial_high_evals: &mut [PE],
162 ) -> Result<(), Error> {
163 validate_subcube_partial_evals_params(
164 self.n_vars(),
165 query,
166 subcube_vars,
167 subcube_index,
168 partial_high_evals,
169 )?;
170
171 let query_n_vars = query.n_vars();
172
173 partial_high_evals.fill(PE::zero());
176
177 for query_index in 0..1 << query_n_vars {
178 let query_factor = get_packed_slice(query.expansion(), query_index);
179 let subcube_start =
180 subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
181 for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
182 *packed += PE::from_fn(|inner_index| {
183 let index = subcube_start | outer_index << PE::LOG_WIDTH | inner_index;
184 query_factor * get_packed_slice(self.0.evals(), index)
185 });
186 }
187 }
188
189 if subcube_vars < PE::LOG_WIDTH {
190 for i in 1 << subcube_vars..PE::WIDTH {
191 partial_high_evals
192 .first_mut()
193 .expect("at least one")
194 .set(i, PE::Scalar::ZERO);
195 }
196 }
197
198 Ok(())
199 }
200
201 fn subcube_evals(
202 &self,
203 subcube_vars: usize,
204 subcube_index: usize,
205 log_embedding_degree: usize,
206 evals: &mut [PE],
207 ) -> Result<(), Error> {
208 let log_extension_degree = PE::Scalar::LOG_DEGREE;
209
210 if subcube_vars > self.n_vars() {
211 bail!(Error::ArgumentRangeError {
212 arg: "subcube_vars".to_string(),
213 range: 0..self.n_vars() + 1,
214 });
215 }
216
217 const MAX_TOWER_HEIGHT: usize = 7;
221 if log_embedding_degree > log_extension_degree.min(MAX_TOWER_HEIGHT) {
222 bail!(Error::LogEmbeddingDegreeTooLarge {
223 log_embedding_degree
224 });
225 }
226
227 let correct_len = 1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH);
228 if evals.len() != correct_len {
229 bail!(Error::ArgumentRangeError {
230 arg: "evals.len()".to_string(),
231 range: correct_len..correct_len + 1,
232 });
233 }
234
235 let max_index = 1 << (self.n_vars() - subcube_vars);
236 if subcube_index >= max_index {
237 bail!(Error::ArgumentRangeError {
238 arg: "subcube_index".to_string(),
239 range: 0..max_index,
240 });
241 }
242
243 let subcube_start = subcube_index << subcube_vars;
244
245 if log_embedding_degree == 0 {
246 for i in 0..1 << subcube_vars {
248 let scalar =
250 unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
251
252 let extension_scalar = scalar.into();
253
254 unsafe {
256 set_packed_slice_unchecked(evals, i, extension_scalar);
257 }
258 }
259 } else {
260 let mut bases = [P::Scalar::default(); 1 << MAX_TOWER_HEIGHT];
262 let bases = &mut bases[0..1 << log_embedding_degree];
263
264 let bases_count = 1 << log_embedding_degree.min(subcube_vars);
265 for i in 0..1 << subcube_vars.saturating_sub(log_embedding_degree) {
266 for (j, base) in bases[..bases_count].iter_mut().enumerate() {
267 *base = unsafe {
269 get_packed_slice_unchecked(
270 self.0.evals(),
271 subcube_start + (i << log_embedding_degree) + j,
272 )
273 };
274 }
275
276 let extension_scalar = PE::Scalar::from_bases_sparse(
277 bases.iter().copied(),
278 log_extension_degree - log_embedding_degree,
279 )?;
280
281 unsafe {
283 set_packed_slice_unchecked(evals, i, extension_scalar);
284 }
285 }
286 }
287
288 Ok(())
289 }
290
291 fn packed_evals(&self) -> Option<&[PE]> {
292 Some(PE::cast_exts(self.0.evals()))
293 }
294}
295
296impl<P, Data> MultilinearExtension<P, Data>
297where
298 P: PackedField,
299 Data: Deref<Target = [P]>,
300{
301 pub fn specialize<PE>(self) -> MLEEmbeddingAdapter<P, PE, Data>
302 where
303 PE: PackedField,
304 PE::Scalar: ExtensionField<P::Scalar>,
305 {
306 MLEEmbeddingAdapter::from(self)
307 }
308}
309
310impl<'a, P, Data> MultilinearExtension<P, Data>
311where
312 P: PackedField,
313 Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
314{
315 pub fn specialize_arc_dyn<PE: RepackedExtension<P>>(
316 self,
317 ) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
318 self.specialize().upcast_arc_dyn()
319 }
320}
321
322#[derive(Debug, Clone, PartialEq, Eq)]
325pub struct MLEDirectAdapter<P, Data = Vec<P>>(MultilinearExtension<P, Data>)
326where
327 P: PackedField,
328 Data: Deref<Target = [P]>;
329
330impl<'a, P, Data> MLEDirectAdapter<P, Data>
331where
332 P: PackedField,
333 Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
334{
335 pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<P> + Send + Sync + 'a> {
336 Arc::new(self)
337 }
338}
339
340impl<P, Data> From<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
341where
342 P: PackedField,
343 Data: Deref<Target = [P]>,
344{
345 fn from(inner: MultilinearExtension<P, Data>) -> Self {
346 Self(inner)
347 }
348}
349
350impl<P, Data> AsRef<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
351where
352 P: PackedField,
353 Data: Deref<Target = [P]>,
354{
355 fn as_ref(&self) -> &MultilinearExtension<P, Data> {
356 &self.0
357 }
358}
359
360impl<F, P, Data> MultilinearPoly<P> for MLEDirectAdapter<P, Data>
361where
362 F: Field,
363 P: PackedField<Scalar = F>,
364 Data: Deref<Target = [P]> + Send + Sync + Debug,
365{
366 #[inline]
367 fn n_vars(&self) -> usize {
368 self.0.n_vars()
369 }
370
371 #[inline]
372 fn log_extension_degree(&self) -> usize {
373 0
374 }
375
376 fn evaluate_on_hypercube(&self, index: usize) -> Result<F, Error> {
377 self.0.evaluate_on_hypercube(index)
378 }
379
380 fn evaluate_on_hypercube_and_scale(&self, index: usize, scalar: F) -> Result<F, Error> {
381 let eval = self.0.evaluate_on_hypercube(index)?;
382 Ok(scalar * eval)
383 }
384
385 fn evaluate(&self, query: MultilinearQueryRef<P>) -> Result<F, Error> {
386 self.0.evaluate(query)
387 }
388
389 fn evaluate_partial_low(
390 &self,
391 query: MultilinearQueryRef<P>,
392 ) -> Result<MultilinearExtension<P>, Error> {
393 self.0.evaluate_partial_low(query)
394 }
395
396 fn evaluate_partial_high(
397 &self,
398 query: MultilinearQueryRef<P>,
399 ) -> Result<MultilinearExtension<P>, Error> {
400 self.0.evaluate_partial_high(query)
401 }
402
403 fn evaluate_partial(
404 &self,
405 query: MultilinearQueryRef<P>,
406 start_index: usize,
407 ) -> Result<MultilinearExtension<P>, Error> {
408 self.0.evaluate_partial(query, start_index)
409 }
410
411 fn subcube_partial_low_evals(
412 &self,
413 query: MultilinearQueryRef<P>,
414 subcube_vars: usize,
415 subcube_index: usize,
416 partial_low_evals: &mut [P],
417 ) -> Result<(), Error> {
418 validate_subcube_partial_evals_params(
420 self.n_vars(),
421 query,
422 subcube_vars,
423 subcube_index,
424 partial_low_evals,
425 )?;
426
427 let query_n_vars = query.n_vars();
428 let subcube_start = subcube_index << (query_n_vars + subcube_vars);
429
430 for scalar_index in 0..1 << subcube_vars {
432 let evals_start = subcube_start + (scalar_index << query_n_vars);
433 let mut inner_product = F::ZERO;
434 for i in 0..1 << query_n_vars {
435 inner_product += get_packed_slice(query.expansion(), i)
436 * get_packed_slice(self.0.evals(), evals_start + i);
437 }
438
439 set_packed_slice(partial_low_evals, scalar_index, inner_product);
440 }
441
442 Ok(())
443 }
444
445 fn subcube_partial_high_evals(
446 &self,
447 query: MultilinearQueryRef<P>,
448 subcube_vars: usize,
449 subcube_index: usize,
450 partial_high_evals: &mut [P],
451 ) -> Result<(), Error> {
452 validate_subcube_partial_evals_params(
454 self.n_vars(),
455 query,
456 subcube_vars,
457 subcube_index,
458 partial_high_evals,
459 )?;
460
461 let query_n_vars = query.n_vars();
462
463 partial_high_evals.fill(P::zero());
466
467 for query_index in 0..1 << query_n_vars {
468 let query_factor = get_packed_slice(query.expansion(), query_index);
469 let subcube_start =
470 subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
471 for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
472 *packed += P::from_fn(|inner_index| {
473 let index = subcube_start | outer_index << P::LOG_WIDTH | inner_index;
474 query_factor * get_packed_slice(self.0.evals(), index)
475 });
476 }
477 }
478
479 if subcube_vars < P::LOG_WIDTH {
480 for i in 1 << subcube_vars..P::WIDTH {
481 partial_high_evals
482 .first_mut()
483 .expect("at least one")
484 .set(i, P::Scalar::ZERO);
485 }
486 }
487
488 Ok(())
489 }
490
491 fn subcube_evals(
492 &self,
493 subcube_vars: usize,
494 subcube_index: usize,
495 log_embedding_degree: usize,
496 evals: &mut [P],
497 ) -> Result<(), Error> {
498 let n_vars = self.n_vars();
499 if subcube_vars > n_vars {
500 bail!(Error::ArgumentRangeError {
501 arg: "subcube_vars".to_string(),
502 range: 0..n_vars + 1,
503 });
504 }
505
506 if log_embedding_degree != 0 {
507 bail!(Error::LogEmbeddingDegreeTooLarge {
508 log_embedding_degree
509 });
510 }
511
512 let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
513 if evals.len() != correct_len {
514 bail!(Error::ArgumentRangeError {
515 arg: "evals.len()".to_string(),
516 range: correct_len..correct_len + 1,
517 });
518 }
519
520 let max_index = 1 << (n_vars - subcube_vars);
521 if subcube_index >= max_index {
522 bail!(Error::ArgumentRangeError {
523 arg: "subcube_index".to_string(),
524 range: 0..max_index,
525 });
526 }
527
528 if subcube_vars < P::LOG_WIDTH {
529 let subcube_start = subcube_index << subcube_vars;
530 for i in 0..1 << subcube_vars {
531 let scalar =
533 unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
534
535 unsafe {
537 set_packed_slice_unchecked(evals, i, scalar);
538 }
539 }
540 } else {
541 let range = subcube_index << (subcube_vars - P::LOG_WIDTH)
542 ..(subcube_index + 1) << (subcube_vars - P::LOG_WIDTH);
543 evals.copy_from_slice(&self.0.evals()[range]);
544 }
545
546 Ok(())
547 }
548
549 fn packed_evals(&self) -> Option<&[P]> {
550 Some(self.0.evals())
551 }
552}
553
554fn validate_subcube_partial_evals_params<P: PackedField>(
555 n_vars: usize,
556 query: MultilinearQueryRef<P>,
557 subcube_vars: usize,
558 subcube_index: usize,
559 partial_evals: &[P],
560) -> Result<(), Error> {
561 let query_n_vars = query.n_vars();
562 if query_n_vars + subcube_vars > n_vars {
563 bail!(Error::ArgumentRangeError {
564 arg: "query.n_vars() + subcube_vars".into(),
565 range: 0..n_vars,
566 });
567 }
568
569 let max_index = 1 << (n_vars - query_n_vars - subcube_vars);
570 if subcube_index >= max_index {
571 bail!(Error::ArgumentRangeError {
572 arg: "subcube_index".into(),
573 range: 0..max_index,
574 });
575 }
576
577 let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
578 if partial_evals.len() != correct_len {
579 bail!(Error::ArgumentRangeError {
580 arg: "partial_evals.len()".to_string(),
581 range: correct_len..correct_len + 1,
582 });
583 }
584
585 Ok(())
586}
587
588#[cfg(test)]
589mod tests {
590 use std::iter::repeat_with;
591
592 use binius_field::{
593 arch::OptimalUnderlier256b, as_packed_field::PackedType, BinaryField128b, BinaryField16b,
594 BinaryField32b, BinaryField8b, PackedBinaryField16x8b, PackedBinaryField1x128b,
595 PackedBinaryField4x128b, PackedBinaryField4x32b, PackedBinaryField64x8b,
596 PackedBinaryField8x16b, PackedExtension, PackedField, PackedFieldIndexable,
597 };
598 use rand::prelude::*;
599
600 use super::*;
601 use crate::{tensor_prod_eq_ind, MultilinearQuery};
602
603 type F = BinaryField16b;
604 type P = PackedBinaryField8x16b;
605
606 fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
607 let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
608 result[0] = P::set_single(P::Scalar::ONE);
609 tensor_prod_eq_ind(0, &mut result, p).unwrap();
610 MultilinearQuery::with_expansion(p.len(), result).unwrap()
611 }
612
613 #[test]
614 fn test_evaluate_subcube_and_evaluate_partial_consistent() {
615 let mut rng = StdRng::seed_from_u64(0);
616 let poly = MultilinearExtension::from_values(
617 repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
618 .take(1 << 8)
619 .collect(),
620 )
621 .unwrap()
622 .specialize::<PackedBinaryField1x128b>();
623
624 let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
625 .take(6)
626 .collect::<Vec<_>>();
627 let query = multilinear_query(&q);
628
629 let partial_low = poly.evaluate_partial_low(query.to_ref()).unwrap();
630 let partial_high = poly.evaluate_partial_high(query.to_ref()).unwrap();
631
632 let mut subcube_partial_low = vec![PackedBinaryField1x128b::zero(); 16];
633 let mut subcube_partial_high = vec![PackedBinaryField1x128b::zero(); 16];
634 poly.subcube_partial_low_evals(query.to_ref(), 4, 0, &mut subcube_partial_low)
635 .unwrap();
636 poly.subcube_partial_high_evals(query.to_ref(), 4, 0, &mut subcube_partial_high)
637 .unwrap();
638
639 for (idx, subcube_partial_low) in PackedField::iter_slice(&subcube_partial_low).enumerate()
640 {
641 assert_eq!(subcube_partial_low, partial_low.evaluate_on_hypercube(idx).unwrap(),);
642 }
643
644 for (idx, subcube_partial_high) in
645 PackedField::iter_slice(&subcube_partial_high).enumerate()
646 {
647 assert_eq!(subcube_partial_high, partial_high.evaluate_on_hypercube(idx).unwrap(),);
648 }
649 }
650
651 #[test]
652 fn test_evaluate_subcube_smaller_than_packed_width() {
653 let mut rng = StdRng::seed_from_u64(0);
654 let poly = MultilinearExtension::new(
655 2,
656 vec![PackedBinaryField64x8b::from_scalars(
657 [2, 2, 9, 9].map(BinaryField8b::new),
658 )],
659 )
660 .unwrap()
661 .specialize::<PackedBinaryField4x128b>();
662
663 let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
664 .take(1)
665 .collect::<Vec<_>>();
666 let query = multilinear_query(&q);
667
668 let mut subcube_partial_low = vec![PackedBinaryField4x128b::zero(); 1];
669 let mut subcube_partial_high = vec![PackedBinaryField4x128b::zero(); 1];
670 poly.subcube_partial_low_evals(query.to_ref(), 1, 0, &mut subcube_partial_low)
671 .unwrap();
672 poly.subcube_partial_high_evals(query.to_ref(), 1, 0, &mut subcube_partial_high)
673 .unwrap();
674
675 let expected_partial_high = BinaryField128b::new(2)
676 + (BinaryField128b::new(2) + BinaryField128b::new(9)) * q.first().unwrap();
677
678 assert_eq!(get_packed_slice(&subcube_partial_low, 0), BinaryField128b::new(2));
679 assert_eq!(get_packed_slice(&subcube_partial_low, 1), BinaryField128b::new(9));
680 assert_eq!(get_packed_slice(&subcube_partial_low, 2), BinaryField128b::ZERO);
681 assert_eq!(get_packed_slice(&subcube_partial_low, 3), BinaryField128b::ZERO);
682 assert_eq!(get_packed_slice(&subcube_partial_high, 0), expected_partial_high);
683 assert_eq!(get_packed_slice(&subcube_partial_high, 1), expected_partial_high);
684 assert_eq!(get_packed_slice(&subcube_partial_high, 2), BinaryField128b::ZERO);
685 assert_eq!(get_packed_slice(&subcube_partial_high, 3), BinaryField128b::ZERO);
686 }
687
688 #[test]
689 fn test_subcube_evals_embeds_correctly() {
690 let mut rng = StdRng::seed_from_u64(0);
691
692 type P = PackedBinaryField16x8b;
693 type PE = PackedBinaryField1x128b;
694
695 let packed_count = 4;
696 let values: Vec<_> = repeat_with(|| P::random(&mut rng))
697 .take(1 << packed_count)
698 .collect();
699
700 let mle = MultilinearExtension::from_values(values).unwrap();
701 let mles = MLEEmbeddingAdapter::<P, PE, _>::from(mle);
702
703 let bytes_values = P::unpack_scalars(mles.0.evals());
704
705 let n_vars = packed_count + P::LOG_WIDTH;
706 let mut evals = vec![PE::zero(); 1 << n_vars];
707 for subcube_vars in 0..n_vars {
708 for subcube_index in 0..1 << (n_vars - subcube_vars) {
709 for log_embedding_degree in 0..=4 {
710 let evals_subcube = &mut evals
711 [0..1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH)];
712
713 mles.subcube_evals(
714 subcube_vars,
715 subcube_index,
716 log_embedding_degree,
717 evals_subcube,
718 )
719 .unwrap();
720
721 let bytes_evals = P::unpack_scalars(
722 <PE as PackedExtension<BinaryField8b>>::cast_bases(evals_subcube),
723 );
724
725 let shift = 4 - log_embedding_degree;
726 let skip_mask = (1 << shift) - 1;
727 for (i, &b_evals) in bytes_evals.iter().enumerate() {
728 let b_values = if i & skip_mask == 0 && i < 1 << (subcube_vars + shift) {
729 bytes_values[(subcube_index << subcube_vars) + (i >> shift)]
730 } else {
731 BinaryField8b::ZERO
732 };
733 assert_eq!(b_evals, b_values);
734 }
735 }
736 }
737 }
738 }
739
740 #[test]
741 fn test_subcube_partial_and_evaluate_partial_conform() {
742 let mut rng = StdRng::seed_from_u64(0);
743 let n_vars = 12;
744 let evals = repeat_with(|| P::random(&mut rng))
745 .take(1 << (n_vars - P::LOG_WIDTH))
746 .collect::<Vec<_>>();
747 let mle = MultilinearExtension::from_values(evals).unwrap();
748 let mles = MLEDirectAdapter::from(mle);
749 let q = repeat_with(|| Field::random(&mut rng))
750 .take(6)
751 .collect::<Vec<F>>();
752 let query = multilinear_query(&q);
753 let partial_low_eval = mles.evaluate_partial_low(query.to_ref()).unwrap();
754 let partial_high_eval = mles.evaluate_partial_high(query.to_ref()).unwrap();
755
756 let subcube_vars = 4;
757 let mut subcube_partial_low_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
758 let mut subcube_partial_high_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
759 for subcube_index in 0..(n_vars - query.n_vars() - subcube_vars) {
760 mles.subcube_partial_low_evals(
761 query.to_ref(),
762 subcube_vars,
763 subcube_index,
764 &mut subcube_partial_low_evals,
765 )
766 .unwrap();
767 mles.subcube_partial_high_evals(
768 query.to_ref(),
769 subcube_vars,
770 subcube_index,
771 &mut subcube_partial_high_evals,
772 )
773 .unwrap();
774 for hypercube_idx in 0..(1 << subcube_vars) {
775 assert_eq!(
776 get_packed_slice(&subcube_partial_low_evals, hypercube_idx),
777 partial_low_eval
778 .evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
779 .unwrap()
780 );
781 assert_eq!(
782 get_packed_slice(&subcube_partial_high_evals, hypercube_idx),
783 partial_high_eval
784 .evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
785 .unwrap()
786 );
787 }
788 }
789 }
790
791 #[test]
792 fn test_packed_evals_against_subcube_evals() {
793 type U = OptimalUnderlier256b;
794 type P = PackedType<U, BinaryField32b>;
795 type PExt = PackedType<U, BinaryField128b>;
796
797 let mut rng = StdRng::seed_from_u64(0);
798 let evals = repeat_with(|| P::random(&mut rng))
799 .take(2)
800 .collect::<Vec<_>>();
801 let mle = MultilinearExtension::from_values(evals.clone()).unwrap();
802 let poly = MLEEmbeddingAdapter::from(mle);
803 assert_eq!(
804 <PExt as PackedExtension<BinaryField32b>>::cast_bases(poly.packed_evals().unwrap()),
805 &evals
806 );
807
808 let mut evals_out = vec![PExt::zero(); 2];
809 poly.subcube_evals(poly.n_vars(), 0, poly.log_extension_degree(), evals_out.as_mut_slice())
810 .unwrap();
811 assert_eq!(evals_out, poly.packed_evals().unwrap());
812 }
813}