1use std::{fmt::Debug, marker::PhantomData, ops::Deref, sync::Arc};
4
5use binius_field::{
6 packed::{
7 get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
8 },
9 ExtensionField, Field, PackedField, RepackedExtension,
10};
11use binius_utils::bail;
12
13use super::{Error, MultilinearExtension, MultilinearPoly, MultilinearQueryRef};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct MLEEmbeddingAdapter<P, PE, Data = Vec<P>>(
22 MultilinearExtension<P, Data>,
23 PhantomData<PE>,
24)
25where
26 P: PackedField,
27 PE: PackedField,
28 PE::Scalar: ExtensionField<P::Scalar>,
29 Data: Deref<Target = [P]>;
30
31impl<'a, P, PE, Data> MLEEmbeddingAdapter<P, PE, Data>
32where
33 P: PackedField,
34 PE: PackedField + RepackedExtension<P>,
35 PE::Scalar: ExtensionField<P::Scalar>,
36 Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
37{
38 pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
39 Arc::new(self)
40 }
41}
42
43impl<P, PE, Data> From<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
44where
45 P: PackedField,
46 PE: PackedField,
47 PE::Scalar: ExtensionField<P::Scalar>,
48 Data: Deref<Target = [P]>,
49{
50 fn from(inner: MultilinearExtension<P, Data>) -> Self {
51 Self(inner, PhantomData)
52 }
53}
54
55impl<P, PE, Data> AsRef<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
56where
57 P: PackedField,
58 PE: PackedField,
59 PE::Scalar: ExtensionField<P::Scalar>,
60 Data: Deref<Target = [P]>,
61{
62 fn as_ref(&self) -> &MultilinearExtension<P, Data> {
63 &self.0
64 }
65}
66
67impl<P, PE, Data> MultilinearPoly<PE> for MLEEmbeddingAdapter<P, PE, Data>
68where
69 P: PackedField + Debug,
70 PE: PackedField + RepackedExtension<P>,
71 PE::Scalar: ExtensionField<P::Scalar>,
72 Data: Deref<Target = [P]> + Send + Sync + Debug,
73{
74 fn n_vars(&self) -> usize {
75 self.0.n_vars()
76 }
77
78 fn log_extension_degree(&self) -> usize {
79 PE::Scalar::LOG_DEGREE
80 }
81
82 fn evaluate_on_hypercube(&self, index: usize) -> Result<PE::Scalar, Error> {
83 let eval = self.0.evaluate_on_hypercube(index)?;
84 Ok(eval.into())
85 }
86
87 fn evaluate_on_hypercube_and_scale(
88 &self,
89 index: usize,
90 scalar: PE::Scalar,
91 ) -> Result<PE::Scalar, Error> {
92 let eval = self.0.evaluate_on_hypercube(index)?;
93 Ok(scalar * eval)
94 }
95
96 fn evaluate(&self, query: MultilinearQueryRef<PE>) -> Result<PE::Scalar, Error> {
97 self.0.evaluate(query)
98 }
99
100 fn evaluate_partial_low(
101 &self,
102 query: MultilinearQueryRef<PE>,
103 ) -> Result<MultilinearExtension<PE>, Error> {
104 self.0.evaluate_partial_low(query)
105 }
106
107 fn evaluate_partial_high(
108 &self,
109 query: MultilinearQueryRef<PE>,
110 ) -> Result<MultilinearExtension<PE>, Error> {
111 self.0.evaluate_partial_high(query)
112 }
113
114 fn evaluate_partial(
115 &self,
116 query: MultilinearQueryRef<PE>,
117 start_index: usize,
118 ) -> Result<MultilinearExtension<PE>, Error> {
119 self.0.evaluate_partial(query, start_index)
120 }
121
122 fn zero_pad(
123 &self,
124 n_pad_vars: usize,
125 start_index: usize,
126 nonzero_index: usize,
127 ) -> Result<MultilinearExtension<PE>, Error> {
128 self.0.zero_pad(n_pad_vars, start_index, nonzero_index)
129 }
130
131 fn subcube_partial_low_evals(
132 &self,
133 query: MultilinearQueryRef<PE>,
134 subcube_vars: usize,
135 subcube_index: usize,
136 partial_low_evals: &mut [PE],
137 ) -> Result<(), Error> {
138 validate_subcube_partial_evals_params(
139 self.n_vars(),
140 query,
141 subcube_vars,
142 subcube_index,
143 partial_low_evals,
144 )?;
145
146 let query_n_vars = query.n_vars();
147 let subcube_start = subcube_index << (query_n_vars + subcube_vars);
148
149 for scalar_index in 0..1 << subcube_vars {
152 let evals_start = subcube_start + (scalar_index << query_n_vars);
153 let mut inner_product = PE::Scalar::ZERO;
154 for i in 0..1 << query_n_vars {
155 inner_product += get_packed_slice(query.expansion(), i)
156 * get_packed_slice(self.0.evals(), evals_start + i);
157 }
158
159 set_packed_slice(partial_low_evals, scalar_index, inner_product);
160 }
161
162 Ok(())
163 }
164
165 fn subcube_partial_high_evals(
166 &self,
167 query: MultilinearQueryRef<PE>,
168 subcube_vars: usize,
169 subcube_index: usize,
170 partial_high_evals: &mut [PE],
171 ) -> Result<(), Error> {
172 validate_subcube_partial_evals_params(
173 self.n_vars(),
174 query,
175 subcube_vars,
176 subcube_index,
177 partial_high_evals,
178 )?;
179
180 let query_n_vars = query.n_vars();
181
182 partial_high_evals.fill(PE::zero());
185
186 for query_index in 0..1 << query_n_vars {
187 let query_factor = get_packed_slice(query.expansion(), query_index);
188 let subcube_start =
189 subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
190 for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
191 *packed += PE::from_fn(|inner_index| {
192 let index = subcube_start | outer_index << PE::LOG_WIDTH | inner_index;
193 query_factor * get_packed_slice(self.0.evals(), index)
194 });
195 }
196 }
197
198 if subcube_vars < PE::LOG_WIDTH {
199 for i in 1 << subcube_vars..PE::WIDTH {
200 partial_high_evals
201 .first_mut()
202 .expect("at least one")
203 .set(i, PE::Scalar::ZERO);
204 }
205 }
206
207 Ok(())
208 }
209
210 fn subcube_evals(
211 &self,
212 subcube_vars: usize,
213 subcube_index: usize,
214 log_embedding_degree: usize,
215 evals: &mut [PE],
216 ) -> Result<(), Error> {
217 let log_extension_degree = PE::Scalar::LOG_DEGREE;
218
219 if subcube_vars > self.n_vars() {
220 bail!(Error::ArgumentRangeError {
221 arg: "subcube_vars".to_string(),
222 range: 0..self.n_vars() + 1,
223 });
224 }
225
226 const MAX_TOWER_HEIGHT: usize = 7;
230 if log_embedding_degree > log_extension_degree.min(MAX_TOWER_HEIGHT) {
231 bail!(Error::LogEmbeddingDegreeTooLarge {
232 log_embedding_degree
233 });
234 }
235
236 let correct_len = 1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH);
237 if evals.len() != correct_len {
238 bail!(Error::ArgumentRangeError {
239 arg: "evals.len()".to_string(),
240 range: correct_len..correct_len + 1,
241 });
242 }
243
244 let max_index = 1 << (self.n_vars() - subcube_vars);
245 if subcube_index >= max_index {
246 bail!(Error::ArgumentRangeError {
247 arg: "subcube_index".to_string(),
248 range: 0..max_index,
249 });
250 }
251
252 let subcube_start = subcube_index << subcube_vars;
253
254 if log_embedding_degree == 0 {
255 for i in 0..1 << subcube_vars {
257 let scalar =
259 unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
260
261 let extension_scalar = scalar.into();
262
263 unsafe {
265 set_packed_slice_unchecked(evals, i, extension_scalar);
266 }
267 }
268 } else {
269 let mut bases = [P::Scalar::default(); 1 << MAX_TOWER_HEIGHT];
271 let bases = &mut bases[0..1 << log_embedding_degree];
272
273 let bases_count = 1 << log_embedding_degree.min(subcube_vars);
274 for i in 0..1 << subcube_vars.saturating_sub(log_embedding_degree) {
275 for (j, base) in bases[..bases_count].iter_mut().enumerate() {
276 *base = unsafe {
278 get_packed_slice_unchecked(
279 self.0.evals(),
280 subcube_start + (i << log_embedding_degree) + j,
281 )
282 };
283 }
284
285 let extension_scalar = PE::Scalar::from_bases_sparse(
286 bases.iter().copied(),
287 log_extension_degree - log_embedding_degree,
288 )?;
289
290 unsafe {
292 set_packed_slice_unchecked(evals, i, extension_scalar);
293 }
294 }
295 }
296
297 Ok(())
298 }
299
300 fn packed_evals(&self) -> Option<&[PE]> {
301 Some(PE::cast_exts(self.0.evals()))
302 }
303}
304
305impl<P, Data> MultilinearExtension<P, Data>
306where
307 P: PackedField,
308 Data: Deref<Target = [P]>,
309{
310 pub fn specialize<PE>(self) -> MLEEmbeddingAdapter<P, PE, Data>
311 where
312 PE: PackedField,
313 PE::Scalar: ExtensionField<P::Scalar>,
314 {
315 MLEEmbeddingAdapter::from(self)
316 }
317}
318
319impl<'a, P, Data> MultilinearExtension<P, Data>
320where
321 P: PackedField,
322 Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
323{
324 pub fn specialize_arc_dyn<PE: RepackedExtension<P>>(
325 self,
326 ) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
327 self.specialize().upcast_arc_dyn()
328 }
329}
330
331#[derive(Debug, Clone, PartialEq, Eq)]
334pub struct MLEDirectAdapter<P, Data = Vec<P>>(MultilinearExtension<P, Data>)
335where
336 P: PackedField,
337 Data: Deref<Target = [P]>;
338
339impl<'a, P, Data> MLEDirectAdapter<P, Data>
340where
341 P: PackedField,
342 Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
343{
344 pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<P> + Send + Sync + 'a> {
345 Arc::new(self)
346 }
347}
348
349impl<P, Data> From<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
350where
351 P: PackedField,
352 Data: Deref<Target = [P]>,
353{
354 fn from(inner: MultilinearExtension<P, Data>) -> Self {
355 Self(inner)
356 }
357}
358
359impl<P, Data> AsRef<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
360where
361 P: PackedField,
362 Data: Deref<Target = [P]>,
363{
364 fn as_ref(&self) -> &MultilinearExtension<P, Data> {
365 &self.0
366 }
367}
368
369impl<F, P, Data> MultilinearPoly<P> for MLEDirectAdapter<P, Data>
370where
371 F: Field,
372 P: PackedField<Scalar = F>,
373 Data: Deref<Target = [P]> + Send + Sync + Debug,
374{
375 #[inline]
376 fn n_vars(&self) -> usize {
377 self.0.n_vars()
378 }
379
380 #[inline]
381 fn log_extension_degree(&self) -> usize {
382 0
383 }
384
385 fn evaluate_on_hypercube(&self, index: usize) -> Result<F, Error> {
386 self.0.evaluate_on_hypercube(index)
387 }
388
389 fn evaluate_on_hypercube_and_scale(&self, index: usize, scalar: F) -> Result<F, Error> {
390 let eval = self.0.evaluate_on_hypercube(index)?;
391 Ok(scalar * eval)
392 }
393
394 fn evaluate(&self, query: MultilinearQueryRef<P>) -> Result<F, Error> {
395 self.0.evaluate(query)
396 }
397
398 fn evaluate_partial_low(
399 &self,
400 query: MultilinearQueryRef<P>,
401 ) -> Result<MultilinearExtension<P>, Error> {
402 self.0.evaluate_partial_low(query)
403 }
404
405 fn evaluate_partial_high(
406 &self,
407 query: MultilinearQueryRef<P>,
408 ) -> Result<MultilinearExtension<P>, Error> {
409 self.0.evaluate_partial_high(query)
410 }
411
412 fn evaluate_partial(
413 &self,
414 query: MultilinearQueryRef<P>,
415 start_index: usize,
416 ) -> Result<MultilinearExtension<P>, Error> {
417 self.0.evaluate_partial(query, start_index)
418 }
419
420 fn zero_pad(
421 &self,
422 n_pad_vars: usize,
423 start_index: usize,
424 nonzero_padding: usize,
425 ) -> Result<MultilinearExtension<P>, Error> {
426 self.0.zero_pad(n_pad_vars, start_index, nonzero_padding)
427 }
428
429 fn subcube_partial_low_evals(
430 &self,
431 query: MultilinearQueryRef<P>,
432 subcube_vars: usize,
433 subcube_index: usize,
434 partial_low_evals: &mut [P],
435 ) -> Result<(), Error> {
436 validate_subcube_partial_evals_params(
438 self.n_vars(),
439 query,
440 subcube_vars,
441 subcube_index,
442 partial_low_evals,
443 )?;
444
445 let query_n_vars = query.n_vars();
446 let subcube_start = subcube_index << (query_n_vars + subcube_vars);
447
448 for scalar_index in 0..1 << subcube_vars {
450 let evals_start = subcube_start + (scalar_index << query_n_vars);
451 let mut inner_product = F::ZERO;
452 for i in 0..1 << query_n_vars {
453 inner_product += get_packed_slice(query.expansion(), i)
454 * get_packed_slice(self.0.evals(), evals_start + i);
455 }
456
457 set_packed_slice(partial_low_evals, scalar_index, inner_product);
458 }
459
460 Ok(())
461 }
462
463 fn subcube_partial_high_evals(
464 &self,
465 query: MultilinearQueryRef<P>,
466 subcube_vars: usize,
467 subcube_index: usize,
468 partial_high_evals: &mut [P],
469 ) -> Result<(), Error> {
470 validate_subcube_partial_evals_params(
472 self.n_vars(),
473 query,
474 subcube_vars,
475 subcube_index,
476 partial_high_evals,
477 )?;
478
479 let query_n_vars = query.n_vars();
480
481 partial_high_evals.fill(P::zero());
484
485 for query_index in 0..1 << query_n_vars {
486 let query_factor = get_packed_slice(query.expansion(), query_index);
487 let subcube_start =
488 subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
489 for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
490 *packed += P::from_fn(|inner_index| {
491 let index = subcube_start | outer_index << P::LOG_WIDTH | inner_index;
492 query_factor * get_packed_slice(self.0.evals(), index)
493 });
494 }
495 }
496
497 if subcube_vars < P::LOG_WIDTH {
498 for i in 1 << subcube_vars..P::WIDTH {
499 partial_high_evals
500 .first_mut()
501 .expect("at least one")
502 .set(i, P::Scalar::ZERO);
503 }
504 }
505
506 Ok(())
507 }
508
509 fn subcube_evals(
510 &self,
511 subcube_vars: usize,
512 subcube_index: usize,
513 log_embedding_degree: usize,
514 evals: &mut [P],
515 ) -> Result<(), Error> {
516 let n_vars = self.n_vars();
517 if subcube_vars > n_vars {
518 bail!(Error::ArgumentRangeError {
519 arg: "subcube_vars".to_string(),
520 range: 0..n_vars + 1,
521 });
522 }
523
524 if log_embedding_degree != 0 {
525 bail!(Error::LogEmbeddingDegreeTooLarge {
526 log_embedding_degree
527 });
528 }
529
530 let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
531 if evals.len() != correct_len {
532 bail!(Error::ArgumentRangeError {
533 arg: "evals.len()".to_string(),
534 range: correct_len..correct_len + 1,
535 });
536 }
537
538 let max_index = 1 << (n_vars - subcube_vars);
539 if subcube_index >= max_index {
540 bail!(Error::ArgumentRangeError {
541 arg: "subcube_index".to_string(),
542 range: 0..max_index,
543 });
544 }
545
546 if subcube_vars < P::LOG_WIDTH {
547 let subcube_start = subcube_index << subcube_vars;
548 for i in 0..1 << subcube_vars {
549 let scalar =
551 unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
552
553 unsafe {
555 set_packed_slice_unchecked(evals, i, scalar);
556 }
557 }
558 } else {
559 let range = subcube_index << (subcube_vars - P::LOG_WIDTH)
560 ..(subcube_index + 1) << (subcube_vars - P::LOG_WIDTH);
561 evals.copy_from_slice(&self.0.evals()[range]);
562 }
563
564 Ok(())
565 }
566
567 fn packed_evals(&self) -> Option<&[P]> {
568 Some(self.0.evals())
569 }
570}
571
572fn validate_subcube_partial_evals_params<P: PackedField>(
573 n_vars: usize,
574 query: MultilinearQueryRef<P>,
575 subcube_vars: usize,
576 subcube_index: usize,
577 partial_evals: &[P],
578) -> Result<(), Error> {
579 let query_n_vars = query.n_vars();
580 if query_n_vars + subcube_vars > n_vars {
581 bail!(Error::ArgumentRangeError {
582 arg: "query.n_vars() + subcube_vars".into(),
583 range: 0..n_vars,
584 });
585 }
586
587 let max_index = 1 << (n_vars - query_n_vars - subcube_vars);
588 if subcube_index >= max_index {
589 bail!(Error::ArgumentRangeError {
590 arg: "subcube_index".into(),
591 range: 0..max_index,
592 });
593 }
594
595 let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
596 if partial_evals.len() != correct_len {
597 bail!(Error::ArgumentRangeError {
598 arg: "partial_evals.len()".to_string(),
599 range: correct_len..correct_len + 1,
600 });
601 }
602
603 Ok(())
604}
605
606#[cfg(test)]
607mod tests {
608 use std::iter::repeat_with;
609
610 use binius_field::{
611 arch::OptimalUnderlier256b, as_packed_field::PackedType, BinaryField128b, BinaryField16b,
612 BinaryField32b, BinaryField8b, PackedBinaryField16x8b, PackedBinaryField1x128b,
613 PackedBinaryField4x128b, PackedBinaryField4x32b, PackedBinaryField64x8b,
614 PackedBinaryField8x16b, PackedExtension, PackedField, PackedFieldIndexable,
615 };
616 use rand::prelude::*;
617
618 use super::*;
619 use crate::{tensor_prod_eq_ind, MultilinearQuery};
620
621 type F = BinaryField16b;
622 type P = PackedBinaryField8x16b;
623
624 fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
625 let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
626 result[0] = P::set_single(P::Scalar::ONE);
627 tensor_prod_eq_ind(0, &mut result, p).unwrap();
628 MultilinearQuery::with_expansion(p.len(), result).unwrap()
629 }
630
631 #[test]
632 fn test_evaluate_subcube_and_evaluate_partial_consistent() {
633 let mut rng = StdRng::seed_from_u64(0);
634 let poly = MultilinearExtension::from_values(
635 repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
636 .take(1 << 8)
637 .collect(),
638 )
639 .unwrap()
640 .specialize::<PackedBinaryField1x128b>();
641
642 let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
643 .take(6)
644 .collect::<Vec<_>>();
645 let query = multilinear_query(&q);
646
647 let partial_low = poly.evaluate_partial_low(query.to_ref()).unwrap();
648 let partial_high = poly.evaluate_partial_high(query.to_ref()).unwrap();
649
650 let mut subcube_partial_low = vec![PackedBinaryField1x128b::zero(); 16];
651 let mut subcube_partial_high = vec![PackedBinaryField1x128b::zero(); 16];
652 poly.subcube_partial_low_evals(query.to_ref(), 4, 0, &mut subcube_partial_low)
653 .unwrap();
654 poly.subcube_partial_high_evals(query.to_ref(), 4, 0, &mut subcube_partial_high)
655 .unwrap();
656
657 for (idx, subcube_partial_low) in PackedField::iter_slice(&subcube_partial_low).enumerate()
658 {
659 assert_eq!(subcube_partial_low, partial_low.evaluate_on_hypercube(idx).unwrap(),);
660 }
661
662 for (idx, subcube_partial_high) in
663 PackedField::iter_slice(&subcube_partial_high).enumerate()
664 {
665 assert_eq!(subcube_partial_high, partial_high.evaluate_on_hypercube(idx).unwrap(),);
666 }
667 }
668
669 #[test]
670 fn test_evaluate_subcube_smaller_than_packed_width() {
671 let mut rng = StdRng::seed_from_u64(0);
672 let poly = MultilinearExtension::new(
673 2,
674 vec![PackedBinaryField64x8b::from_scalars(
675 [2, 2, 9, 9].map(BinaryField8b::new),
676 )],
677 )
678 .unwrap()
679 .specialize::<PackedBinaryField4x128b>();
680
681 let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
682 .take(1)
683 .collect::<Vec<_>>();
684 let query = multilinear_query(&q);
685
686 let mut subcube_partial_low = vec![PackedBinaryField4x128b::zero(); 1];
687 let mut subcube_partial_high = vec![PackedBinaryField4x128b::zero(); 1];
688 poly.subcube_partial_low_evals(query.to_ref(), 1, 0, &mut subcube_partial_low)
689 .unwrap();
690 poly.subcube_partial_high_evals(query.to_ref(), 1, 0, &mut subcube_partial_high)
691 .unwrap();
692
693 let expected_partial_high = BinaryField128b::new(2)
694 + (BinaryField128b::new(2) + BinaryField128b::new(9)) * q.first().unwrap();
695
696 assert_eq!(get_packed_slice(&subcube_partial_low, 0), BinaryField128b::new(2));
697 assert_eq!(get_packed_slice(&subcube_partial_low, 1), BinaryField128b::new(9));
698 assert_eq!(get_packed_slice(&subcube_partial_low, 2), BinaryField128b::ZERO);
699 assert_eq!(get_packed_slice(&subcube_partial_low, 3), BinaryField128b::ZERO);
700 assert_eq!(get_packed_slice(&subcube_partial_high, 0), expected_partial_high);
701 assert_eq!(get_packed_slice(&subcube_partial_high, 1), expected_partial_high);
702 assert_eq!(get_packed_slice(&subcube_partial_high, 2), BinaryField128b::ZERO);
703 assert_eq!(get_packed_slice(&subcube_partial_high, 3), BinaryField128b::ZERO);
704 }
705
706 #[test]
707 fn test_subcube_evals_embeds_correctly() {
708 let mut rng = StdRng::seed_from_u64(0);
709
710 type P = PackedBinaryField16x8b;
711 type PE = PackedBinaryField1x128b;
712
713 let packed_count = 4;
714 let values: Vec<_> = repeat_with(|| P::random(&mut rng))
715 .take(1 << packed_count)
716 .collect();
717
718 let mle = MultilinearExtension::from_values(values).unwrap();
719 let mles = MLEEmbeddingAdapter::<P, PE, _>::from(mle);
720
721 let bytes_values = P::unpack_scalars(mles.0.evals());
722
723 let n_vars = packed_count + P::LOG_WIDTH;
724 let mut evals = vec![PE::zero(); 1 << n_vars];
725 for subcube_vars in 0..n_vars {
726 for subcube_index in 0..1 << (n_vars - subcube_vars) {
727 for log_embedding_degree in 0..=4 {
728 let evals_subcube = &mut evals
729 [0..1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH)];
730
731 mles.subcube_evals(
732 subcube_vars,
733 subcube_index,
734 log_embedding_degree,
735 evals_subcube,
736 )
737 .unwrap();
738
739 let bytes_evals = P::unpack_scalars(
740 <PE as PackedExtension<BinaryField8b>>::cast_bases(evals_subcube),
741 );
742
743 let shift = 4 - log_embedding_degree;
744 let skip_mask = (1 << shift) - 1;
745 for (i, &b_evals) in bytes_evals.iter().enumerate() {
746 let b_values = if i & skip_mask == 0 && i < 1 << (subcube_vars + shift) {
747 bytes_values[(subcube_index << subcube_vars) + (i >> shift)]
748 } else {
749 BinaryField8b::ZERO
750 };
751 assert_eq!(b_evals, b_values);
752 }
753 }
754 }
755 }
756 }
757
758 #[test]
759 fn test_subcube_partial_and_evaluate_partial_conform() {
760 let mut rng = StdRng::seed_from_u64(0);
761 let n_vars = 12;
762 let evals = repeat_with(|| P::random(&mut rng))
763 .take(1 << (n_vars - P::LOG_WIDTH))
764 .collect::<Vec<_>>();
765 let mle = MultilinearExtension::from_values(evals).unwrap();
766 let mles = MLEDirectAdapter::from(mle);
767 let q = repeat_with(|| Field::random(&mut rng))
768 .take(6)
769 .collect::<Vec<F>>();
770 let query = multilinear_query(&q);
771 let partial_low_eval = mles.evaluate_partial_low(query.to_ref()).unwrap();
772 let partial_high_eval = mles.evaluate_partial_high(query.to_ref()).unwrap();
773
774 let subcube_vars = 4;
775 let mut subcube_partial_low_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
776 let mut subcube_partial_high_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
777 for subcube_index in 0..(n_vars - query.n_vars() - subcube_vars) {
778 mles.subcube_partial_low_evals(
779 query.to_ref(),
780 subcube_vars,
781 subcube_index,
782 &mut subcube_partial_low_evals,
783 )
784 .unwrap();
785 mles.subcube_partial_high_evals(
786 query.to_ref(),
787 subcube_vars,
788 subcube_index,
789 &mut subcube_partial_high_evals,
790 )
791 .unwrap();
792 for hypercube_idx in 0..(1 << subcube_vars) {
793 assert_eq!(
794 get_packed_slice(&subcube_partial_low_evals, hypercube_idx),
795 partial_low_eval
796 .evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
797 .unwrap()
798 );
799 assert_eq!(
800 get_packed_slice(&subcube_partial_high_evals, hypercube_idx),
801 partial_high_eval
802 .evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
803 .unwrap()
804 );
805 }
806 }
807 }
808
809 #[test]
810 fn test_packed_evals_against_subcube_evals() {
811 type U = OptimalUnderlier256b;
812 type P = PackedType<U, BinaryField32b>;
813 type PExt = PackedType<U, BinaryField128b>;
814
815 let mut rng = StdRng::seed_from_u64(0);
816 let evals = repeat_with(|| P::random(&mut rng))
817 .take(2)
818 .collect::<Vec<_>>();
819 let mle = MultilinearExtension::from_values(evals.clone()).unwrap();
820 let poly = MLEEmbeddingAdapter::from(mle);
821 assert_eq!(
822 <PExt as PackedExtension<BinaryField32b>>::cast_bases(poly.packed_evals().unwrap()),
823 &evals
824 );
825
826 let mut evals_out = vec![PExt::zero(); 2];
827 poly.subcube_evals(poly.n_vars(), 0, poly.log_extension_degree(), evals_out.as_mut_slice())
828 .unwrap();
829 assert_eq!(evals_out, poly.packed_evals().unwrap());
830 }
831}