binius_math/
mle_adapters.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{fmt::Debug, marker::PhantomData, ops::Deref, sync::Arc};
4
5use binius_field::{
6	ExtensionField, Field, PackedField, RepackedExtension,
7	packed::{
8		get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
9	},
10};
11use binius_utils::bail;
12
13use super::{Error, MultilinearExtension, MultilinearPoly, MultilinearQueryRef};
14
15/// An adapter for [`MultilinearExtension`] that implements [`MultilinearPoly`] over a packed
16/// extension field.
17///
18/// This struct implements `MultilinearPoly` for an extension field of the base field that the
19/// multilinear extension is defined over.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct MLEEmbeddingAdapter<P, PE, Data = Vec<P>>(
22	MultilinearExtension<P, Data>,
23	PhantomData<PE>,
24)
25where
26	P: PackedField,
27	PE: PackedField,
28	PE::Scalar: ExtensionField<P::Scalar>,
29	Data: Deref<Target = [P]>;
30
31impl<'a, P, PE, Data> MLEEmbeddingAdapter<P, PE, Data>
32where
33	P: PackedField,
34	PE: PackedField + RepackedExtension<P>,
35	PE::Scalar: ExtensionField<P::Scalar>,
36	Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
37{
38	pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
39		Arc::new(self)
40	}
41}
42
43impl<P, PE, Data> From<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
44where
45	P: PackedField,
46	PE: PackedField,
47	PE::Scalar: ExtensionField<P::Scalar>,
48	Data: Deref<Target = [P]>,
49{
50	fn from(inner: MultilinearExtension<P, Data>) -> Self {
51		Self(inner, PhantomData)
52	}
53}
54
55impl<P, PE, Data> AsRef<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
56where
57	P: PackedField,
58	PE: PackedField,
59	PE::Scalar: ExtensionField<P::Scalar>,
60	Data: Deref<Target = [P]>,
61{
62	fn as_ref(&self) -> &MultilinearExtension<P, Data> {
63		&self.0
64	}
65}
66
67impl<P, PE, Data> MultilinearPoly<PE> for MLEEmbeddingAdapter<P, PE, Data>
68where
69	P: PackedField + Debug,
70	PE: PackedField + RepackedExtension<P>,
71	PE::Scalar: ExtensionField<P::Scalar>,
72	Data: Deref<Target = [P]> + Send + Sync + Debug,
73{
74	fn n_vars(&self) -> usize {
75		self.0.n_vars()
76	}
77
78	fn log_extension_degree(&self) -> usize {
79		PE::Scalar::LOG_DEGREE
80	}
81
82	fn evaluate_on_hypercube(&self, index: usize) -> Result<PE::Scalar, Error> {
83		let eval = self.0.evaluate_on_hypercube(index)?;
84		Ok(eval.into())
85	}
86
87	fn evaluate_on_hypercube_and_scale(
88		&self,
89		index: usize,
90		scalar: PE::Scalar,
91	) -> Result<PE::Scalar, Error> {
92		let eval = self.0.evaluate_on_hypercube(index)?;
93		Ok(scalar * eval)
94	}
95
96	fn evaluate(&self, query: MultilinearQueryRef<PE>) -> Result<PE::Scalar, Error> {
97		self.0.evaluate(query)
98	}
99
100	fn evaluate_partial_low(
101		&self,
102		query: MultilinearQueryRef<PE>,
103	) -> Result<MultilinearExtension<PE>, Error> {
104		self.0.evaluate_partial_low(query)
105	}
106
107	fn evaluate_partial_high(
108		&self,
109		query: MultilinearQueryRef<PE>,
110	) -> Result<MultilinearExtension<PE>, Error> {
111		self.0.evaluate_partial_high(query)
112	}
113
114	fn evaluate_partial(
115		&self,
116		query: MultilinearQueryRef<PE>,
117		start_index: usize,
118	) -> Result<MultilinearExtension<PE>, Error> {
119		self.0.evaluate_partial(query, start_index)
120	}
121
122	fn zero_pad(
123		&self,
124		n_pad_vars: usize,
125		start_index: usize,
126		nonzero_index: usize,
127	) -> Result<MultilinearExtension<PE>, Error> {
128		self.0.zero_pad(n_pad_vars, start_index, nonzero_index)
129	}
130
131	fn subcube_partial_low_evals(
132		&self,
133		query: MultilinearQueryRef<PE>,
134		subcube_vars: usize,
135		subcube_index: usize,
136		partial_low_evals: &mut [PE],
137	) -> Result<(), Error> {
138		validate_subcube_partial_evals_params(
139			self.n_vars(),
140			query,
141			subcube_vars,
142			subcube_index,
143			partial_low_evals,
144		)?;
145
146		let query_n_vars = query.n_vars();
147		let subcube_start = subcube_index << (query_n_vars + subcube_vars);
148
149		// REVIEW: not spending effort to optimize this as the future of switchover
150		//         is somewhat unclear in light of univariate skip
151		for scalar_index in 0..1 << subcube_vars {
152			let evals_start = subcube_start + (scalar_index << query_n_vars);
153			let mut inner_product = PE::Scalar::ZERO;
154			for i in 0..1 << query_n_vars {
155				inner_product += get_packed_slice(query.expansion(), i)
156					* get_packed_slice(self.0.evals(), evals_start + i);
157			}
158
159			set_packed_slice(partial_low_evals, scalar_index, inner_product);
160		}
161
162		Ok(())
163	}
164
165	fn subcube_partial_high_evals(
166		&self,
167		query: MultilinearQueryRef<PE>,
168		subcube_vars: usize,
169		subcube_index: usize,
170		partial_high_evals: &mut [PE],
171	) -> Result<(), Error> {
172		validate_subcube_partial_evals_params(
173			self.n_vars(),
174			query,
175			subcube_vars,
176			subcube_index,
177			partial_high_evals,
178		)?;
179
180		let query_n_vars = query.n_vars();
181
182		// REVIEW: not spending effort to optimize this as the future of switchover
183		//         is somewhat unclear in light of univariate skip
184		partial_high_evals.fill(PE::zero());
185
186		for query_index in 0..1 << query_n_vars {
187			let query_factor = get_packed_slice(query.expansion(), query_index);
188			let subcube_start =
189				subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
190			for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
191				*packed += PE::from_fn(|inner_index| {
192					let index = subcube_start | outer_index << PE::LOG_WIDTH | inner_index;
193					query_factor * get_packed_slice(self.0.evals(), index)
194				});
195			}
196		}
197
198		if subcube_vars < PE::LOG_WIDTH {
199			for i in 1 << subcube_vars..PE::WIDTH {
200				partial_high_evals
201					.first_mut()
202					.expect("at least one")
203					.set(i, PE::Scalar::ZERO);
204			}
205		}
206
207		Ok(())
208	}
209
210	fn subcube_evals(
211		&self,
212		subcube_vars: usize,
213		subcube_index: usize,
214		log_embedding_degree: usize,
215		evals: &mut [PE],
216	) -> Result<(), Error> {
217		let log_extension_degree = PE::Scalar::LOG_DEGREE;
218
219		if subcube_vars > self.n_vars() {
220			bail!(Error::ArgumentRangeError {
221				arg: "subcube_vars".to_string(),
222				range: 0..self.n_vars() + 1,
223			});
224		}
225
226		// Check that chosen embedding subfield is large enough.
227		// We also use a stack allocated array of bases, which imposes
228		// a maximum tower height restriction.
229		const MAX_TOWER_HEIGHT: usize = 7;
230		if log_embedding_degree > log_extension_degree.min(MAX_TOWER_HEIGHT) {
231			bail!(Error::LogEmbeddingDegreeTooLarge {
232				log_embedding_degree
233			});
234		}
235
236		let correct_len = 1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH);
237		if evals.len() != correct_len {
238			bail!(Error::ArgumentRangeError {
239				arg: "evals.len()".to_string(),
240				range: correct_len..correct_len + 1,
241			});
242		}
243
244		let max_index = 1 << (self.n_vars() - subcube_vars);
245		if subcube_index >= max_index {
246			bail!(Error::ArgumentRangeError {
247				arg: "subcube_index".to_string(),
248				range: 0..max_index,
249			});
250		}
251
252		let subcube_start = subcube_index << subcube_vars;
253
254		if log_embedding_degree == 0 {
255			// One-to-one embedding can bypass the extension field construction overhead.
256			for i in 0..1 << subcube_vars {
257				// Safety: subcube_index < max_index check
258				let scalar =
259					unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
260
261				let extension_scalar = scalar.into();
262
263				// Safety: i < 1 << min(0, subcube_vars) <= correct_len * PE::WIDTH
264				unsafe {
265					set_packed_slice_unchecked(evals, i, extension_scalar);
266				}
267			}
268		} else {
269			// For many-to-one embedding, use ExtensionField::from_bases_sparse
270			let mut bases = [P::Scalar::default(); 1 << MAX_TOWER_HEIGHT];
271			let bases = &mut bases[0..1 << log_embedding_degree];
272
273			let bases_count = 1 << log_embedding_degree.min(subcube_vars);
274			for i in 0..1 << subcube_vars.saturating_sub(log_embedding_degree) {
275				for (j, base) in bases[..bases_count].iter_mut().enumerate() {
276					// Safety: i > 0 iff log_embedding_degree < subcube_vars and subcube_index <
277					// max_index check
278					*base = unsafe {
279						get_packed_slice_unchecked(
280							self.0.evals(),
281							subcube_start + (i << log_embedding_degree) + j,
282						)
283					};
284				}
285
286				let extension_scalar = PE::Scalar::from_bases_sparse(
287					bases.iter().copied(),
288					log_extension_degree - log_embedding_degree,
289				)?;
290
291				// Safety: i < 1 << min(0, subcube_vars - log_embedding_degree) <= correct_len *
292				// PE::WIDTH
293				unsafe {
294					set_packed_slice_unchecked(evals, i, extension_scalar);
295				}
296			}
297		}
298
299		Ok(())
300	}
301
302	fn packed_evals(&self) -> Option<&[PE]> {
303		Some(PE::cast_exts(self.0.evals()))
304	}
305}
306
307impl<P, Data> MultilinearExtension<P, Data>
308where
309	P: PackedField,
310	Data: Deref<Target = [P]>,
311{
312	pub fn specialize<PE>(self) -> MLEEmbeddingAdapter<P, PE, Data>
313	where
314		PE: PackedField,
315		PE::Scalar: ExtensionField<P::Scalar>,
316	{
317		MLEEmbeddingAdapter::from(self)
318	}
319}
320
321impl<'a, P, Data> MultilinearExtension<P, Data>
322where
323	P: PackedField,
324	Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
325{
326	pub fn specialize_arc_dyn<PE: RepackedExtension<P>>(
327		self,
328	) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
329		self.specialize().upcast_arc_dyn()
330	}
331}
332
333/// An adapter for [`MultilinearExtension`] that implements [`MultilinearPoly`] over the same
334/// packed field that the [`MultilinearExtension`] stores evaluations in.
335#[derive(Debug, Clone, PartialEq, Eq)]
336pub struct MLEDirectAdapter<P, Data = Vec<P>>(MultilinearExtension<P, Data>)
337where
338	P: PackedField,
339	Data: Deref<Target = [P]>;
340
341impl<'a, P, Data> MLEDirectAdapter<P, Data>
342where
343	P: PackedField,
344	Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
345{
346	pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<P> + Send + Sync + 'a> {
347		Arc::new(self)
348	}
349}
350
351impl<P, Data> From<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
352where
353	P: PackedField,
354	Data: Deref<Target = [P]>,
355{
356	fn from(inner: MultilinearExtension<P, Data>) -> Self {
357		Self(inner)
358	}
359}
360
361impl<P, Data> AsRef<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
362where
363	P: PackedField,
364	Data: Deref<Target = [P]>,
365{
366	fn as_ref(&self) -> &MultilinearExtension<P, Data> {
367		&self.0
368	}
369}
370
371impl<F, P, Data> MultilinearPoly<P> for MLEDirectAdapter<P, Data>
372where
373	F: Field,
374	P: PackedField<Scalar = F>,
375	Data: Deref<Target = [P]> + Send + Sync + Debug,
376{
377	#[inline]
378	fn n_vars(&self) -> usize {
379		self.0.n_vars()
380	}
381
382	#[inline]
383	fn log_extension_degree(&self) -> usize {
384		0
385	}
386
387	fn evaluate_on_hypercube(&self, index: usize) -> Result<F, Error> {
388		self.0.evaluate_on_hypercube(index)
389	}
390
391	fn evaluate_on_hypercube_and_scale(&self, index: usize, scalar: F) -> Result<F, Error> {
392		let eval = self.0.evaluate_on_hypercube(index)?;
393		Ok(scalar * eval)
394	}
395
396	fn evaluate(&self, query: MultilinearQueryRef<P>) -> Result<F, Error> {
397		self.0.evaluate(query)
398	}
399
400	fn evaluate_partial_low(
401		&self,
402		query: MultilinearQueryRef<P>,
403	) -> Result<MultilinearExtension<P>, Error> {
404		self.0.evaluate_partial_low(query)
405	}
406
407	fn evaluate_partial_high(
408		&self,
409		query: MultilinearQueryRef<P>,
410	) -> Result<MultilinearExtension<P>, Error> {
411		self.0.evaluate_partial_high(query)
412	}
413
414	fn evaluate_partial(
415		&self,
416		query: MultilinearQueryRef<P>,
417		start_index: usize,
418	) -> Result<MultilinearExtension<P>, Error> {
419		self.0.evaluate_partial(query, start_index)
420	}
421
422	fn zero_pad(
423		&self,
424		n_pad_vars: usize,
425		start_index: usize,
426		nonzero_padding: usize,
427	) -> Result<MultilinearExtension<P>, Error> {
428		self.0.zero_pad(n_pad_vars, start_index, nonzero_padding)
429	}
430
431	fn subcube_partial_low_evals(
432		&self,
433		query: MultilinearQueryRef<P>,
434		subcube_vars: usize,
435		subcube_index: usize,
436		partial_low_evals: &mut [P],
437	) -> Result<(), Error> {
438		// TODO: think of a way to factor out duplicated implementation in direct & embedded
439		// adapters.
440		validate_subcube_partial_evals_params(
441			self.n_vars(),
442			query,
443			subcube_vars,
444			subcube_index,
445			partial_low_evals,
446		)?;
447
448		let query_n_vars = query.n_vars();
449		let subcube_start = subcube_index << (query_n_vars + subcube_vars);
450
451		// TODO: Maybe optimize me
452		for scalar_index in 0..1 << subcube_vars {
453			let evals_start = subcube_start + (scalar_index << query_n_vars);
454			let mut inner_product = F::ZERO;
455			for i in 0..1 << query_n_vars {
456				inner_product += get_packed_slice(query.expansion(), i)
457					* get_packed_slice(self.0.evals(), evals_start + i);
458			}
459
460			set_packed_slice(partial_low_evals, scalar_index, inner_product);
461		}
462
463		Ok(())
464	}
465
466	fn subcube_partial_high_evals(
467		&self,
468		query: MultilinearQueryRef<P>,
469		subcube_vars: usize,
470		subcube_index: usize,
471		partial_high_evals: &mut [P],
472	) -> Result<(), Error> {
473		// TODO: think of a way to factor out duplicated implementation in direct & embedded
474		// adapters.
475		validate_subcube_partial_evals_params(
476			self.n_vars(),
477			query,
478			subcube_vars,
479			subcube_index,
480			partial_high_evals,
481		)?;
482
483		let query_n_vars = query.n_vars();
484
485		// REVIEW: not spending effort to optimize this as the future of switchover
486		//         is somewhat unclear in light of univariate skip
487		partial_high_evals.fill(P::zero());
488
489		for query_index in 0..1 << query_n_vars {
490			let query_factor = get_packed_slice(query.expansion(), query_index);
491			let subcube_start =
492				subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
493			for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
494				*packed += P::from_fn(|inner_index| {
495					let index = subcube_start | outer_index << P::LOG_WIDTH | inner_index;
496					query_factor * get_packed_slice(self.0.evals(), index)
497				});
498			}
499		}
500
501		if subcube_vars < P::LOG_WIDTH {
502			for i in 1 << subcube_vars..P::WIDTH {
503				partial_high_evals
504					.first_mut()
505					.expect("at least one")
506					.set(i, P::Scalar::ZERO);
507			}
508		}
509
510		Ok(())
511	}
512
513	fn subcube_evals(
514		&self,
515		subcube_vars: usize,
516		subcube_index: usize,
517		log_embedding_degree: usize,
518		evals: &mut [P],
519	) -> Result<(), Error> {
520		let n_vars = self.n_vars();
521		if subcube_vars > n_vars {
522			bail!(Error::ArgumentRangeError {
523				arg: "subcube_vars".to_string(),
524				range: 0..n_vars + 1,
525			});
526		}
527
528		if log_embedding_degree != 0 {
529			bail!(Error::LogEmbeddingDegreeTooLarge {
530				log_embedding_degree
531			});
532		}
533
534		let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
535		if evals.len() != correct_len {
536			bail!(Error::ArgumentRangeError {
537				arg: "evals.len()".to_string(),
538				range: correct_len..correct_len + 1,
539			});
540		}
541
542		let max_index = 1 << (n_vars - subcube_vars);
543		if subcube_index >= max_index {
544			bail!(Error::ArgumentRangeError {
545				arg: "subcube_index".to_string(),
546				range: 0..max_index,
547			});
548		}
549
550		if subcube_vars < P::LOG_WIDTH {
551			let subcube_start = subcube_index << subcube_vars;
552			for i in 0..1 << subcube_vars {
553				// Safety: subcube_index < max_index check
554				let scalar =
555					unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
556
557				// Safety: i < 1 << min(0, subcube_vars) <= correct_len * P::WIDTH
558				unsafe {
559					set_packed_slice_unchecked(evals, i, scalar);
560				}
561			}
562		} else {
563			let range = subcube_index << (subcube_vars - P::LOG_WIDTH)
564				..(subcube_index + 1) << (subcube_vars - P::LOG_WIDTH);
565			evals.copy_from_slice(&self.0.evals()[range]);
566		}
567
568		Ok(())
569	}
570
571	fn packed_evals(&self) -> Option<&[P]> {
572		Some(self.0.evals())
573	}
574}
575
576fn validate_subcube_partial_evals_params<P: PackedField>(
577	n_vars: usize,
578	query: MultilinearQueryRef<P>,
579	subcube_vars: usize,
580	subcube_index: usize,
581	partial_evals: &[P],
582) -> Result<(), Error> {
583	let query_n_vars = query.n_vars();
584	if query_n_vars + subcube_vars > n_vars {
585		bail!(Error::ArgumentRangeError {
586			arg: "query.n_vars() + subcube_vars".into(),
587			range: 0..n_vars,
588		});
589	}
590
591	let max_index = 1 << (n_vars - query_n_vars - subcube_vars);
592	if subcube_index >= max_index {
593		bail!(Error::ArgumentRangeError {
594			arg: "subcube_index".into(),
595			range: 0..max_index,
596		});
597	}
598
599	let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
600	if partial_evals.len() != correct_len {
601		bail!(Error::ArgumentRangeError {
602			arg: "partial_evals.len()".to_string(),
603			range: correct_len..correct_len + 1,
604		});
605	}
606
607	Ok(())
608}
609
610#[cfg(test)]
611mod tests {
612	use std::iter::repeat_with;
613
614	use binius_field::{
615		BinaryField8b, BinaryField16b, BinaryField32b, BinaryField128b, PackedBinaryField1x128b,
616		PackedBinaryField4x32b, PackedBinaryField4x128b, PackedBinaryField8x16b,
617		PackedBinaryField16x8b, PackedBinaryField64x8b, PackedExtension, PackedField,
618		PackedFieldIndexable, arch::OptimalUnderlier256b, as_packed_field::PackedType,
619	};
620	use rand::prelude::*;
621
622	use super::*;
623	use crate::{MultilinearQuery, tensor_prod_eq_ind};
624
625	type F = BinaryField16b;
626	type P = PackedBinaryField8x16b;
627
628	fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
629		let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
630		result[0] = P::set_single(P::Scalar::ONE);
631		tensor_prod_eq_ind(0, &mut result, p).unwrap();
632		MultilinearQuery::with_expansion(p.len(), result).unwrap()
633	}
634
635	#[test]
636	fn test_evaluate_subcube_and_evaluate_partial_consistent() {
637		let mut rng = StdRng::seed_from_u64(0);
638		let poly = MultilinearExtension::from_values(
639			repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
640				.take(1 << 8)
641				.collect(),
642		)
643		.unwrap()
644		.specialize::<PackedBinaryField1x128b>();
645
646		let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
647			.take(6)
648			.collect::<Vec<_>>();
649		let query = multilinear_query(&q);
650
651		let partial_low = poly.evaluate_partial_low(query.to_ref()).unwrap();
652		let partial_high = poly.evaluate_partial_high(query.to_ref()).unwrap();
653
654		let mut subcube_partial_low = vec![PackedBinaryField1x128b::zero(); 16];
655		let mut subcube_partial_high = vec![PackedBinaryField1x128b::zero(); 16];
656		poly.subcube_partial_low_evals(query.to_ref(), 4, 0, &mut subcube_partial_low)
657			.unwrap();
658		poly.subcube_partial_high_evals(query.to_ref(), 4, 0, &mut subcube_partial_high)
659			.unwrap();
660
661		for (idx, subcube_partial_low) in PackedField::iter_slice(&subcube_partial_low).enumerate()
662		{
663			assert_eq!(subcube_partial_low, partial_low.evaluate_on_hypercube(idx).unwrap(),);
664		}
665
666		for (idx, subcube_partial_high) in
667			PackedField::iter_slice(&subcube_partial_high).enumerate()
668		{
669			assert_eq!(subcube_partial_high, partial_high.evaluate_on_hypercube(idx).unwrap(),);
670		}
671	}
672
673	#[test]
674	fn test_evaluate_subcube_smaller_than_packed_width() {
675		let mut rng = StdRng::seed_from_u64(0);
676		let poly = MultilinearExtension::new(
677			2,
678			vec![PackedBinaryField64x8b::from_scalars(
679				[2, 2, 9, 9].map(BinaryField8b::new),
680			)],
681		)
682		.unwrap()
683		.specialize::<PackedBinaryField4x128b>();
684
685		let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
686			.take(1)
687			.collect::<Vec<_>>();
688		let query = multilinear_query(&q);
689
690		let mut subcube_partial_low = vec![PackedBinaryField4x128b::zero(); 1];
691		let mut subcube_partial_high = vec![PackedBinaryField4x128b::zero(); 1];
692		poly.subcube_partial_low_evals(query.to_ref(), 1, 0, &mut subcube_partial_low)
693			.unwrap();
694		poly.subcube_partial_high_evals(query.to_ref(), 1, 0, &mut subcube_partial_high)
695			.unwrap();
696
697		let expected_partial_high = BinaryField128b::new(2)
698			+ (BinaryField128b::new(2) + BinaryField128b::new(9)) * q.first().unwrap();
699
700		assert_eq!(get_packed_slice(&subcube_partial_low, 0), BinaryField128b::new(2));
701		assert_eq!(get_packed_slice(&subcube_partial_low, 1), BinaryField128b::new(9));
702		assert_eq!(get_packed_slice(&subcube_partial_low, 2), BinaryField128b::ZERO);
703		assert_eq!(get_packed_slice(&subcube_partial_low, 3), BinaryField128b::ZERO);
704		assert_eq!(get_packed_slice(&subcube_partial_high, 0), expected_partial_high);
705		assert_eq!(get_packed_slice(&subcube_partial_high, 1), expected_partial_high);
706		assert_eq!(get_packed_slice(&subcube_partial_high, 2), BinaryField128b::ZERO);
707		assert_eq!(get_packed_slice(&subcube_partial_high, 3), BinaryField128b::ZERO);
708	}
709
710	#[test]
711	fn test_subcube_evals_embeds_correctly() {
712		let mut rng = StdRng::seed_from_u64(0);
713
714		type P = PackedBinaryField16x8b;
715		type PE = PackedBinaryField1x128b;
716
717		let packed_count = 4;
718		let values: Vec<_> = repeat_with(|| P::random(&mut rng))
719			.take(1 << packed_count)
720			.collect();
721
722		let mle = MultilinearExtension::from_values(values).unwrap();
723		let mles = MLEEmbeddingAdapter::<P, PE, _>::from(mle);
724
725		let bytes_values = P::unpack_scalars(mles.0.evals());
726
727		let n_vars = packed_count + P::LOG_WIDTH;
728		let mut evals = vec![PE::zero(); 1 << n_vars];
729		for subcube_vars in 0..n_vars {
730			for subcube_index in 0..1 << (n_vars - subcube_vars) {
731				for log_embedding_degree in 0..=4 {
732					let evals_subcube = &mut evals
733						[0..1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH)];
734
735					mles.subcube_evals(
736						subcube_vars,
737						subcube_index,
738						log_embedding_degree,
739						evals_subcube,
740					)
741					.unwrap();
742
743					let bytes_evals = P::unpack_scalars(
744						<PE as PackedExtension<BinaryField8b>>::cast_bases(evals_subcube),
745					);
746
747					let shift = 4 - log_embedding_degree;
748					let skip_mask = (1 << shift) - 1;
749					for (i, &b_evals) in bytes_evals.iter().enumerate() {
750						let b_values = if i & skip_mask == 0 && i < 1 << (subcube_vars + shift) {
751							bytes_values[(subcube_index << subcube_vars) + (i >> shift)]
752						} else {
753							BinaryField8b::ZERO
754						};
755						assert_eq!(b_evals, b_values);
756					}
757				}
758			}
759		}
760	}
761
762	#[test]
763	fn test_subcube_partial_and_evaluate_partial_conform() {
764		let mut rng = StdRng::seed_from_u64(0);
765		let n_vars = 12;
766		let evals = repeat_with(|| P::random(&mut rng))
767			.take(1 << (n_vars - P::LOG_WIDTH))
768			.collect::<Vec<_>>();
769		let mle = MultilinearExtension::from_values(evals).unwrap();
770		let mles = MLEDirectAdapter::from(mle);
771		let q = repeat_with(|| Field::random(&mut rng))
772			.take(6)
773			.collect::<Vec<F>>();
774		let query = multilinear_query(&q);
775		let partial_low_eval = mles.evaluate_partial_low(query.to_ref()).unwrap();
776		let partial_high_eval = mles.evaluate_partial_high(query.to_ref()).unwrap();
777
778		let subcube_vars = 4;
779		let mut subcube_partial_low_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
780		let mut subcube_partial_high_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
781		for subcube_index in 0..(n_vars - query.n_vars() - subcube_vars) {
782			mles.subcube_partial_low_evals(
783				query.to_ref(),
784				subcube_vars,
785				subcube_index,
786				&mut subcube_partial_low_evals,
787			)
788			.unwrap();
789			mles.subcube_partial_high_evals(
790				query.to_ref(),
791				subcube_vars,
792				subcube_index,
793				&mut subcube_partial_high_evals,
794			)
795			.unwrap();
796			for hypercube_idx in 0..(1 << subcube_vars) {
797				assert_eq!(
798					get_packed_slice(&subcube_partial_low_evals, hypercube_idx),
799					partial_low_eval
800						.evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
801						.unwrap()
802				);
803				assert_eq!(
804					get_packed_slice(&subcube_partial_high_evals, hypercube_idx),
805					partial_high_eval
806						.evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
807						.unwrap()
808				);
809			}
810		}
811	}
812
813	#[test]
814	fn test_packed_evals_against_subcube_evals() {
815		type U = OptimalUnderlier256b;
816		type P = PackedType<U, BinaryField32b>;
817		type PExt = PackedType<U, BinaryField128b>;
818
819		let mut rng = StdRng::seed_from_u64(0);
820		let evals = repeat_with(|| P::random(&mut rng))
821			.take(2)
822			.collect::<Vec<_>>();
823		let mle = MultilinearExtension::from_values(evals.clone()).unwrap();
824		let poly = MLEEmbeddingAdapter::from(mle);
825		assert_eq!(
826			<PExt as PackedExtension<BinaryField32b>>::cast_bases(poly.packed_evals().unwrap()),
827			&evals
828		);
829
830		let mut evals_out = vec![PExt::zero(); 2];
831		poly.subcube_evals(poly.n_vars(), 0, poly.log_extension_degree(), evals_out.as_mut_slice())
832			.unwrap();
833		assert_eq!(evals_out, poly.packed_evals().unwrap());
834	}
835}