binius_math/
mle_adapters.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{fmt::Debug, marker::PhantomData, ops::Deref, sync::Arc};
4
5use binius_field::{
6	packed::{
7		get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
8	},
9	ExtensionField, Field, PackedField, RepackedExtension,
10};
11use binius_utils::bail;
12
13use super::{Error, MultilinearExtension, MultilinearPoly, MultilinearQueryRef};
14
15/// An adapter for [`MultilinearExtension`] that implements [`MultilinearPoly`] over a packed
16/// extension field.
17///
18/// This struct implements `MultilinearPoly` for an extension field of the base field that the
19/// multilinear extension is defined over.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct MLEEmbeddingAdapter<P, PE, Data = Vec<P>>(
22	MultilinearExtension<P, Data>,
23	PhantomData<PE>,
24)
25where
26	P: PackedField,
27	PE: PackedField,
28	PE::Scalar: ExtensionField<P::Scalar>,
29	Data: Deref<Target = [P]>;
30
31impl<'a, P, PE, Data> MLEEmbeddingAdapter<P, PE, Data>
32where
33	P: PackedField,
34	PE: PackedField + RepackedExtension<P>,
35	PE::Scalar: ExtensionField<P::Scalar>,
36	Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
37{
38	pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
39		Arc::new(self)
40	}
41}
42
43impl<P, PE, Data> From<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
44where
45	P: PackedField,
46	PE: PackedField,
47	PE::Scalar: ExtensionField<P::Scalar>,
48	Data: Deref<Target = [P]>,
49{
50	fn from(inner: MultilinearExtension<P, Data>) -> Self {
51		Self(inner, PhantomData)
52	}
53}
54
55impl<P, PE, Data> AsRef<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
56where
57	P: PackedField,
58	PE: PackedField,
59	PE::Scalar: ExtensionField<P::Scalar>,
60	Data: Deref<Target = [P]>,
61{
62	fn as_ref(&self) -> &MultilinearExtension<P, Data> {
63		&self.0
64	}
65}
66
67impl<P, PE, Data> MultilinearPoly<PE> for MLEEmbeddingAdapter<P, PE, Data>
68where
69	P: PackedField + Debug,
70	PE: PackedField + RepackedExtension<P>,
71	PE::Scalar: ExtensionField<P::Scalar>,
72	Data: Deref<Target = [P]> + Send + Sync + Debug,
73{
74	fn n_vars(&self) -> usize {
75		self.0.n_vars()
76	}
77
78	fn log_extension_degree(&self) -> usize {
79		PE::Scalar::LOG_DEGREE
80	}
81
82	fn evaluate_on_hypercube(&self, index: usize) -> Result<PE::Scalar, Error> {
83		let eval = self.0.evaluate_on_hypercube(index)?;
84		Ok(eval.into())
85	}
86
87	fn evaluate_on_hypercube_and_scale(
88		&self,
89		index: usize,
90		scalar: PE::Scalar,
91	) -> Result<PE::Scalar, Error> {
92		let eval = self.0.evaluate_on_hypercube(index)?;
93		Ok(scalar * eval)
94	}
95
96	fn evaluate(&self, query: MultilinearQueryRef<PE>) -> Result<PE::Scalar, Error> {
97		self.0.evaluate(query)
98	}
99
100	fn evaluate_partial_low(
101		&self,
102		query: MultilinearQueryRef<PE>,
103	) -> Result<MultilinearExtension<PE>, Error> {
104		self.0.evaluate_partial_low(query)
105	}
106
107	fn evaluate_partial_high(
108		&self,
109		query: MultilinearQueryRef<PE>,
110	) -> Result<MultilinearExtension<PE>, Error> {
111		self.0.evaluate_partial_high(query)
112	}
113
114	fn evaluate_partial(
115		&self,
116		query: MultilinearQueryRef<PE>,
117		start_index: usize,
118	) -> Result<MultilinearExtension<PE>, Error> {
119		self.0.evaluate_partial(query, start_index)
120	}
121
122	fn subcube_partial_low_evals(
123		&self,
124		query: MultilinearQueryRef<PE>,
125		subcube_vars: usize,
126		subcube_index: usize,
127		partial_low_evals: &mut [PE],
128	) -> Result<(), Error> {
129		validate_subcube_partial_evals_params(
130			self.n_vars(),
131			query,
132			subcube_vars,
133			subcube_index,
134			partial_low_evals,
135		)?;
136
137		let query_n_vars = query.n_vars();
138		let subcube_start = subcube_index << (query_n_vars + subcube_vars);
139
140		// REVIEW: not spending effort to optimize this as the future of switchover
141		//         is somewhat unclear in light of univariate skip
142		for scalar_index in 0..1 << subcube_vars {
143			let evals_start = subcube_start + (scalar_index << query_n_vars);
144			let mut inner_product = PE::Scalar::ZERO;
145			for i in 0..1 << query_n_vars {
146				inner_product += get_packed_slice(query.expansion(), i)
147					* get_packed_slice(self.0.evals(), evals_start + i);
148			}
149
150			set_packed_slice(partial_low_evals, scalar_index, inner_product);
151		}
152
153		Ok(())
154	}
155
156	fn subcube_partial_high_evals(
157		&self,
158		query: MultilinearQueryRef<PE>,
159		subcube_vars: usize,
160		subcube_index: usize,
161		partial_high_evals: &mut [PE],
162	) -> Result<(), Error> {
163		validate_subcube_partial_evals_params(
164			self.n_vars(),
165			query,
166			subcube_vars,
167			subcube_index,
168			partial_high_evals,
169		)?;
170
171		let query_n_vars = query.n_vars();
172
173		// REVIEW: not spending effort to optimize this as the future of switchover
174		//         is somewhat unclear in light of univariate skip
175		partial_high_evals.fill(PE::zero());
176
177		for query_index in 0..1 << query_n_vars {
178			let query_factor = get_packed_slice(query.expansion(), query_index);
179			let subcube_start =
180				subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
181			for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
182				*packed += PE::from_fn(|inner_index| {
183					let index = subcube_start | outer_index << PE::LOG_WIDTH | inner_index;
184					query_factor * get_packed_slice(self.0.evals(), index)
185				});
186			}
187		}
188
189		if subcube_vars < PE::LOG_WIDTH {
190			for i in 1 << subcube_vars..PE::WIDTH {
191				partial_high_evals
192					.first_mut()
193					.expect("at least one")
194					.set(i, PE::Scalar::ZERO);
195			}
196		}
197
198		Ok(())
199	}
200
201	fn subcube_evals(
202		&self,
203		subcube_vars: usize,
204		subcube_index: usize,
205		log_embedding_degree: usize,
206		evals: &mut [PE],
207	) -> Result<(), Error> {
208		let log_extension_degree = PE::Scalar::LOG_DEGREE;
209
210		if subcube_vars > self.n_vars() {
211			bail!(Error::ArgumentRangeError {
212				arg: "subcube_vars".to_string(),
213				range: 0..self.n_vars() + 1,
214			});
215		}
216
217		// Check that chosen embedding subfield is large enough.
218		// We also use a stack allocated array of bases, which imposes
219		// a maximum tower height restriction.
220		const MAX_TOWER_HEIGHT: usize = 7;
221		if log_embedding_degree > log_extension_degree.min(MAX_TOWER_HEIGHT) {
222			bail!(Error::LogEmbeddingDegreeTooLarge {
223				log_embedding_degree
224			});
225		}
226
227		let correct_len = 1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH);
228		if evals.len() != correct_len {
229			bail!(Error::ArgumentRangeError {
230				arg: "evals.len()".to_string(),
231				range: correct_len..correct_len + 1,
232			});
233		}
234
235		let max_index = 1 << (self.n_vars() - subcube_vars);
236		if subcube_index >= max_index {
237			bail!(Error::ArgumentRangeError {
238				arg: "subcube_index".to_string(),
239				range: 0..max_index,
240			});
241		}
242
243		let subcube_start = subcube_index << subcube_vars;
244
245		if log_embedding_degree == 0 {
246			// One-to-one embedding can bypass the extension field construction overhead.
247			for i in 0..1 << subcube_vars {
248				// Safety: subcube_index < max_index check
249				let scalar =
250					unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
251
252				let extension_scalar = scalar.into();
253
254				// Safety: i < 1 << min(0, subcube_vars) <= correct_len * PE::WIDTH
255				unsafe {
256					set_packed_slice_unchecked(evals, i, extension_scalar);
257				}
258			}
259		} else {
260			// For many-to-one embedding, use ExtensionField::from_bases_sparse
261			let mut bases = [P::Scalar::default(); 1 << MAX_TOWER_HEIGHT];
262			let bases = &mut bases[0..1 << log_embedding_degree];
263
264			let bases_count = 1 << log_embedding_degree.min(subcube_vars);
265			for i in 0..1 << subcube_vars.saturating_sub(log_embedding_degree) {
266				for (j, base) in bases[..bases_count].iter_mut().enumerate() {
267					// Safety: i > 0 iff log_embedding_degree < subcube_vars and subcube_index < max_index check
268					*base = unsafe {
269						get_packed_slice_unchecked(
270							self.0.evals(),
271							subcube_start + (i << log_embedding_degree) + j,
272						)
273					};
274				}
275
276				let extension_scalar = PE::Scalar::from_bases_sparse(
277					bases.iter().copied(),
278					log_extension_degree - log_embedding_degree,
279				)?;
280
281				// Safety: i < 1 << min(0, subcube_vars - log_embedding_degree) <= correct_len * PE::WIDTH
282				unsafe {
283					set_packed_slice_unchecked(evals, i, extension_scalar);
284				}
285			}
286		}
287
288		Ok(())
289	}
290
291	fn packed_evals(&self) -> Option<&[PE]> {
292		Some(PE::cast_exts(self.0.evals()))
293	}
294}
295
296impl<P, Data> MultilinearExtension<P, Data>
297where
298	P: PackedField,
299	Data: Deref<Target = [P]>,
300{
301	pub fn specialize<PE>(self) -> MLEEmbeddingAdapter<P, PE, Data>
302	where
303		PE: PackedField,
304		PE::Scalar: ExtensionField<P::Scalar>,
305	{
306		MLEEmbeddingAdapter::from(self)
307	}
308}
309
310impl<'a, P, Data> MultilinearExtension<P, Data>
311where
312	P: PackedField,
313	Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
314{
315	pub fn specialize_arc_dyn<PE: RepackedExtension<P>>(
316		self,
317	) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
318		self.specialize().upcast_arc_dyn()
319	}
320}
321
322/// An adapter for [`MultilinearExtension`] that implements [`MultilinearPoly`] over the same
323/// packed field that the [`MultilinearExtension`] stores evaluations in.
324#[derive(Debug, Clone, PartialEq, Eq)]
325pub struct MLEDirectAdapter<P, Data = Vec<P>>(MultilinearExtension<P, Data>)
326where
327	P: PackedField,
328	Data: Deref<Target = [P]>;
329
330impl<'a, P, Data> MLEDirectAdapter<P, Data>
331where
332	P: PackedField,
333	Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
334{
335	pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<P> + Send + Sync + 'a> {
336		Arc::new(self)
337	}
338}
339
340impl<P, Data> From<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
341where
342	P: PackedField,
343	Data: Deref<Target = [P]>,
344{
345	fn from(inner: MultilinearExtension<P, Data>) -> Self {
346		Self(inner)
347	}
348}
349
350impl<P, Data> AsRef<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
351where
352	P: PackedField,
353	Data: Deref<Target = [P]>,
354{
355	fn as_ref(&self) -> &MultilinearExtension<P, Data> {
356		&self.0
357	}
358}
359
360impl<F, P, Data> MultilinearPoly<P> for MLEDirectAdapter<P, Data>
361where
362	F: Field,
363	P: PackedField<Scalar = F>,
364	Data: Deref<Target = [P]> + Send + Sync + Debug,
365{
366	#[inline]
367	fn n_vars(&self) -> usize {
368		self.0.n_vars()
369	}
370
371	#[inline]
372	fn log_extension_degree(&self) -> usize {
373		0
374	}
375
376	fn evaluate_on_hypercube(&self, index: usize) -> Result<F, Error> {
377		self.0.evaluate_on_hypercube(index)
378	}
379
380	fn evaluate_on_hypercube_and_scale(&self, index: usize, scalar: F) -> Result<F, Error> {
381		let eval = self.0.evaluate_on_hypercube(index)?;
382		Ok(scalar * eval)
383	}
384
385	fn evaluate(&self, query: MultilinearQueryRef<P>) -> Result<F, Error> {
386		self.0.evaluate(query)
387	}
388
389	fn evaluate_partial_low(
390		&self,
391		query: MultilinearQueryRef<P>,
392	) -> Result<MultilinearExtension<P>, Error> {
393		self.0.evaluate_partial_low(query)
394	}
395
396	fn evaluate_partial_high(
397		&self,
398		query: MultilinearQueryRef<P>,
399	) -> Result<MultilinearExtension<P>, Error> {
400		self.0.evaluate_partial_high(query)
401	}
402
403	fn evaluate_partial(
404		&self,
405		query: MultilinearQueryRef<P>,
406		start_index: usize,
407	) -> Result<MultilinearExtension<P>, Error> {
408		self.0.evaluate_partial(query, start_index)
409	}
410
411	fn subcube_partial_low_evals(
412		&self,
413		query: MultilinearQueryRef<P>,
414		subcube_vars: usize,
415		subcube_index: usize,
416		partial_low_evals: &mut [P],
417	) -> Result<(), Error> {
418		// TODO: think of a way to factor out duplicated implementation in direct & embedded adapters.
419		validate_subcube_partial_evals_params(
420			self.n_vars(),
421			query,
422			subcube_vars,
423			subcube_index,
424			partial_low_evals,
425		)?;
426
427		let query_n_vars = query.n_vars();
428		let subcube_start = subcube_index << (query_n_vars + subcube_vars);
429
430		// TODO: Maybe optimize me
431		for scalar_index in 0..1 << subcube_vars {
432			let evals_start = subcube_start + (scalar_index << query_n_vars);
433			let mut inner_product = F::ZERO;
434			for i in 0..1 << query_n_vars {
435				inner_product += get_packed_slice(query.expansion(), i)
436					* get_packed_slice(self.0.evals(), evals_start + i);
437			}
438
439			set_packed_slice(partial_low_evals, scalar_index, inner_product);
440		}
441
442		Ok(())
443	}
444
445	fn subcube_partial_high_evals(
446		&self,
447		query: MultilinearQueryRef<P>,
448		subcube_vars: usize,
449		subcube_index: usize,
450		partial_high_evals: &mut [P],
451	) -> Result<(), Error> {
452		// TODO: think of a way to factor out duplicated implementation in direct & embedded adapters.
453		validate_subcube_partial_evals_params(
454			self.n_vars(),
455			query,
456			subcube_vars,
457			subcube_index,
458			partial_high_evals,
459		)?;
460
461		let query_n_vars = query.n_vars();
462
463		// REVIEW: not spending effort to optimize this as the future of switchover
464		//         is somewhat unclear in light of univariate skip
465		partial_high_evals.fill(P::zero());
466
467		for query_index in 0..1 << query_n_vars {
468			let query_factor = get_packed_slice(query.expansion(), query_index);
469			let subcube_start =
470				subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
471			for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
472				*packed += P::from_fn(|inner_index| {
473					let index = subcube_start | outer_index << P::LOG_WIDTH | inner_index;
474					query_factor * get_packed_slice(self.0.evals(), index)
475				});
476			}
477		}
478
479		if subcube_vars < P::LOG_WIDTH {
480			for i in 1 << subcube_vars..P::WIDTH {
481				partial_high_evals
482					.first_mut()
483					.expect("at least one")
484					.set(i, P::Scalar::ZERO);
485			}
486		}
487
488		Ok(())
489	}
490
491	fn subcube_evals(
492		&self,
493		subcube_vars: usize,
494		subcube_index: usize,
495		log_embedding_degree: usize,
496		evals: &mut [P],
497	) -> Result<(), Error> {
498		let n_vars = self.n_vars();
499		if subcube_vars > n_vars {
500			bail!(Error::ArgumentRangeError {
501				arg: "subcube_vars".to_string(),
502				range: 0..n_vars + 1,
503			});
504		}
505
506		if log_embedding_degree != 0 {
507			bail!(Error::LogEmbeddingDegreeTooLarge {
508				log_embedding_degree
509			});
510		}
511
512		let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
513		if evals.len() != correct_len {
514			bail!(Error::ArgumentRangeError {
515				arg: "evals.len()".to_string(),
516				range: correct_len..correct_len + 1,
517			});
518		}
519
520		let max_index = 1 << (n_vars - subcube_vars);
521		if subcube_index >= max_index {
522			bail!(Error::ArgumentRangeError {
523				arg: "subcube_index".to_string(),
524				range: 0..max_index,
525			});
526		}
527
528		if subcube_vars < P::LOG_WIDTH {
529			let subcube_start = subcube_index << subcube_vars;
530			for i in 0..1 << subcube_vars {
531				// Safety: subcube_index < max_index check
532				let scalar =
533					unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
534
535				// Safety: i < 1 << min(0, subcube_vars) <= correct_len * P::WIDTH
536				unsafe {
537					set_packed_slice_unchecked(evals, i, scalar);
538				}
539			}
540		} else {
541			let range = subcube_index << (subcube_vars - P::LOG_WIDTH)
542				..(subcube_index + 1) << (subcube_vars - P::LOG_WIDTH);
543			evals.copy_from_slice(&self.0.evals()[range]);
544		}
545
546		Ok(())
547	}
548
549	fn packed_evals(&self) -> Option<&[P]> {
550		Some(self.0.evals())
551	}
552}
553
554fn validate_subcube_partial_evals_params<P: PackedField>(
555	n_vars: usize,
556	query: MultilinearQueryRef<P>,
557	subcube_vars: usize,
558	subcube_index: usize,
559	partial_evals: &[P],
560) -> Result<(), Error> {
561	let query_n_vars = query.n_vars();
562	if query_n_vars + subcube_vars > n_vars {
563		bail!(Error::ArgumentRangeError {
564			arg: "query.n_vars() + subcube_vars".into(),
565			range: 0..n_vars,
566		});
567	}
568
569	let max_index = 1 << (n_vars - query_n_vars - subcube_vars);
570	if subcube_index >= max_index {
571		bail!(Error::ArgumentRangeError {
572			arg: "subcube_index".into(),
573			range: 0..max_index,
574		});
575	}
576
577	let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
578	if partial_evals.len() != correct_len {
579		bail!(Error::ArgumentRangeError {
580			arg: "partial_evals.len()".to_string(),
581			range: correct_len..correct_len + 1,
582		});
583	}
584
585	Ok(())
586}
587
588#[cfg(test)]
589mod tests {
590	use std::iter::repeat_with;
591
592	use binius_field::{
593		arch::OptimalUnderlier256b, as_packed_field::PackedType, BinaryField128b, BinaryField16b,
594		BinaryField32b, BinaryField8b, PackedBinaryField16x8b, PackedBinaryField1x128b,
595		PackedBinaryField4x128b, PackedBinaryField4x32b, PackedBinaryField64x8b,
596		PackedBinaryField8x16b, PackedExtension, PackedField, PackedFieldIndexable,
597	};
598	use rand::prelude::*;
599
600	use super::*;
601	use crate::{tensor_prod_eq_ind, MultilinearQuery};
602
603	type F = BinaryField16b;
604	type P = PackedBinaryField8x16b;
605
606	fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
607		let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
608		result[0] = P::set_single(P::Scalar::ONE);
609		tensor_prod_eq_ind(0, &mut result, p).unwrap();
610		MultilinearQuery::with_expansion(p.len(), result).unwrap()
611	}
612
613	#[test]
614	fn test_evaluate_subcube_and_evaluate_partial_consistent() {
615		let mut rng = StdRng::seed_from_u64(0);
616		let poly = MultilinearExtension::from_values(
617			repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
618				.take(1 << 8)
619				.collect(),
620		)
621		.unwrap()
622		.specialize::<PackedBinaryField1x128b>();
623
624		let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
625			.take(6)
626			.collect::<Vec<_>>();
627		let query = multilinear_query(&q);
628
629		let partial_low = poly.evaluate_partial_low(query.to_ref()).unwrap();
630		let partial_high = poly.evaluate_partial_high(query.to_ref()).unwrap();
631
632		let mut subcube_partial_low = vec![PackedBinaryField1x128b::zero(); 16];
633		let mut subcube_partial_high = vec![PackedBinaryField1x128b::zero(); 16];
634		poly.subcube_partial_low_evals(query.to_ref(), 4, 0, &mut subcube_partial_low)
635			.unwrap();
636		poly.subcube_partial_high_evals(query.to_ref(), 4, 0, &mut subcube_partial_high)
637			.unwrap();
638
639		for (idx, subcube_partial_low) in PackedField::iter_slice(&subcube_partial_low).enumerate()
640		{
641			assert_eq!(subcube_partial_low, partial_low.evaluate_on_hypercube(idx).unwrap(),);
642		}
643
644		for (idx, subcube_partial_high) in
645			PackedField::iter_slice(&subcube_partial_high).enumerate()
646		{
647			assert_eq!(subcube_partial_high, partial_high.evaluate_on_hypercube(idx).unwrap(),);
648		}
649	}
650
651	#[test]
652	fn test_evaluate_subcube_smaller_than_packed_width() {
653		let mut rng = StdRng::seed_from_u64(0);
654		let poly = MultilinearExtension::new(
655			2,
656			vec![PackedBinaryField64x8b::from_scalars(
657				[2, 2, 9, 9].map(BinaryField8b::new),
658			)],
659		)
660		.unwrap()
661		.specialize::<PackedBinaryField4x128b>();
662
663		let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
664			.take(1)
665			.collect::<Vec<_>>();
666		let query = multilinear_query(&q);
667
668		let mut subcube_partial_low = vec![PackedBinaryField4x128b::zero(); 1];
669		let mut subcube_partial_high = vec![PackedBinaryField4x128b::zero(); 1];
670		poly.subcube_partial_low_evals(query.to_ref(), 1, 0, &mut subcube_partial_low)
671			.unwrap();
672		poly.subcube_partial_high_evals(query.to_ref(), 1, 0, &mut subcube_partial_high)
673			.unwrap();
674
675		let expected_partial_high = BinaryField128b::new(2)
676			+ (BinaryField128b::new(2) + BinaryField128b::new(9)) * q.first().unwrap();
677
678		assert_eq!(get_packed_slice(&subcube_partial_low, 0), BinaryField128b::new(2));
679		assert_eq!(get_packed_slice(&subcube_partial_low, 1), BinaryField128b::new(9));
680		assert_eq!(get_packed_slice(&subcube_partial_low, 2), BinaryField128b::ZERO);
681		assert_eq!(get_packed_slice(&subcube_partial_low, 3), BinaryField128b::ZERO);
682		assert_eq!(get_packed_slice(&subcube_partial_high, 0), expected_partial_high);
683		assert_eq!(get_packed_slice(&subcube_partial_high, 1), expected_partial_high);
684		assert_eq!(get_packed_slice(&subcube_partial_high, 2), BinaryField128b::ZERO);
685		assert_eq!(get_packed_slice(&subcube_partial_high, 3), BinaryField128b::ZERO);
686	}
687
688	#[test]
689	fn test_subcube_evals_embeds_correctly() {
690		let mut rng = StdRng::seed_from_u64(0);
691
692		type P = PackedBinaryField16x8b;
693		type PE = PackedBinaryField1x128b;
694
695		let packed_count = 4;
696		let values: Vec<_> = repeat_with(|| P::random(&mut rng))
697			.take(1 << packed_count)
698			.collect();
699
700		let mle = MultilinearExtension::from_values(values).unwrap();
701		let mles = MLEEmbeddingAdapter::<P, PE, _>::from(mle);
702
703		let bytes_values = P::unpack_scalars(mles.0.evals());
704
705		let n_vars = packed_count + P::LOG_WIDTH;
706		let mut evals = vec![PE::zero(); 1 << n_vars];
707		for subcube_vars in 0..n_vars {
708			for subcube_index in 0..1 << (n_vars - subcube_vars) {
709				for log_embedding_degree in 0..=4 {
710					let evals_subcube = &mut evals
711						[0..1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH)];
712
713					mles.subcube_evals(
714						subcube_vars,
715						subcube_index,
716						log_embedding_degree,
717						evals_subcube,
718					)
719					.unwrap();
720
721					let bytes_evals = P::unpack_scalars(
722						<PE as PackedExtension<BinaryField8b>>::cast_bases(evals_subcube),
723					);
724
725					let shift = 4 - log_embedding_degree;
726					let skip_mask = (1 << shift) - 1;
727					for (i, &b_evals) in bytes_evals.iter().enumerate() {
728						let b_values = if i & skip_mask == 0 && i < 1 << (subcube_vars + shift) {
729							bytes_values[(subcube_index << subcube_vars) + (i >> shift)]
730						} else {
731							BinaryField8b::ZERO
732						};
733						assert_eq!(b_evals, b_values);
734					}
735				}
736			}
737		}
738	}
739
740	#[test]
741	fn test_subcube_partial_and_evaluate_partial_conform() {
742		let mut rng = StdRng::seed_from_u64(0);
743		let n_vars = 12;
744		let evals = repeat_with(|| P::random(&mut rng))
745			.take(1 << (n_vars - P::LOG_WIDTH))
746			.collect::<Vec<_>>();
747		let mle = MultilinearExtension::from_values(evals).unwrap();
748		let mles = MLEDirectAdapter::from(mle);
749		let q = repeat_with(|| Field::random(&mut rng))
750			.take(6)
751			.collect::<Vec<F>>();
752		let query = multilinear_query(&q);
753		let partial_low_eval = mles.evaluate_partial_low(query.to_ref()).unwrap();
754		let partial_high_eval = mles.evaluate_partial_high(query.to_ref()).unwrap();
755
756		let subcube_vars = 4;
757		let mut subcube_partial_low_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
758		let mut subcube_partial_high_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
759		for subcube_index in 0..(n_vars - query.n_vars() - subcube_vars) {
760			mles.subcube_partial_low_evals(
761				query.to_ref(),
762				subcube_vars,
763				subcube_index,
764				&mut subcube_partial_low_evals,
765			)
766			.unwrap();
767			mles.subcube_partial_high_evals(
768				query.to_ref(),
769				subcube_vars,
770				subcube_index,
771				&mut subcube_partial_high_evals,
772			)
773			.unwrap();
774			for hypercube_idx in 0..(1 << subcube_vars) {
775				assert_eq!(
776					get_packed_slice(&subcube_partial_low_evals, hypercube_idx),
777					partial_low_eval
778						.evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
779						.unwrap()
780				);
781				assert_eq!(
782					get_packed_slice(&subcube_partial_high_evals, hypercube_idx),
783					partial_high_eval
784						.evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
785						.unwrap()
786				);
787			}
788		}
789	}
790
791	#[test]
792	fn test_packed_evals_against_subcube_evals() {
793		type U = OptimalUnderlier256b;
794		type P = PackedType<U, BinaryField32b>;
795		type PExt = PackedType<U, BinaryField128b>;
796
797		let mut rng = StdRng::seed_from_u64(0);
798		let evals = repeat_with(|| P::random(&mut rng))
799			.take(2)
800			.collect::<Vec<_>>();
801		let mle = MultilinearExtension::from_values(evals.clone()).unwrap();
802		let poly = MLEEmbeddingAdapter::from(mle);
803		assert_eq!(
804			<PExt as PackedExtension<BinaryField32b>>::cast_bases(poly.packed_evals().unwrap()),
805			&evals
806		);
807
808		let mut evals_out = vec![PExt::zero(); 2];
809		poly.subcube_evals(poly.n_vars(), 0, poly.log_extension_degree(), evals_out.as_mut_slice())
810			.unwrap();
811		assert_eq!(evals_out, poly.packed_evals().unwrap());
812	}
813}