binius_math/
mle_adapters.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{fmt::Debug, marker::PhantomData, ops::Deref, sync::Arc};
4
5use binius_field::{
6	packed::{
7		get_packed_slice, get_packed_slice_unchecked, set_packed_slice, set_packed_slice_unchecked,
8	},
9	ExtensionField, Field, PackedField, RepackedExtension,
10};
11use binius_utils::bail;
12
13use super::{Error, MultilinearExtension, MultilinearPoly, MultilinearQueryRef};
14
15/// An adapter for [`MultilinearExtension`] that implements [`MultilinearPoly`] over a packed
16/// extension field.
17///
18/// This struct implements `MultilinearPoly` for an extension field of the base field that the
19/// multilinear extension is defined over.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct MLEEmbeddingAdapter<P, PE, Data = Vec<P>>(
22	MultilinearExtension<P, Data>,
23	PhantomData<PE>,
24)
25where
26	P: PackedField,
27	PE: PackedField,
28	PE::Scalar: ExtensionField<P::Scalar>,
29	Data: Deref<Target = [P]>;
30
31impl<'a, P, PE, Data> MLEEmbeddingAdapter<P, PE, Data>
32where
33	P: PackedField,
34	PE: PackedField + RepackedExtension<P>,
35	PE::Scalar: ExtensionField<P::Scalar>,
36	Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
37{
38	pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
39		Arc::new(self)
40	}
41}
42
43impl<P, PE, Data> From<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
44where
45	P: PackedField,
46	PE: PackedField,
47	PE::Scalar: ExtensionField<P::Scalar>,
48	Data: Deref<Target = [P]>,
49{
50	fn from(inner: MultilinearExtension<P, Data>) -> Self {
51		Self(inner, PhantomData)
52	}
53}
54
55impl<P, PE, Data> AsRef<MultilinearExtension<P, Data>> for MLEEmbeddingAdapter<P, PE, Data>
56where
57	P: PackedField,
58	PE: PackedField,
59	PE::Scalar: ExtensionField<P::Scalar>,
60	Data: Deref<Target = [P]>,
61{
62	fn as_ref(&self) -> &MultilinearExtension<P, Data> {
63		&self.0
64	}
65}
66
67impl<P, PE, Data> MultilinearPoly<PE> for MLEEmbeddingAdapter<P, PE, Data>
68where
69	P: PackedField + Debug,
70	PE: PackedField + RepackedExtension<P>,
71	PE::Scalar: ExtensionField<P::Scalar>,
72	Data: Deref<Target = [P]> + Send + Sync + Debug,
73{
74	fn n_vars(&self) -> usize {
75		self.0.n_vars()
76	}
77
78	fn log_extension_degree(&self) -> usize {
79		PE::Scalar::LOG_DEGREE
80	}
81
82	fn evaluate_on_hypercube(&self, index: usize) -> Result<PE::Scalar, Error> {
83		let eval = self.0.evaluate_on_hypercube(index)?;
84		Ok(eval.into())
85	}
86
87	fn evaluate_on_hypercube_and_scale(
88		&self,
89		index: usize,
90		scalar: PE::Scalar,
91	) -> Result<PE::Scalar, Error> {
92		let eval = self.0.evaluate_on_hypercube(index)?;
93		Ok(scalar * eval)
94	}
95
96	fn evaluate(&self, query: MultilinearQueryRef<PE>) -> Result<PE::Scalar, Error> {
97		self.0.evaluate(query)
98	}
99
100	fn evaluate_partial_low(
101		&self,
102		query: MultilinearQueryRef<PE>,
103	) -> Result<MultilinearExtension<PE>, Error> {
104		self.0.evaluate_partial_low(query)
105	}
106
107	fn evaluate_partial_high(
108		&self,
109		query: MultilinearQueryRef<PE>,
110	) -> Result<MultilinearExtension<PE>, Error> {
111		self.0.evaluate_partial_high(query)
112	}
113
114	fn subcube_partial_low_evals(
115		&self,
116		query: MultilinearQueryRef<PE>,
117		subcube_vars: usize,
118		subcube_index: usize,
119		partial_low_evals: &mut [PE],
120	) -> Result<(), Error> {
121		validate_subcube_partial_evals_params(
122			self.n_vars(),
123			query,
124			subcube_vars,
125			subcube_index,
126			partial_low_evals,
127		)?;
128
129		let query_n_vars = query.n_vars();
130		let subcube_start = subcube_index << (query_n_vars + subcube_vars);
131
132		// REVIEW: not spending effort to optimize this as the future of switchover
133		//         is somewhat unclear in light of univariate skip
134		for scalar_index in 0..1 << subcube_vars {
135			let evals_start = subcube_start + (scalar_index << query_n_vars);
136			let mut inner_product = PE::Scalar::ZERO;
137			for i in 0..1 << query_n_vars {
138				inner_product += get_packed_slice(query.expansion(), i)
139					* get_packed_slice(self.0.evals(), evals_start + i);
140			}
141
142			set_packed_slice(partial_low_evals, scalar_index, inner_product);
143		}
144
145		Ok(())
146	}
147
148	fn subcube_partial_high_evals(
149		&self,
150		query: MultilinearQueryRef<PE>,
151		subcube_vars: usize,
152		subcube_index: usize,
153		partial_high_evals: &mut [PE],
154	) -> Result<(), Error> {
155		validate_subcube_partial_evals_params(
156			self.n_vars(),
157			query,
158			subcube_vars,
159			subcube_index,
160			partial_high_evals,
161		)?;
162
163		let query_n_vars = query.n_vars();
164
165		// REVIEW: not spending effort to optimize this as the future of switchover
166		//         is somewhat unclear in light of univariate skip
167		partial_high_evals.fill(PE::zero());
168
169		for query_index in 0..1 << query_n_vars {
170			let query_factor = get_packed_slice(query.expansion(), query_index);
171			let subcube_start =
172				subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
173			for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
174				*packed += PE::from_fn(|inner_index| {
175					let index = subcube_start | outer_index << PE::LOG_WIDTH | inner_index;
176					query_factor * get_packed_slice(self.0.evals(), index)
177				});
178			}
179		}
180
181		if subcube_vars < PE::LOG_WIDTH {
182			for i in 1 << subcube_vars..PE::WIDTH {
183				partial_high_evals
184					.first_mut()
185					.expect("at least one")
186					.set(i, PE::Scalar::ZERO);
187			}
188		}
189
190		Ok(())
191	}
192
193	fn subcube_evals(
194		&self,
195		subcube_vars: usize,
196		subcube_index: usize,
197		log_embedding_degree: usize,
198		evals: &mut [PE],
199	) -> Result<(), Error> {
200		let log_extension_degree = PE::Scalar::LOG_DEGREE;
201
202		if subcube_vars > self.n_vars() {
203			bail!(Error::ArgumentRangeError {
204				arg: "subcube_vars".to_string(),
205				range: 0..self.n_vars() + 1,
206			});
207		}
208
209		// Check that chosen embedding subfield is large enough.
210		// We also use a stack allocated array of bases, which imposes
211		// a maximum tower height restriction.
212		const MAX_TOWER_HEIGHT: usize = 7;
213		if log_embedding_degree > log_extension_degree.min(MAX_TOWER_HEIGHT) {
214			bail!(Error::LogEmbeddingDegreeTooLarge {
215				log_embedding_degree
216			});
217		}
218
219		let correct_len = 1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH);
220		if evals.len() != correct_len {
221			bail!(Error::ArgumentRangeError {
222				arg: "evals.len()".to_string(),
223				range: correct_len..correct_len + 1,
224			});
225		}
226
227		let max_index = 1 << (self.n_vars() - subcube_vars);
228		if subcube_index >= max_index {
229			bail!(Error::ArgumentRangeError {
230				arg: "subcube_index".to_string(),
231				range: 0..max_index,
232			});
233		}
234
235		let subcube_start = subcube_index << subcube_vars;
236
237		if log_embedding_degree == 0 {
238			// One-to-one embedding can bypass the extension field construction overhead.
239			for i in 0..1 << subcube_vars {
240				// Safety: subcube_index < max_index check
241				let scalar =
242					unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
243
244				let extension_scalar = scalar.into();
245
246				// Safety: i < 1 << min(0, subcube_vars) <= correct_len * PE::WIDTH
247				unsafe {
248					set_packed_slice_unchecked(evals, i, extension_scalar);
249				}
250			}
251		} else {
252			// For many-to-one embedding, use ExtensionField::from_bases_sparse
253			let mut bases = [P::Scalar::default(); 1 << MAX_TOWER_HEIGHT];
254			let bases = &mut bases[0..1 << log_embedding_degree];
255
256			let bases_count = 1 << log_embedding_degree.min(subcube_vars);
257			for i in 0..1 << subcube_vars.saturating_sub(log_embedding_degree) {
258				for (j, base) in bases[..bases_count].iter_mut().enumerate() {
259					// Safety: i > 0 iff log_embedding_degree < subcube_vars and subcube_index < max_index check
260					*base = unsafe {
261						get_packed_slice_unchecked(
262							self.0.evals(),
263							subcube_start + (i << log_embedding_degree) + j,
264						)
265					};
266				}
267
268				let extension_scalar = PE::Scalar::from_bases_sparse(
269					bases,
270					log_extension_degree - log_embedding_degree,
271				)?;
272
273				// Safety: i < 1 << min(0, subcube_vars - log_embedding_degree) <= correct_len * PE::WIDTH
274				unsafe {
275					set_packed_slice_unchecked(evals, i, extension_scalar);
276				}
277			}
278		}
279
280		Ok(())
281	}
282
283	fn packed_evals(&self) -> Option<&[PE]> {
284		Some(PE::cast_exts(self.0.evals()))
285	}
286}
287
288impl<P, Data> MultilinearExtension<P, Data>
289where
290	P: PackedField,
291	Data: Deref<Target = [P]>,
292{
293	pub fn specialize<PE>(self) -> MLEEmbeddingAdapter<P, PE, Data>
294	where
295		PE: PackedField,
296		PE::Scalar: ExtensionField<P::Scalar>,
297	{
298		MLEEmbeddingAdapter::from(self)
299	}
300}
301
302impl<'a, P, Data> MultilinearExtension<P, Data>
303where
304	P: PackedField,
305	Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
306{
307	pub fn specialize_arc_dyn<PE: RepackedExtension<P>>(
308		self,
309	) -> Arc<dyn MultilinearPoly<PE> + Send + Sync + 'a> {
310		self.specialize().upcast_arc_dyn()
311	}
312}
313
314/// An adapter for [`MultilinearExtension`] that implements [`MultilinearPoly`] over the same
315/// packed field that the [`MultilinearExtension`] stores evaluations in.
316#[derive(Debug, Clone, PartialEq, Eq)]
317pub struct MLEDirectAdapter<P, Data = Vec<P>>(MultilinearExtension<P, Data>)
318where
319	P: PackedField,
320	Data: Deref<Target = [P]>;
321
322impl<'a, P, Data> MLEDirectAdapter<P, Data>
323where
324	P: PackedField,
325	Data: Deref<Target = [P]> + Send + Sync + Debug + 'a,
326{
327	pub fn upcast_arc_dyn(self) -> Arc<dyn MultilinearPoly<P> + Send + Sync + 'a> {
328		Arc::new(self)
329	}
330}
331
332impl<P, Data> From<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
333where
334	P: PackedField,
335	Data: Deref<Target = [P]>,
336{
337	fn from(inner: MultilinearExtension<P, Data>) -> Self {
338		Self(inner)
339	}
340}
341
342impl<P, Data> AsRef<MultilinearExtension<P, Data>> for MLEDirectAdapter<P, Data>
343where
344	P: PackedField,
345	Data: Deref<Target = [P]>,
346{
347	fn as_ref(&self) -> &MultilinearExtension<P, Data> {
348		&self.0
349	}
350}
351
352impl<F, P, Data> MultilinearPoly<P> for MLEDirectAdapter<P, Data>
353where
354	F: Field,
355	P: PackedField<Scalar = F>,
356	Data: Deref<Target = [P]> + Send + Sync + Debug,
357{
358	#[inline]
359	fn n_vars(&self) -> usize {
360		self.0.n_vars()
361	}
362
363	#[inline]
364	fn log_extension_degree(&self) -> usize {
365		0
366	}
367
368	fn evaluate_on_hypercube(&self, index: usize) -> Result<F, Error> {
369		self.0.evaluate_on_hypercube(index)
370	}
371
372	fn evaluate_on_hypercube_and_scale(&self, index: usize, scalar: F) -> Result<F, Error> {
373		let eval = self.0.evaluate_on_hypercube(index)?;
374		Ok(scalar * eval)
375	}
376
377	fn evaluate(&self, query: MultilinearQueryRef<P>) -> Result<F, Error> {
378		self.0.evaluate(query)
379	}
380
381	fn evaluate_partial_low(
382		&self,
383		query: MultilinearQueryRef<P>,
384	) -> Result<MultilinearExtension<P>, Error> {
385		self.0.evaluate_partial_low(query)
386	}
387
388	fn evaluate_partial_high(
389		&self,
390		query: MultilinearQueryRef<P>,
391	) -> Result<MultilinearExtension<P>, Error> {
392		self.0.evaluate_partial_high(query)
393	}
394
395	fn subcube_partial_low_evals(
396		&self,
397		query: MultilinearQueryRef<P>,
398		subcube_vars: usize,
399		subcube_index: usize,
400		partial_low_evals: &mut [P],
401	) -> Result<(), Error> {
402		// TODO: think of a way to factor out duplicated implementation in direct & embedded adapters.
403		validate_subcube_partial_evals_params(
404			self.n_vars(),
405			query,
406			subcube_vars,
407			subcube_index,
408			partial_low_evals,
409		)?;
410
411		let query_n_vars = query.n_vars();
412		let subcube_start = subcube_index << (query_n_vars + subcube_vars);
413
414		// TODO: Maybe optimize me
415		for scalar_index in 0..1 << subcube_vars {
416			let evals_start = subcube_start + (scalar_index << query_n_vars);
417			let mut inner_product = F::ZERO;
418			for i in 0..1 << query_n_vars {
419				inner_product += get_packed_slice(query.expansion(), i)
420					* get_packed_slice(self.0.evals(), evals_start + i);
421			}
422
423			set_packed_slice(partial_low_evals, scalar_index, inner_product);
424		}
425
426		Ok(())
427	}
428
429	fn subcube_partial_high_evals(
430		&self,
431		query: MultilinearQueryRef<P>,
432		subcube_vars: usize,
433		subcube_index: usize,
434		partial_high_evals: &mut [P],
435	) -> Result<(), Error> {
436		// TODO: think of a way to factor out duplicated implementation in direct & embedded adapters.
437		validate_subcube_partial_evals_params(
438			self.n_vars(),
439			query,
440			subcube_vars,
441			subcube_index,
442			partial_high_evals,
443		)?;
444
445		let query_n_vars = query.n_vars();
446
447		// REVIEW: not spending effort to optimize this as the future of switchover
448		//         is somewhat unclear in light of univariate skip
449		partial_high_evals.fill(P::zero());
450
451		for query_index in 0..1 << query_n_vars {
452			let query_factor = get_packed_slice(query.expansion(), query_index);
453			let subcube_start =
454				subcube_index << subcube_vars | query_index << (self.n_vars() - query_n_vars);
455			for (outer_index, packed) in partial_high_evals.iter_mut().enumerate() {
456				*packed += P::from_fn(|inner_index| {
457					let index = subcube_start | outer_index << P::LOG_WIDTH | inner_index;
458					query_factor * get_packed_slice(self.0.evals(), index)
459				});
460			}
461		}
462
463		if subcube_vars < P::LOG_WIDTH {
464			for i in 1 << subcube_vars..P::WIDTH {
465				partial_high_evals
466					.first_mut()
467					.expect("at least one")
468					.set(i, P::Scalar::ZERO);
469			}
470		}
471
472		Ok(())
473	}
474
475	fn subcube_evals(
476		&self,
477		subcube_vars: usize,
478		subcube_index: usize,
479		log_embedding_degree: usize,
480		evals: &mut [P],
481	) -> Result<(), Error> {
482		let n_vars = self.n_vars();
483		if subcube_vars > n_vars {
484			bail!(Error::ArgumentRangeError {
485				arg: "subcube_vars".to_string(),
486				range: 0..n_vars + 1,
487			});
488		}
489
490		if log_embedding_degree != 0 {
491			bail!(Error::LogEmbeddingDegreeTooLarge {
492				log_embedding_degree
493			});
494		}
495
496		let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
497		if evals.len() != correct_len {
498			bail!(Error::ArgumentRangeError {
499				arg: "evals.len()".to_string(),
500				range: correct_len..correct_len + 1,
501			});
502		}
503
504		let max_index = 1 << (n_vars - subcube_vars);
505		if subcube_index >= max_index {
506			bail!(Error::ArgumentRangeError {
507				arg: "subcube_index".to_string(),
508				range: 0..max_index,
509			});
510		}
511
512		if subcube_vars < P::LOG_WIDTH {
513			let subcube_start = subcube_index << subcube_vars;
514			for i in 0..1 << subcube_vars {
515				// Safety: subcube_index < max_index check
516				let scalar =
517					unsafe { get_packed_slice_unchecked(self.0.evals(), subcube_start + i) };
518
519				// Safety: i < 1 << min(0, subcube_vars) <= correct_len * P::WIDTH
520				unsafe {
521					set_packed_slice_unchecked(evals, i, scalar);
522				}
523			}
524		} else {
525			let range = subcube_index << (subcube_vars - P::LOG_WIDTH)
526				..(subcube_index + 1) << (subcube_vars - P::LOG_WIDTH);
527			evals.copy_from_slice(&self.0.evals()[range]);
528		}
529
530		Ok(())
531	}
532
533	fn packed_evals(&self) -> Option<&[P]> {
534		Some(self.0.evals())
535	}
536}
537
538fn validate_subcube_partial_evals_params<P: PackedField>(
539	n_vars: usize,
540	query: MultilinearQueryRef<P>,
541	subcube_vars: usize,
542	subcube_index: usize,
543	partial_evals: &[P],
544) -> Result<(), Error> {
545	let query_n_vars = query.n_vars();
546	if query_n_vars + subcube_vars > n_vars {
547		bail!(Error::ArgumentRangeError {
548			arg: "query.n_vars() + subcube_vars".into(),
549			range: 0..n_vars,
550		});
551	}
552
553	let max_index = 1 << (n_vars - query_n_vars - subcube_vars);
554	if subcube_index >= max_index {
555		bail!(Error::ArgumentRangeError {
556			arg: "subcube_index".into(),
557			range: 0..max_index,
558		});
559	}
560
561	let correct_len = 1 << subcube_vars.saturating_sub(P::LOG_WIDTH);
562	if partial_evals.len() != correct_len {
563		bail!(Error::ArgumentRangeError {
564			arg: "partial_evals.len()".to_string(),
565			range: correct_len..correct_len + 1,
566		});
567	}
568
569	Ok(())
570}
571
572#[cfg(test)]
573mod tests {
574	use std::iter::repeat_with;
575
576	use binius_field::{
577		arch::OptimalUnderlier256b, as_packed_field::PackedType, BinaryField128b, BinaryField16b,
578		BinaryField32b, BinaryField8b, PackedBinaryField16x8b, PackedBinaryField1x128b,
579		PackedBinaryField4x128b, PackedBinaryField4x32b, PackedBinaryField64x8b,
580		PackedBinaryField8x16b, PackedExtension, PackedField, PackedFieldIndexable,
581	};
582	use rand::prelude::*;
583
584	use super::*;
585	use crate::{tensor_prod_eq_ind, MultilinearQuery};
586
587	type F = BinaryField16b;
588	type P = PackedBinaryField8x16b;
589
590	fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
591		let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
592		result[0] = P::set_single(P::Scalar::ONE);
593		tensor_prod_eq_ind(0, &mut result, p).unwrap();
594		MultilinearQuery::with_expansion(p.len(), result).unwrap()
595	}
596
597	#[test]
598	fn test_evaluate_subcube_and_evaluate_partial_consistent() {
599		let mut rng = StdRng::seed_from_u64(0);
600		let poly = MultilinearExtension::from_values(
601			repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
602				.take(1 << 8)
603				.collect(),
604		)
605		.unwrap()
606		.specialize::<PackedBinaryField1x128b>();
607
608		let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
609			.take(6)
610			.collect::<Vec<_>>();
611		let query = multilinear_query(&q);
612
613		let partial_low = poly.evaluate_partial_low(query.to_ref()).unwrap();
614		let partial_high = poly.evaluate_partial_high(query.to_ref()).unwrap();
615
616		let mut subcube_partial_low = vec![PackedBinaryField1x128b::zero(); 16];
617		let mut subcube_partial_high = vec![PackedBinaryField1x128b::zero(); 16];
618		poly.subcube_partial_low_evals(query.to_ref(), 4, 0, &mut subcube_partial_low)
619			.unwrap();
620		poly.subcube_partial_high_evals(query.to_ref(), 4, 0, &mut subcube_partial_high)
621			.unwrap();
622
623		for (idx, subcube_partial_low) in PackedField::iter_slice(&subcube_partial_low).enumerate()
624		{
625			assert_eq!(subcube_partial_low, partial_low.evaluate_on_hypercube(idx).unwrap(),);
626		}
627
628		for (idx, subcube_partial_high) in
629			PackedField::iter_slice(&subcube_partial_high).enumerate()
630		{
631			assert_eq!(subcube_partial_high, partial_high.evaluate_on_hypercube(idx).unwrap(),);
632		}
633	}
634
635	#[test]
636	fn test_evaluate_subcube_smaller_than_packed_width() {
637		let mut rng = StdRng::seed_from_u64(0);
638		let poly = MultilinearExtension::new(
639			2,
640			vec![PackedBinaryField64x8b::from_scalars(
641				[2, 2, 9, 9].map(BinaryField8b::new),
642			)],
643		)
644		.unwrap()
645		.specialize::<PackedBinaryField4x128b>();
646
647		let q = repeat_with(|| <BinaryField128b as PackedField>::random(&mut rng))
648			.take(1)
649			.collect::<Vec<_>>();
650		let query = multilinear_query(&q);
651
652		let mut subcube_partial_low = vec![PackedBinaryField4x128b::zero(); 1];
653		let mut subcube_partial_high = vec![PackedBinaryField4x128b::zero(); 1];
654		poly.subcube_partial_low_evals(query.to_ref(), 1, 0, &mut subcube_partial_low)
655			.unwrap();
656		poly.subcube_partial_high_evals(query.to_ref(), 1, 0, &mut subcube_partial_high)
657			.unwrap();
658
659		let expected_partial_high = BinaryField128b::new(2)
660			+ (BinaryField128b::new(2) + BinaryField128b::new(9)) * q.first().unwrap();
661
662		assert_eq!(get_packed_slice(&subcube_partial_low, 0), BinaryField128b::new(2));
663		assert_eq!(get_packed_slice(&subcube_partial_low, 1), BinaryField128b::new(9));
664		assert_eq!(get_packed_slice(&subcube_partial_low, 2), BinaryField128b::ZERO);
665		assert_eq!(get_packed_slice(&subcube_partial_low, 3), BinaryField128b::ZERO);
666		assert_eq!(get_packed_slice(&subcube_partial_high, 0), expected_partial_high);
667		assert_eq!(get_packed_slice(&subcube_partial_high, 1), expected_partial_high);
668		assert_eq!(get_packed_slice(&subcube_partial_high, 2), BinaryField128b::ZERO);
669		assert_eq!(get_packed_slice(&subcube_partial_high, 3), BinaryField128b::ZERO);
670	}
671
672	#[test]
673	fn test_subcube_evals_embeds_correctly() {
674		let mut rng = StdRng::seed_from_u64(0);
675
676		type P = PackedBinaryField16x8b;
677		type PE = PackedBinaryField1x128b;
678
679		let packed_count = 4;
680		let values: Vec<_> = repeat_with(|| P::random(&mut rng))
681			.take(1 << packed_count)
682			.collect();
683
684		let mle = MultilinearExtension::from_values(values).unwrap();
685		let mles = MLEEmbeddingAdapter::<P, PE, _>::from(mle);
686
687		let bytes_values = P::unpack_scalars(mles.0.evals());
688
689		let n_vars = packed_count + P::LOG_WIDTH;
690		let mut evals = vec![PE::zero(); 1 << n_vars];
691		for subcube_vars in 0..n_vars {
692			for subcube_index in 0..1 << (n_vars - subcube_vars) {
693				for log_embedding_degree in 0..=4 {
694					let evals_subcube = &mut evals
695						[0..1 << subcube_vars.saturating_sub(log_embedding_degree + PE::LOG_WIDTH)];
696
697					mles.subcube_evals(
698						subcube_vars,
699						subcube_index,
700						log_embedding_degree,
701						evals_subcube,
702					)
703					.unwrap();
704
705					let bytes_evals = P::unpack_scalars(
706						<PE as PackedExtension<BinaryField8b>>::cast_bases(evals_subcube),
707					);
708
709					let shift = 4 - log_embedding_degree;
710					let skip_mask = (1 << shift) - 1;
711					for (i, &b_evals) in bytes_evals.iter().enumerate() {
712						let b_values = if i & skip_mask == 0 && i < 1 << (subcube_vars + shift) {
713							bytes_values[(subcube_index << subcube_vars) + (i >> shift)]
714						} else {
715							BinaryField8b::ZERO
716						};
717						assert_eq!(b_evals, b_values);
718					}
719				}
720			}
721		}
722	}
723
724	#[test]
725	fn test_subcube_partial_and_evaluate_partial_conform() {
726		let mut rng = StdRng::seed_from_u64(0);
727		let n_vars = 12;
728		let evals = repeat_with(|| P::random(&mut rng))
729			.take(1 << (n_vars - P::LOG_WIDTH))
730			.collect::<Vec<_>>();
731		let mle = MultilinearExtension::from_values(evals).unwrap();
732		let mles = MLEDirectAdapter::from(mle);
733		let q = repeat_with(|| Field::random(&mut rng))
734			.take(6)
735			.collect::<Vec<F>>();
736		let query = multilinear_query(&q);
737		let partial_low_eval = mles.evaluate_partial_low(query.to_ref()).unwrap();
738		let partial_high_eval = mles.evaluate_partial_high(query.to_ref()).unwrap();
739
740		let subcube_vars = 4;
741		let mut subcube_partial_low_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
742		let mut subcube_partial_high_evals = vec![P::default(); 1 << (subcube_vars - P::LOG_WIDTH)];
743		for subcube_index in 0..(n_vars - query.n_vars() - subcube_vars) {
744			mles.subcube_partial_low_evals(
745				query.to_ref(),
746				subcube_vars,
747				subcube_index,
748				&mut subcube_partial_low_evals,
749			)
750			.unwrap();
751			mles.subcube_partial_high_evals(
752				query.to_ref(),
753				subcube_vars,
754				subcube_index,
755				&mut subcube_partial_high_evals,
756			)
757			.unwrap();
758			for hypercube_idx in 0..(1 << subcube_vars) {
759				assert_eq!(
760					get_packed_slice(&subcube_partial_low_evals, hypercube_idx),
761					partial_low_eval
762						.evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
763						.unwrap()
764				);
765				assert_eq!(
766					get_packed_slice(&subcube_partial_high_evals, hypercube_idx),
767					partial_high_eval
768						.evaluate_on_hypercube(hypercube_idx + (subcube_index << subcube_vars))
769						.unwrap()
770				);
771			}
772		}
773	}
774
775	#[test]
776	fn test_packed_evals_against_subcube_evals() {
777		type U = OptimalUnderlier256b;
778		type P = PackedType<U, BinaryField32b>;
779		type PExt = PackedType<U, BinaryField128b>;
780
781		let mut rng = StdRng::seed_from_u64(0);
782		let evals = repeat_with(|| P::random(&mut rng))
783			.take(2)
784			.collect::<Vec<_>>();
785		let mle = MultilinearExtension::from_values(evals.clone()).unwrap();
786		let poly = MLEEmbeddingAdapter::from(mle);
787		assert_eq!(
788			<PExt as PackedExtension<BinaryField32b>>::cast_bases(poly.packed_evals().unwrap()),
789			&evals
790		);
791
792		let mut evals_out = vec![PExt::zero(); 2];
793		poly.subcube_evals(poly.n_vars(), 0, poly.log_extension_degree(), evals_out.as_mut_slice())
794			.unwrap();
795		assert_eq!(evals_out, poly.packed_evals().unwrap());
796	}
797}