1use std::{fmt::Debug, marker::PhantomData, ops::Deref, sync::Arc};
4
5use binius_field::{
6 packed::{
7 get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
8 },
9 ExtensionField, Field, PackedField, RepackedExtension,
10};
11use binius_utils::bail;
12
13use super::{Error, MultilinearExtension, MultilinearPoly, MultilinearQueryRef};
14
15#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct MLEEmbeddingAdapter<P, PE, Data = Vec<P>>(
22 MultilinearExtension<P, Data>,
23 PhantomData<PE>,
24)
25where
26 P: PackedField,
27 PE: PackedField,
28 PE::Scalar: ExtensionField<P::Scalar>,
29 Data: Deref<Target = [P]>;
30
31impl<'a, P, PE, Data> MLEEmbeddingAdapter<P, PE, Data>
32where
33 P: PackedField,
34 PE: PackedField + RepackedExtension<P>,
35 PE::Scalar: ExtensionField<P::Scalar>,
36 Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
37{
38 pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
39 Arc::new(self)
40 }
41}
42
43impl<P, PE, Data> From<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
44where
45 P: PackedField,
46 PE: PackedField,
47 PE::Scalar: ExtensionField<P::Scalar>,
48 Data: Deref<Target = [P]>,
49{
50 fn from(inner: MultilinearExtension<P, Data>) -> Self {
51 Self(inner, PhantomData)
52 }
53}
54
55impl<P, PE, Data> AsRef<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
56where
57 P: PackedField,
58 PE: PackedField,
59 PE::Scalar: ExtensionField<P::Scalar>,
60 Data: Deref<Target = [P]>,
61{
62 fn as_ref(&self) -> &MultilinearExtension<P, Data> {
63 &self.0
64 }
65}
66
67impl<P, PE, Data> MultilinearPoly<PE> for MLEEmbeddingAdapter<P, PE, Data>
68where
69 P: PackedField + Debug,
70 PE: PackedField + RepackedExtension<P>,
71 PE::Scalar: ExtensionField<P::Scalar>,
72 Data: Deref<Target = [P]> + Send + Sync + Debug,
73{
74 fn n_vars(&self) -> usize {
75 self.0.n_vars()
76 }
77
78 fn log_extension_degree(&self) -> usize {
79 PE::Scalar::LOG_DEGREE
80 }
81
82 fn evaluate_on_hypercube(&self, index: usize) -> Result<PE::Scalar, Error> {
83 let eval = self.0.evaluate_on_hypercube(index)?;
84 Ok(eval.into())
85 }
86
87 fn evaluate_on_hypercube_and_scale(
88 &self,
89 index: usize,
90 scalar: PE::Scalar,
91 ) -> Result<PE::Scalar, Error> {
92 let eval = self.0.evaluate_on_hypercube(index)?;
93 Ok(scalar * eval)
94 }
95
96 fn evaluate(&self, query: MultilinearQueryRef<PE>) -> Result<PE::Scalar, Error> {
97 self.0.evaluate(query)
98 }
99
100 fn evaluate_partial_low(
101 &self,
102 query: MultilinearQueryRef<PE>,
103 ) -> Result<MultilinearExtension<PE>, Error> {
104 self.0.evaluate_partial_low(query)
105 }
106
107 fn evaluate_partial_high(
108 &self,
109 query: MultilinearQueryRef<PE>,
110 ) -> Result<MultilinearExtension<PE>, Error> {
111 self.0.evaluate_partial_high(query)
112 }
113
114 fn subcube_partial_low_evals(
115 &self,
116 query: MultilinearQueryRef<PE>,
117 subcube_vars: usize,
118 subcube_index: usize,
119 partial_low_evals: &mut [PE],
120 ) -> Result<(), Error> {
121 validate_subcube_partial_evals_params(
122 self.n_vars(),
123 query,
124 subcube_vars,
125 subcube_index,
126 partial_low_evals,
127 )?;
128
129 let query_n_vars = query.n_vars();
130 let subcube_start = subcube_index << (query_n_vars + subcube_vars);
131
132 for scalar_index in 0..1 << subcube_vars {
135 let evals_start = subcube_start + (scalar_index << query_n_vars);
136 let mut inner_product = PE::Scalar::ZERO;
137 for i in 0..1 << query_n_vars {
138 inner_product += get_packed_slice(query.expansion(), i)
139 * get_packed_slice(self.0.evals(), evals_start + i);
140 }
141
142 set_packed_slice(partial_low_evals, scalar_index, inner_product);
143 }
144
145 Ok(())
146 }
147
148 fn subcube_partial_high_evals(
149 &self,
150 query: MultilinearQueryRef<PE>,
151 subcube_vars: usize,
152 subcube_index: usize,
153 partial_high_evals: &mut [PE],
154 ) -> Result<(), Error> {
155 validate_subcube_partial_evals_params(
156 self.n_vars(),
157 query,
158 subcube_vars,
159 subcube_index,
160 partial_high_evals,
161 )?;
162
163 let query_n_vars = query.n_vars();
164
165 partial_high_evals.fill(PE::zero());
168
169 for query_index in 0..1 << query_n_vars {
170 let query_factor = get_packed_slice(query.expansion(), query_index);
171 let subcube_start =
172 subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
173 for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
174 *packed += PE::from_fn(|inner_index| {
175 let index = subcube_start | outer_index << PE::LOG_WIDTH | inner_index;
176 query_factor * get_packed_slice(self.0.evals(), index)
177 });
178 }
179 }
180
181 if subcube_vars < PE::LOG_WIDTH {
182 for i in 1 << subcube_vars..PE::WIDTH {
183 partial_high_evals
184 .first_mut()
185 .expect("at least one")
186 .set(i, PE::Scalar::ZERO);
187 }
188 }
189
190 Ok(())
191 }
192
193 fn subcube_evals(
194 &self,
195 subcube_vars: usize,
196 subcube_index: usize,
197 log_embedding_degree: usize,
198 evals: &mut [PE],
199 ) -> Result<(), Error> {
200 let log_extension_degree = PE::Scalar::LOG_DEGREE;
201
202 if subcube_vars > self.n_vars() {
203 bail!(Error::ArgumentRangeError {
204 arg: "subcube_vars".to_string(),
205 range: 0..self.n_vars() + 1,
206 });
207 }
208
209 const MAX_TOWER_HEIGHT: usize = 7;
213 if log_embedding_degree > log_extension_degree.min(MAX_TOWER_HEIGHT) {
214 bail!(Error::LogEmbeddingDegreeTooLarge {
215 log_embedding_degree
216 });
217 }
218
219 let correct_len = 1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH);
220 if evals.len() != correct_len {
221 bail!(Error::ArgumentRangeError {
222 arg: "evals.len()".to_string(),
223 range: correct_len..correct_len + 1,
224 });
225 }
226
227 let max_index = 1 << (self.n_vars() - subcube_vars);
228 if subcube_index >= max_index {
229 bail!(Error::ArgumentRangeError {
230 arg: "subcube_index".to_string(),
231 range: 0..max_index,
232 });
233 }
234
235 let subcube_start = subcube_index << subcube_vars;
236
237 if log_embedding_degree == 0 {
238 for i in 0..1 << subcube_vars {
240 let scalar =
242 unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
243
244 let extension_scalar = scalar.into();
245
246 unsafe {
248 set_packed_slice_unchecked(evals, i, extension_scalar);
249 }
250 }
251 } else {
252 let mut bases = [P::Scalar::default(); 1 << MAX_TOWER_HEIGHT];
254 let bases = &mut bases[0..1 << log_embedding_degree];
255
256 let bases_count = 1 << log_embedding_degree.min(subcube_vars);
257 for i in 0..1 << subcube_vars.saturating_sub(log_embedding_degree) {
258 for (j, base) in bases[..bases_count].iter_mut().enumerate() {
259 *base = unsafe {
261 get_packed_slice_unchecked(
262 self.0.evals(),
263 subcube_start + (i << log_embedding_degree) + j,
264 )
265 };
266 }
267
268 let extension_scalar = PE::Scalar::from_bases_sparse(
269 bases,
270 log_extension_degree - log_embedding_degree,
271 )?;
272
273 unsafe {
275 set_packed_slice_unchecked(evals, i, extension_scalar);
276 }
277 }
278 }
279
280 Ok(())
281 }
282
283 fn packed_evals(&self) -> Option<&[PE]> {
284 Some(PE::cast_exts(self.0.evals()))
285 }
286}
287
288impl<P, Data> MultilinearExtension<P, Data>
289where
290 P: PackedField,
291 Data: Deref<Target = [P]>,
292{
293 pub fn specialize<PE>(self) -> MLEEmbeddingAdapter<P, PE, Data>
294 where
295 PE: PackedField,
296 PE::Scalar: ExtensionField<P::Scalar>,
297 {
298 MLEEmbeddingAdapter::from(self)
299 }
300}
301
302impl<'a, P, Data> MultilinearExtension<P, Data>
303where
304 P: PackedField,
305 Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
306{
307 pub fn specialize_arc_dyn<PE: RepackedExtension<P>>(
308 self,
309 ) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
310 self.specialize().upcast_arc_dyn()
311 }
312}
313
314#[derive(Debug, Clone, PartialEq, Eq)]
317pub struct MLEDirectAdapter<P, Data = Vec<P>>(MultilinearExtension<P, Data>)
318where
319 P: PackedField,
320 Data: Deref<Target = [P]>;
321
322impl<'a, P, Data> MLEDirectAdapter<P, Data>
323where
324 P: PackedField,
325 Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
326{
327 pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<P> + Send + Sync + 'a> {
328 Arc::new(self)
329 }
330}
331
332impl<P, Data> From<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
333where
334 P: PackedField,
335 Data: Deref<Target = [P]>,
336{
337 fn from(inner: MultilinearExtension<P, Data>) -> Self {
338 Self(inner)
339 }
340}
341
342impl<P, Data> AsRef<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
343where
344 P: PackedField,
345 Data: Deref<Target = [P]>,
346{
347 fn as_ref(&self) -> &MultilinearExtension<P, Data> {
348 &self.0
349 }
350}
351
352impl<F, P, Data> MultilinearPoly<P> for MLEDirectAdapter<P, Data>
353where
354 F: Field,
355 P: PackedField<Scalar = F>,
356 Data: Deref<Target = [P]> + Send + Sync + Debug,
357{
358 #[inline]
359 fn n_vars(&self) -> usize {
360 self.0.n_vars()
361 }
362
363 #[inline]
364 fn log_extension_degree(&self) -> usize {
365 0
366 }
367
368 fn evaluate_on_hypercube(&self, index: usize) -> Result<F, Error> {
369 self.0.evaluate_on_hypercube(index)
370 }
371
372 fn evaluate_on_hypercube_and_scale(&self, index: usize, scalar: F) -> Result<F, Error> {
373 let eval = self.0.evaluate_on_hypercube(index)?;
374 Ok(scalar * eval)
375 }
376
377 fn evaluate(&self, query: MultilinearQueryRef<P>) -> Result<F, Error> {
378 self.0.evaluate(query)
379 }
380
381 fn evaluate_partial_low(
382 &self,
383 query: MultilinearQueryRef<P>,
384 ) -> Result<MultilinearExtension<P>, Error> {
385 self.0.evaluate_partial_low(query)
386 }
387
388 fn evaluate_partial_high(
389 &self,
390 query: MultilinearQueryRef<P>,
391 ) -> Result<MultilinearExtension<P>, Error> {
392 self.0.evaluate_partial_high(query)
393 }
394
395 fn subcube_partial_low_evals(
396 &self,
397 query: MultilinearQueryRef<P>,
398 subcube_vars: usize,
399 subcube_index: usize,
400 partial_low_evals: &mut [P],
401 ) -> Result<(), Error> {
402 validate_subcube_partial_evals_params(
404 self.n_vars(),
405 query,
406 subcube_vars,
407 subcube_index,
408 partial_low_evals,
409 )?;
410
411 let query_n_vars = query.n_vars();
412 let subcube_start = subcube_index << (query_n_vars + subcube_vars);
413
414 for scalar_index in 0..1 << subcube_vars {
416 let evals_start = subcube_start + (scalar_index << query_n_vars);
417 let mut inner_product = F::ZERO;
418 for i in 0..1 << query_n_vars {
419 inner_product += get_packed_slice(query.expansion(), i)
420 * get_packed_slice(self.0.evals(), evals_start + i);
421 }
422
423 set_packed_slice(partial_low_evals, scalar_index, inner_product);
424 }
425
426 Ok(())
427 }
428
429 fn subcube_partial_high_evals(
430 &self,
431 query: MultilinearQueryRef<P>,
432 subcube_vars: usize,
433 subcube_index: usize,
434 partial_high_evals: &mut [P],
435 ) -> Result<(), Error> {
436 validate_subcube_partial_evals_params(
438 self.n_vars(),
439 query,
440 subcube_vars,
441 subcube_index,
442 partial_high_evals,
443 )?;
444
445 let query_n_vars = query.n_vars();
446
447 partial_high_evals.fill(P::zero());
450
451 for query_index in 0..1 << query_n_vars {
452 let query_factor = get_packed_slice(query.expansion(), query_index);
453 let subcube_start =
454 subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
455 for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
456 *packed += P::from_fn(|inner_index| {
457 let index = subcube_start | outer_index << P::LOG_WIDTH | inner_index;
458 query_factor * get_packed_slice(self.0.evals(), index)
459 });
460 }
461 }
462
463 if subcube_vars < P::LOG_WIDTH {
464 for i in 1 << subcube_vars..P::WIDTH {
465 partial_high_evals
466 .first_mut()
467 .expect("at least one")
468 .set(i, P::Scalar::ZERO);
469 }
470 }
471
472 Ok(())
473 }
474
475 fn subcube_evals(
476 &self,
477 subcube_vars: usize,
478 subcube_index: usize,
479 log_embedding_degree: usize,
480 evals: &mut [P],
481 ) -> Result<(), Error> {
482 let n_vars = self.n_vars();
483 if subcube_vars > n_vars {
484 bail!(Error::ArgumentRangeError {
485 arg: "subcube_vars".to_string(),
486 range: 0..n_vars + 1,
487 });
488 }
489
490 if log_embedding_degree != 0 {
491 bail!(Error::LogEmbeddingDegreeTooLarge {
492 log_embedding_degree
493 });
494 }
495
496 let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
497 if evals.len() != correct_len {
498 bail!(Error::ArgumentRangeError {
499 arg: "evals.len()".to_string(),
500 range: correct_len..correct_len + 1,
501 });
502 }
503
504 let max_index = 1 << (n_vars - subcube_vars);
505 if subcube_index >= max_index {
506 bail!(Error::ArgumentRangeError {
507 arg: "subcube_index".to_string(),
508 range: 0..max_index,
509 });
510 }
511
512 if subcube_vars < P::LOG_WIDTH {
513 let subcube_start = subcube_index << subcube_vars;
514 for i in 0..1 << subcube_vars {
515 let scalar =
517 unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
518
519 unsafe {
521 set_packed_slice_unchecked(evals, i, scalar);
522 }
523 }
524 } else {
525 let range = subcube_index << (subcube_vars - P::LOG_WIDTH)
526 ..(subcube_index + 1) << (subcube_vars - P::LOG_WIDTH);
527 evals.copy_from_slice(&self.0.evals()[range]);
528 }
529
530 Ok(())
531 }
532
533 fn packed_evals(&self) -> Option<&[P]> {
534 Some(self.0.evals())
535 }
536}
537
538fn validate_subcube_partial_evals_params<P: PackedField>(
539 n_vars: usize,
540 query: MultilinearQueryRef<P>,
541 subcube_vars: usize,
542 subcube_index: usize,
543 partial_evals: &[P],
544) -> Result<(), Error> {
545 let query_n_vars = query.n_vars();
546 if query_n_vars + subcube_vars > n_vars {
547 bail!(Error::ArgumentRangeError {
548 arg: "query.n_vars() + subcube_vars".into(),
549 range: 0..n_vars,
550 });
551 }
552
553 let max_index = 1 << (n_vars - query_n_vars - subcube_vars);
554 if subcube_index >= max_index {
555 bail!(Error::ArgumentRangeError {
556 arg: "subcube_index".into(),
557 range: 0..max_index,
558 });
559 }
560
561 let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
562 if partial_evals.len() != correct_len {
563 bail!(Error::ArgumentRangeError {
564 arg: "partial_evals.len()".to_string(),
565 range: correct_len..correct_len + 1,
566 });
567 }
568
569 Ok(())
570}
571
572#[cfg(test)]
573mod tests {
574 use std::iter::repeat_with;
575
576 use binius_field::{
577 arch::OptimalUnderlier256b, as_packed_field::PackedType, BinaryField128b, BinaryField16b,
578 BinaryField32b, BinaryField8b, PackedBinaryField16x8b, PackedBinaryField1x128b,
579 PackedBinaryField4x128b, PackedBinaryField4x32b, PackedBinaryField64x8b,
580 PackedBinaryField8x16b, PackedExtension, PackedField, PackedFieldIndexable,
581 };
582 use rand::prelude::*;
583
584 use super::*;
585 use crate::{tensor_prod_eq_ind, MultilinearQuery};
586
587 type F = BinaryField16b;
588 type P = PackedBinaryField8x16b;
589
590 fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
591 let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
592 result[0] = P::set_single(P::Scalar::ONE);
593 tensor_prod_eq_ind(0, &mut result, p).unwrap();
594 MultilinearQuery::with_expansion(p.len(), result).unwrap()
595 }
596
597 #[test]
598 fn test_evaluate_subcube_and_evaluate_partial_consistent() {
599 let mut rng = StdRng::seed_from_u64(0);
600 let poly = MultilinearExtension::from_values(
601 repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
602 .take(1 << 8)
603 .collect(),
604 )
605 .unwrap()
606 .specialize::<PackedBinaryField1x128b>();
607
608 let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
609 .take(6)
610 .collect::<Vec<_>>();
611 let query = multilinear_query(&q);
612
613 let partial_low = poly.evaluate_partial_low(query.to_ref()).unwrap();
614 let partial_high = poly.evaluate_partial_high(query.to_ref()).unwrap();
615
616 let mut subcube_partial_low = vec![PackedBinaryField1x128b::zero(); 16];
617 let mut subcube_partial_high = vec![PackedBinaryField1x128b::zero(); 16];
618 poly.subcube_partial_low_evals(query.to_ref(), 4, 0, &mut subcube_partial_low)
619 .unwrap();
620 poly.subcube_partial_high_evals(query.to_ref(), 4, 0, &mut subcube_partial_high)
621 .unwrap();
622
623 for (idx, subcube_partial_low) in PackedField::iter_slice(&subcube_partial_low).enumerate()
624 {
625 assert_eq!(subcube_partial_low, partial_low.evaluate_on_hypercube(idx).unwrap(),);
626 }
627
628 for (idx, subcube_partial_high) in
629 PackedField::iter_slice(&subcube_partial_high).enumerate()
630 {
631 assert_eq!(subcube_partial_high, partial_high.evaluate_on_hypercube(idx).unwrap(),);
632 }
633 }
634
635 #[test]
636 fn test_evaluate_subcube_smaller_than_packed_width() {
637 let mut rng = StdRng::seed_from_u64(0);
638 let poly = MultilinearExtension::new(
639 2,
640 vec![PackedBinaryField64x8b::from_scalars(
641 [2, 2, 9, 9].map(BinaryField8b::new),
642 )],
643 )
644 .unwrap()
645 .specialize::<PackedBinaryField4x128b>();
646
647 let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
648 .take(1)
649 .collect::<Vec<_>>();
650 let query = multilinear_query(&q);
651
652 let mut subcube_partial_low = vec![PackedBinaryField4x128b::zero(); 1];
653 let mut subcube_partial_high = vec![PackedBinaryField4x128b::zero(); 1];
654 poly.subcube_partial_low_evals(query.to_ref(), 1, 0, &mut subcube_partial_low)
655 .unwrap();
656 poly.subcube_partial_high_evals(query.to_ref(), 1, 0, &mut subcube_partial_high)
657 .unwrap();
658
659 let expected_partial_high = BinaryField128b::new(2)
660 + (BinaryField128b::new(2) + BinaryField128b::new(9)) * q.first().unwrap();
661
662 assert_eq!(get_packed_slice(&subcube_partial_low, 0), BinaryField128b::new(2));
663 assert_eq!(get_packed_slice(&subcube_partial_low, 1), BinaryField128b::new(9));
664 assert_eq!(get_packed_slice(&subcube_partial_low, 2), BinaryField128b::ZERO);
665 assert_eq!(get_packed_slice(&subcube_partial_low, 3), BinaryField128b::ZERO);
666 assert_eq!(get_packed_slice(&subcube_partial_high, 0), expected_partial_high);
667 assert_eq!(get_packed_slice(&subcube_partial_high, 1), expected_partial_high);
668 assert_eq!(get_packed_slice(&subcube_partial_high, 2), BinaryField128b::ZERO);
669 assert_eq!(get_packed_slice(&subcube_partial_high, 3), BinaryField128b::ZERO);
670 }
671
672 #[test]
673 fn test_subcube_evals_embeds_correctly() {
674 let mut rng = StdRng::seed_from_u64(0);
675
676 type P = PackedBinaryField16x8b;
677 type PE = PackedBinaryField1x128b;
678
679 let packed_count = 4;
680 let values: Vec<_> = repeat_with(|| P::random(&mut rng))
681 .take(1 << packed_count)
682 .collect();
683
684 let mle = MultilinearExtension::from_values(values).unwrap();
685 let mles = MLEEmbeddingAdapter::<P, PE, _>::from(mle);
686
687 let bytes_values = P::unpack_scalars(mles.0.evals());
688
689 let n_vars = packed_count + P::LOG_WIDTH;
690 let mut evals = vec![PE::zero(); 1 << n_vars];
691 for subcube_vars in 0..n_vars {
692 for subcube_index in 0..1 << (n_vars - subcube_vars) {
693 for log_embedding_degree in 0..=4 {
694 let evals_subcube = &mut evals
695 [0..1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH)];
696
697 mles.subcube_evals(
698 subcube_vars,
699 subcube_index,
700 log_embedding_degree,
701 evals_subcube,
702 )
703 .unwrap();
704
705 let bytes_evals = P::unpack_scalars(
706 <PE as PackedExtension<BinaryField8b>>::cast_bases(evals_subcube),
707 );
708
709 let shift = 4 - log_embedding_degree;
710 let skip_mask = (1 << shift) - 1;
711 for (i, &b_evals) in bytes_evals.iter().enumerate() {
712 let b_values = if i & skip_mask == 0 && i < 1 << (subcube_vars + shift) {
713 bytes_values[(subcube_index << subcube_vars) + (i >> shift)]
714 } else {
715 BinaryField8b::ZERO
716 };
717 assert_eq!(b_evals, b_values);
718 }
719 }
720 }
721 }
722 }
723
724 #[test]
725 fn test_subcube_partial_and_evaluate_partial_conform() {
726 let mut rng = StdRng::seed_from_u64(0);
727 let n_vars = 12;
728 let evals = repeat_with(|| P::random(&mut rng))
729 .take(1 << (n_vars - P::LOG_WIDTH))
730 .collect::<Vec<_>>();
731 let mle = MultilinearExtension::from_values(evals).unwrap();
732 let mles = MLEDirectAdapter::from(mle);
733 let q = repeat_with(|| Field::random(&mut rng))
734 .take(6)
735 .collect::<Vec<F>>();
736 let query = multilinear_query(&q);
737 let partial_low_eval = mles.evaluate_partial_low(query.to_ref()).unwrap();
738 let partial_high_eval = mles.evaluate_partial_high(query.to_ref()).unwrap();
739
740 let subcube_vars = 4;
741 let mut subcube_partial_low_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
742 let mut subcube_partial_high_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
743 for subcube_index in 0..(n_vars - query.n_vars() - subcube_vars) {
744 mles.subcube_partial_low_evals(
745 query.to_ref(),
746 subcube_vars,
747 subcube_index,
748 &mut subcube_partial_low_evals,
749 )
750 .unwrap();
751 mles.subcube_partial_high_evals(
752 query.to_ref(),
753 subcube_vars,
754 subcube_index,
755 &mut subcube_partial_high_evals,
756 )
757 .unwrap();
758 for hypercube_idx in 0..(1 << subcube_vars) {
759 assert_eq!(
760 get_packed_slice(&subcube_partial_low_evals, hypercube_idx),
761 partial_low_eval
762 .evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
763 .unwrap()
764 );
765 assert_eq!(
766 get_packed_slice(&subcube_partial_high_evals, hypercube_idx),
767 partial_high_eval
768 .evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
769 .unwrap()
770 );
771 }
772 }
773 }
774
775 #[test]
776 fn test_packed_evals_against_subcube_evals() {
777 type U = OptimalUnderlier256b;
778 type P = PackedType<U, BinaryField32b>;
779 type PExt = PackedType<U, BinaryField128b>;
780
781 let mut rng = StdRng::seed_from_u64(0);
782 let evals = repeat_with(|| P::random(&mut rng))
783 .take(2)
784 .collect::<Vec<_>>();
785 let mle = MultilinearExtension::from_values(evals.clone()).unwrap();
786 let poly = MLEEmbeddingAdapter::from(mle);
787 assert_eq!(
788 <PExt as PackedExtension<BinaryField32b>>::cast_bases(poly.packed_evals().unwrap()),
789 &evals
790 );
791
792 let mut evals_out = vec![PExt::zero(); 2];
793 poly.subcube_evals(poly.n_vars(), 0, poly.log_extension_degree(), evals_out.as_mut_slice())
794 .unwrap();
795 assert_eq!(evals_out, poly.packed_evals().unwrap());
796 }
797}