1use std::{fmt::Debug, ops::Deref};
4
5use binius_field::{
6 as_packed_field::{AsSinglePacked, PackScalar, PackedType},
7 underlier::UnderlierType,
8 util::inner_product_par,
9 ExtensionField, Field, PackedField,
10};
11use binius_utils::bail;
12use bytemuck::zeroed_vec;
13use tracing::instrument;
14
15use crate::{fold::fold_left, fold_right, Error, MultilinearQueryRef, PackingDeref};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct MultilinearExtension<P: PackedField, Data: Deref<Target = [P]> = Vec<P>> {
25 mu: usize,
27 evals: Data,
29}
30
31impl<P: PackedField> MultilinearExtension<P> {
32 pub fn zeros(n_vars: usize) -> Result<Self, Error> {
33 if n_vars < P::LOG_WIDTH {
34 bail!(Error::ArgumentRangeError {
35 arg: "n_vars".into(),
36 range: P::LOG_WIDTH..32,
37 });
38 }
39 Ok(Self {
40 mu: n_vars,
41 evals: vec![P::default(); 1 << (n_vars - P::LOG_WIDTH)],
42 })
43 }
44
45 pub fn from_values(v: Vec<P>) -> Result<Self, Error> {
46 Self::from_values_generic(v)
47 }
48
49 pub fn into_evals(self) -> Vec<P> {
50 self.evals
51 }
52}
53
54impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
55 pub fn from_values_generic(v: Data) -> Result<Self, Error> {
56 if !v.len().is_power_of_two() {
57 bail!(Error::PowerOfTwoLengthRequired);
58 }
59 let mu = log2(v.len()) + P::LOG_WIDTH;
60
61 Ok(Self { mu, evals: v })
62 }
63
64 pub fn new(n_vars: usize, v: Data) -> Result<Self, Error> {
65 if !v.len().is_power_of_two() {
66 bail!(Error::PowerOfTwoLengthRequired);
67 }
68
69 if n_vars < P::LOG_WIDTH {
70 if v.len() != 1 {
71 bail!(Error::IncorrectNumberOfVariables {
72 expected: n_vars,
73 actual: P::LOG_WIDTH + log2(v.len())
74 });
75 }
76 } else if P::LOG_WIDTH + log2(v.len()) != n_vars {
77 bail!(Error::IncorrectNumberOfVariables {
78 expected: n_vars,
79 actual: P::LOG_WIDTH + log2(v.len())
80 });
81 }
82
83 Ok(Self {
84 mu: n_vars,
85 evals: v,
86 })
87 }
88}
89
90impl<U, F, Data> MultilinearExtension<PackedType<U, F>, PackingDeref<U, F, Data>>
91where
92 U: UnderlierType + PackScalar<F>,
94 F: Field,
95 Data: Deref<Target = [U]>,
96{
97 pub fn from_underliers(v: Data) -> Result<Self, Error> {
98 Self::from_values_generic(PackingDeref::new(v))
99 }
100}
101
102impl<'a, P: PackedField> MultilinearExtension<P, &'a [P]> {
103 pub fn from_values_slice(v: &'a [P]) -> Result<Self, Error> {
104 if !v.len().is_power_of_two() {
105 bail!(Error::PowerOfTwoLengthRequired);
106 }
107 let mu = log2(v.len() * P::WIDTH);
108 Ok(Self { mu, evals: v })
109 }
110}
111
112impl<P: PackedField, Data: Deref<Target = [P]>> MultilinearExtension<P, Data> {
113 pub const fn n_vars(&self) -> usize {
114 self.mu
115 }
116
117 pub const fn size(&self) -> usize {
118 1 << self.mu
119 }
120
121 pub fn evals(&self) -> &[P] {
122 &self.evals
123 }
124
125 pub fn to_ref(&self) -> MultilinearExtension<P, &[P]> {
126 MultilinearExtension {
127 mu: self.mu,
128 evals: self.evals(),
129 }
130 }
131
132 pub fn packed_evaluate_on_hypercube(&self, index: usize) -> Result<P, Error> {
139 self.evals()
140 .get(index)
141 .ok_or(Error::HypercubeIndexOutOfRange { index })
142 .copied()
143 }
144
145 pub fn evaluate_on_hypercube(&self, index: usize) -> Result<P::Scalar, Error> {
146 if self.size() <= index {
147 bail!(Error::HypercubeIndexOutOfRange { index })
148 }
149
150 let subcube_eval = self.packed_evaluate_on_hypercube(index / P::WIDTH)?;
151 Ok(subcube_eval.get(index % P::WIDTH))
152 }
153}
154
155impl<P, Data> MultilinearExtension<P, Data>
156where
157 P: PackedField,
158 Data: Deref<Target = [P]> + Send + Sync,
159{
160 pub fn evaluate<'a, FE, PE>(
161 &self,
162 query: impl Into<MultilinearQueryRef<'a, PE>>,
163 ) -> Result<FE, Error>
164 where
165 FE: ExtensionField<P::Scalar>,
166 PE: PackedField<Scalar = FE>,
167 {
168 let query = query.into();
169 if self.mu != query.n_vars() {
170 bail!(Error::IncorrectQuerySize { expected: self.mu });
171 }
172
173 if self.mu < P::LOG_WIDTH || query.n_vars() < PE::LOG_WIDTH {
174 let evals = PackedField::iter_slice(self.evals())
175 .take(self.size())
176 .collect::<Vec<P::Scalar>>();
177 let querys = PackedField::iter_slice(query.expansion())
178 .take(1 << query.n_vars())
179 .collect::<Vec<PE::Scalar>>();
180 Ok(inner_product_par(&querys, &evals))
181 } else {
182 Ok(inner_product_par(query.expansion(), &self.evals))
183 }
184 }
185
186 #[instrument(
196 "MultilinearExtension::evaluate_partial_high",
197 skip_all,
198 level = "debug"
199 )]
200 pub fn evaluate_partial_high<'a, PE>(
201 &self,
202 query: impl Into<MultilinearQueryRef<'a, PE>>,
203 ) -> Result<MultilinearExtension<PE>, Error>
204 where
205 PE: PackedField,
206 PE::Scalar: ExtensionField<P::Scalar>,
207 {
208 let query = query.into();
209
210 let new_n_vars = self.mu.saturating_sub(query.n_vars());
211 let result_evals_len = 1 << (new_n_vars.saturating_sub(PE::LOG_WIDTH));
212 let mut result_evals = Vec::with_capacity(result_evals_len);
213
214 fold_left(
215 self.evals(),
216 self.mu,
217 query.expansion(),
218 query.n_vars(),
219 result_evals.spare_capacity_mut(),
220 )?;
221 unsafe {
222 result_evals.set_len(result_evals_len);
223 }
224
225 MultilinearExtension::new(new_n_vars, result_evals)
226 }
227
228 #[instrument(
238 "MultilinearExtension::evaluate_partial_low",
239 skip_all,
240 level = "trace"
241 )]
242 pub fn evaluate_partial_low<'a, PE>(
243 &self,
244 query: impl Into<MultilinearQueryRef<'a, PE>>,
245 ) -> Result<MultilinearExtension<PE>, Error>
246 where
247 PE: PackedField,
248 PE::Scalar: ExtensionField<P::Scalar>,
249 {
250 let query = query.into();
251
252 if self.mu < query.n_vars() {
253 bail!(Error::IncorrectQuerySize { expected: self.mu });
254 }
255
256 let new_n_vars = self.mu - query.n_vars();
257
258 let mut result =
259 zeroed_vec(1 << ((self.mu - query.n_vars()).saturating_sub(PE::LOG_WIDTH)));
260 self.evaluate_partial_low_into(query, &mut result)?;
261 MultilinearExtension::new(new_n_vars, result)
262 }
263
264 pub fn evaluate_partial_low_into<PE>(
274 &self,
275 query: MultilinearQueryRef<PE>,
276 out: &mut [PE],
277 ) -> Result<(), Error>
278 where
279 PE: PackedField,
280 PE::Scalar: ExtensionField<P::Scalar>,
281 {
282 fold_right(&self.evals, self.mu, query.expansion(), query.n_vars(), out)
285 }
286}
287
288impl<F: Field + AsSinglePacked, Data: Deref<Target = [F]>> MultilinearExtension<F, Data> {
289 pub fn to_single_packed(self) -> MultilinearExtension<F::Packed> {
291 let packed_evals = self
292 .evals
293 .iter()
294 .map(|eval| eval.to_single_packed())
295 .collect();
296 MultilinearExtension {
297 mu: self.mu,
298 evals: packed_evals,
299 }
300 }
301}
302
303const fn log2(v: usize) -> usize {
304 63 - (v as u64).leading_zeros() as usize
305}
306
307pub type MultilinearExtensionBorrowed<'a, P> = MultilinearExtension<P, &'a [P]>;
309
310#[cfg(test)]
311mod tests {
312 use std::iter::repeat_with;
313
314 use binius_field::{
315 arch::OptimalUnderlier256b, BinaryField128b, BinaryField16b as F, BinaryField32b,
316 BinaryField8b, PackedBinaryField16x8b, PackedBinaryField4x32b, PackedBinaryField8x16b as P,
317 };
318 use itertools::Itertools;
319 use rand::{rngs::StdRng, SeedableRng};
320
321 use super::*;
322 use crate::{tensor_prod_eq_ind, MultilinearQuery};
323
324 fn expand_query_naive<F: Field>(query: &[F]) -> Result<Vec<F>, Error> {
330 let result = (0..1 << query.len())
331 .map(|i| eval_basis(query, i))
332 .collect();
333 Ok(result)
334 }
335
336 fn eval_basis<F: Field>(query: &[F], i: usize) -> F {
338 query
339 .iter()
340 .enumerate()
341 .map(|(j, &v)| if i & (1 << j) == 0 { F::ONE - v } else { v })
342 .product()
343 }
344
345 fn multilinear_query<P: PackedField>(p: &[P::Scalar]) -> MultilinearQuery<P, Vec<P>> {
346 let mut result = vec![P::default(); 1 << p.len().saturating_sub(P::LOG_WIDTH)];
347 result[0] = P::set_single(P::Scalar::ONE);
348 tensor_prod_eq_ind(0, &mut result, p).unwrap();
349 MultilinearQuery::with_expansion(p.len(), result).unwrap()
350 }
351
352 #[test]
353 fn test_expand_query_impls_consistent() {
354 let mut rng = StdRng::seed_from_u64(0);
355 let q = repeat_with(|| Field::random(&mut rng))
356 .take(8)
357 .collect::<Vec<F>>();
358 let result1 = multilinear_query::<P>(&q);
359 let result2 = expand_query_naive(&q).unwrap();
360 assert_eq!(PackedField::iter_slice(result1.expansion()).collect_vec(), result2);
361 }
362
363 #[test]
364 fn test_new_from_values_correspondence() {
365 let mut rng = StdRng::seed_from_u64(0);
366 let evals = repeat_with(|| Field::random(&mut rng))
367 .take(256)
368 .collect::<Vec<F>>();
369 let poly1 = MultilinearExtension::from_values(evals.clone()).unwrap();
370 let poly2 = MultilinearExtension::new(8, evals).unwrap();
371
372 assert_eq!(poly1, poly2)
373 }
374
375 #[test]
376 fn test_evaluate_on_hypercube() {
377 let mut values = vec![F::ZERO; 64];
378 values
379 .iter_mut()
380 .enumerate()
381 .for_each(|(i, val)| *val = F::new(i as u16));
382
383 let poly = MultilinearExtension::from_values(values).unwrap();
384 for i in 0..64 {
385 let q = (0..6)
386 .map(|j| if (i >> j) & 1 != 0 { F::ONE } else { F::ZERO })
387 .collect::<Vec<_>>();
388 let multilin_query = multilinear_query::<P>(&q);
389 let result = poly.evaluate(multilin_query.to_ref()).unwrap();
390 assert_eq!(result, F::new(i));
391 }
392 }
393
394 fn evaluate_split<P>(
395 poly: MultilinearExtension<P>,
396 q: &[P::Scalar],
397 splits: &[usize],
398 ) -> P::Scalar
399 where
400 P: PackedField + 'static,
401 {
402 assert_eq!(splits.iter().sum::<usize>(), poly.n_vars());
403
404 let mut partial_result = poly;
405 let mut index = q.len();
406 for split_vars in &splits[0..splits.len() - 1] {
407 let split_vars = *split_vars;
408 let query = multilinear_query(&q[index - split_vars..index]);
409 partial_result = partial_result
410 .evaluate_partial_high(query.to_ref())
411 .unwrap();
412 index -= split_vars;
413 }
414 let multilin_query = multilinear_query::<P>(&q[..index]);
415 partial_result.evaluate(multilin_query.to_ref()).unwrap()
416 }
417
418 #[test]
419 fn test_evaluate_split_is_correct() {
420 let mut rng = StdRng::seed_from_u64(0);
421 let evals = repeat_with(|| Field::random(&mut rng))
422 .take(256)
423 .collect::<Vec<F>>();
424 let poly = MultilinearExtension::from_values(evals).unwrap();
425 let q = repeat_with(|| Field::random(&mut rng))
426 .take(8)
427 .collect::<Vec<F>>();
428 let multilin_query = multilinear_query::<P>(&q);
429 let result1 = poly.evaluate(multilin_query.to_ref()).unwrap();
430 let result2 = evaluate_split(poly, &q, &[2, 3, 3]);
431 assert_eq!(result1, result2);
432 }
433
434 #[test]
435 fn test_evaluate_partial_high_packed() {
436 let mut rng = StdRng::seed_from_u64(0);
437 let evals = repeat_with(|| P::random(&mut rng))
438 .take(256 >> P::LOG_WIDTH)
439 .collect::<Vec<_>>();
440 let poly = MultilinearExtension::from_values(evals).unwrap();
441 let q = repeat_with(|| Field::random(&mut rng))
442 .take(8)
443 .collect::<Vec<BinaryField128b>>();
444 let multilin_query = multilinear_query::<BinaryField128b>(&q);
445
446 let expected = poly.evaluate(multilin_query.to_ref()).unwrap();
447
448 let query_hi = multilinear_query::<BinaryField128b>(&q[1..]);
450 let partial_eval = poly.evaluate_partial_high(query_hi.to_ref()).unwrap();
451 assert!(partial_eval.n_vars() < P::LOG_WIDTH);
452
453 let query_lo = multilinear_query::<BinaryField128b>(&q[..1]);
454 let eval = partial_eval.evaluate(query_lo.to_ref()).unwrap();
455 assert_eq!(eval, expected);
456 }
457
458 #[test]
459 fn test_evaluate_partial_low_high_smaller_than_packed_width() {
460 type P = PackedBinaryField16x8b;
461
462 type F = BinaryField8b;
463
464 let n_vars = 3;
465
466 let mut rng = StdRng::seed_from_u64(0);
467
468 let values = repeat_with(|| Field::random(&mut rng))
469 .take(1 << n_vars)
470 .collect::<Vec<F>>();
471
472 let q = repeat_with(|| Field::random(&mut rng))
473 .take(n_vars)
474 .collect::<Vec<F>>();
475
476 let query = multilinear_query::<P>(&q);
477
478 let packed = P::from_scalars(values);
479 let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
480
481 let eval = me.evaluate(&query).unwrap();
482
483 let query_low = multilinear_query::<P>(&q[..n_vars - 1]);
484 let query_high = multilinear_query::<P>(&q[n_vars - 1..]);
485
486 let eval_l_h = me
487 .evaluate_partial_high(&query_high)
488 .unwrap()
489 .evaluate_partial_low(&query_low)
490 .unwrap()
491 .evals()[0]
492 .get(0);
493
494 assert_eq!(eval, eval_l_h);
495 }
496
497 #[test]
498 fn test_evaluate_on_hypercube_small_than_packed_width() {
499 type P = PackedBinaryField16x8b;
500
501 type F = BinaryField8b;
502
503 let n_vars = 3;
504
505 let mut rng = StdRng::seed_from_u64(0);
506
507 let values = repeat_with(|| Field::random(&mut rng))
508 .take(1 << n_vars)
509 .collect::<Vec<F>>();
510
511 let packed = P::from_scalars(values.clone());
512
513 let me = MultilinearExtension::new(n_vars, vec![packed]).unwrap();
514
515 assert_eq!(me.evaluate_on_hypercube(1).unwrap(), values[1]);
516
517 assert!(me.evaluate_on_hypercube(1 << n_vars).is_err());
518 }
519
520 #[test]
521 fn test_evaluate_partial_high_low_evaluate_consistent() {
522 let mut rng = StdRng::seed_from_u64(0);
523 let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
524 .take(1 << 8)
525 .collect();
526
527 let me = MultilinearExtension::from_values(values).unwrap();
528
529 let q = repeat_with(|| <BinaryField32b as PackedField>::random(&mut rng))
530 .take(me.n_vars())
531 .collect::<Vec<_>>();
532
533 let query = multilinear_query(&q);
534
535 let eval = me
536 .evaluate::<<PackedBinaryField4x32b as PackedField>::Scalar, PackedBinaryField4x32b>(
537 query.to_ref(),
538 )
539 .unwrap();
540
541 assert_eq!(
542 me.evaluate_partial_low::<PackedBinaryField4x32b>(query.to_ref())
543 .unwrap()
544 .evals[0]
545 .get(0),
546 eval
547 );
548 assert_eq!(
549 me.evaluate_partial_high::<PackedBinaryField4x32b>(query.to_ref())
550 .unwrap()
551 .evals[0]
552 .get(0),
553 eval
554 );
555 }
556
557 #[test]
558 fn test_evaluate_partial_low_single_and_multiple_var_consistent() {
559 let mut rng = StdRng::seed_from_u64(0);
560 let values: Vec<_> = repeat_with(|| PackedBinaryField4x32b::random(&mut rng))
561 .take(1 << 8)
562 .collect();
563
564 let mle = MultilinearExtension::from_values(values).unwrap();
565 let r1 = <BinaryField32b as PackedField>::random(&mut rng);
566 let r2 = <BinaryField32b as PackedField>::random(&mut rng);
567
568 let eval_1: MultilinearExtension<PackedBinaryField4x32b> = mle
569 .evaluate_partial_low::<PackedBinaryField4x32b>(multilinear_query(&[r1]).to_ref())
570 .unwrap()
571 .evaluate_partial_low(multilinear_query(&[r2]).to_ref())
572 .unwrap();
573 let eval_2 = mle
574 .evaluate_partial_low(multilinear_query(&[r1, r2]).to_ref())
575 .unwrap();
576 assert_eq!(eval_1, eval_2);
577 }
578
579 #[test]
580 fn test_new_mle_with_tiny_nvars() {
581 MultilinearExtension::new(
582 1,
583 vec![PackedType::<OptimalUnderlier256b, BinaryField32b>::one()],
584 )
585 .unwrap();
586 }
587}