1use auto_impl::auto_impl;
5use binius_field::{
6 packed::mul_by_subfield_scalar, BinaryField, ExtensionField, Field, PackedExtension,
7 PackedField,
8};
9use binius_utils::bail;
10use itertools::{izip, Either};
11
12use super::{binary_subspace::BinarySubspace, error::Error};
13use crate::Matrix;
14
15#[derive(Debug, Clone)]
26pub struct EvaluationDomain<F: Field> {
27 finite_points: Vec<F>,
28 weights: Vec<F>,
29 with_infinity: bool,
30}
31
32#[derive(Debug, Clone)]
35pub struct InterpolationDomain<F: Field> {
36 evaluation_domain: EvaluationDomain<F>,
37 interpolation_matrix: Matrix<F>,
38}
39
40#[auto_impl(&)]
42pub trait EvaluationDomainFactory<DomainField: Field>: Clone + Sync {
43 fn create(&self, size: usize) -> Result<EvaluationDomain<DomainField>, Error> {
46 self.create_with_infinity(size, false)
47 }
48
49 fn create_with_infinity(
53 &self,
54 size: usize,
55 with_infinity: bool,
56 ) -> Result<EvaluationDomain<DomainField>, Error>;
57}
58
59#[derive(Default, Clone)]
60pub struct DefaultEvaluationDomainFactory<F: BinaryField> {
61 subspace: BinarySubspace<F>,
62}
63
64#[derive(Default, Clone)]
65pub struct IsomorphicEvaluationDomainFactory<F: BinaryField> {
66 subspace: BinarySubspace<F>,
67}
68
69impl<F: BinaryField> EvaluationDomainFactory<F> for DefaultEvaluationDomainFactory<F> {
70 fn create_with_infinity(
71 &self,
72 size: usize,
73 with_infinity: bool,
74 ) -> Result<EvaluationDomain<F>, Error> {
75 if size == 0 && with_infinity {
76 bail!(Error::DomainSizeAtLeastOne);
77 }
78 EvaluationDomain::from_points(
79 make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?,
80 with_infinity,
81 )
82 }
83}
84
85impl<FSrc, FTgt> EvaluationDomainFactory<FTgt> for IsomorphicEvaluationDomainFactory<FSrc>
86where
87 FSrc: BinaryField,
88 FTgt: Field + From<FSrc> + BinaryField,
89{
90 fn create_with_infinity(
91 &self,
92 size: usize,
93 with_infinity: bool,
94 ) -> Result<EvaluationDomain<FTgt>, Error> {
95 if size == 0 && with_infinity {
96 bail!(Error::DomainSizeAtLeastOne);
97 }
98 let points =
99 make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?;
100 EvaluationDomain::from_points(points.into_iter().map(Into::into).collect(), with_infinity)
101 }
102}
103
104fn make_evaluation_points<F: BinaryField>(
105 subspace: &BinarySubspace<F>,
106 size: usize,
107) -> Result<Vec<F>, Error> {
108 let points = subspace.iter().take(size).collect::<Vec<F>>();
109 if points.len() != size {
110 bail!(Error::DomainSizeTooLarge);
111 }
112 Ok(points)
113}
114
115impl<F: Field> From<EvaluationDomain<F>> for InterpolationDomain<F> {
116 fn from(evaluation_domain: EvaluationDomain<F>) -> Self {
117 let n = evaluation_domain.size();
118 let evaluation_matrix =
119 vandermonde(evaluation_domain.finite_points(), evaluation_domain.with_infinity());
120 let mut interpolation_matrix = Matrix::zeros(n, n);
121 evaluation_matrix
122 .inverse_into(&mut interpolation_matrix)
123 .expect(
124 "matrix is square; \
125 there are no duplicate points because that would have been caught when computing \
126 weights; \
127 matrix is non-singular because it is Vandermonde with no duplicate points",
128 );
129
130 Self {
131 evaluation_domain,
132 interpolation_matrix,
133 }
134 }
135}
136
137impl<F: Field> EvaluationDomain<F> {
138 pub fn from_points(finite_points: Vec<F>, with_infinity: bool) -> Result<Self, Error> {
139 let weights = compute_barycentric_weights(&finite_points)?;
140 Ok(Self {
141 finite_points,
142 weights,
143 with_infinity,
144 })
145 }
146
147 pub fn size(&self) -> usize {
148 self.finite_points.len() + if self.with_infinity { 1 } else { 0 }
149 }
150
151 pub fn finite_points(&self) -> &[F] {
152 self.finite_points.as_slice()
153 }
154
155 pub const fn with_infinity(&self) -> bool {
156 self.with_infinity
157 }
158
159 pub fn lagrange_evals<FE: ExtensionField<F>>(&self, x: FE) -> Vec<FE> {
165 let num_evals = self.finite_points().len();
166
167 let mut result: Vec<FE> = vec![FE::ONE; num_evals];
168
169 for i in (1..num_evals).rev() {
171 result[i - 1] = result[i] * (x - self.finite_points[i]);
172 }
173
174 let mut prefix = FE::ONE;
175
176 for ((r, &point), &weight) in result
178 .iter_mut()
179 .zip(&self.finite_points)
180 .zip(&self.weights)
181 {
182 *r *= prefix * weight;
183 prefix *= x - point;
184 }
185
186 result
187 }
188
189 pub fn extrapolate<PE>(&self, values: &[PE], x: PE::Scalar) -> Result<PE, Error>
191 where
192 PE: PackedField<Scalar: ExtensionField<F>>,
193 {
194 if values.len() != self.size() {
195 bail!(Error::ExtrapolateNumberOfEvaluations);
196 }
197
198 let (values_iter, infinity_term) = if self.with_infinity {
199 let (&value_at_infinity, finite_values) =
200 values.split_last().expect("values length checked above");
201 let highest_degree = finite_values.len() as u64;
202 let iter = izip!(&self.finite_points, finite_values).map(move |(&point, &value)| {
203 value - value_at_infinity * PE::Scalar::from(point).pow(highest_degree)
204 });
205 (Either::Left(iter), value_at_infinity * x.pow(highest_degree))
206 } else {
207 (Either::Right(values.iter().copied()), PE::zero())
208 };
209
210 let result = izip!(self.lagrange_evals(x), values_iter)
211 .map(|(lagrange_at_x, value)| value * lagrange_at_x)
212 .sum::<PE>()
213 + infinity_term;
214
215 Ok(result)
216 }
217}
218
219impl<F: Field> InterpolationDomain<F> {
220 pub fn size(&self) -> usize {
221 self.evaluation_domain.size()
222 }
223
224 pub fn finite_points(&self) -> &[F] {
225 self.evaluation_domain.finite_points()
226 }
227
228 pub const fn with_infinity(&self) -> bool {
229 self.evaluation_domain.with_infinity()
230 }
231
232 pub fn extrapolate<PE: PackedExtension<F>>(
233 &self,
234 values: &[PE],
235 x: PE::Scalar,
236 ) -> Result<PE, Error> {
237 self.evaluation_domain.extrapolate(values, x)
238 }
239
240 pub fn interpolate<FE: ExtensionField<F>>(&self, values: &[FE]) -> Result<Vec<FE>, Error> {
241 if values.len() != self.evaluation_domain.size() {
242 bail!(Error::ExtrapolateNumberOfEvaluations);
243 }
244
245 let mut coeffs = vec![FE::ZERO; values.len()];
246 self.interpolation_matrix.mul_vec_into(values, &mut coeffs);
247 Ok(coeffs)
248 }
249}
250
251#[inline]
253pub fn extrapolate_line<P: PackedExtension<FS>, FS: Field>(x0: P, x1: P, z: FS) -> P {
254 x0 + mul_by_subfield_scalar(x1 - x0, z)
255}
256
257#[inline]
259pub fn extrapolate_lines<P>(x0: P, x1: P, z: P) -> P
260where
261 P: PackedField,
262{
263 x0 + (x1 - x0) * z
264}
265
266#[inline]
268pub fn extrapolate_line_scalar<F, FS>(x0: F, x1: F, z: FS) -> F
269where
270 F: ExtensionField<FS>,
271 FS: Field,
272{
273 x0 + (x1 - x0) * z
274}
275
276pub fn evaluate_univariate<F: Field>(coeffs: &[F], x: F) -> F {
278 coeffs
280 .iter()
281 .rfold(F::ZERO, |eval, &coeff| eval * x + coeff)
282}
283
284fn compute_barycentric_weights<F: Field>(points: &[F]) -> Result<Vec<F>, Error> {
285 let n = points.len();
286 (0..n)
287 .map(|i| {
288 let product = (0..n)
289 .filter(|&j| j != i)
290 .map(|j| points[i] - points[j])
291 .product::<F>();
292 product.invert().ok_or(Error::DuplicateDomainPoint)
293 })
294 .collect()
295}
296
297fn vandermonde<F: Field>(xs: &[F], with_infinity: bool) -> Matrix<F> {
298 let n = xs.len() + if with_infinity { 1 } else { 0 };
299
300 let mut mat = Matrix::zeros(n, n);
301 for (i, x_i) in xs.iter().copied().enumerate() {
302 let mut acc = F::ONE;
303 mat[(i, 0)] = acc;
304
305 for j in 1..n {
306 acc *= x_i;
307 mat[(i, j)] = acc;
308 }
309 }
310
311 if with_infinity {
312 mat[(n - 1, n - 1)] = F::ONE;
313 }
314
315 mat
316}
317
318#[cfg(test)]
319mod tests {
320 use std::{iter::repeat_with, slice};
321
322 use assert_matches::assert_matches;
323 use binius_field::{
324 util::inner_product_unchecked, AESTowerField32b, BinaryField32b, BinaryField8b,
325 };
326 use itertools::assert_equal;
327 use proptest::{collection::vec, proptest};
328 use rand::{rngs::StdRng, SeedableRng};
329
330 use super::*;
331
332 fn evaluate_univariate_naive<F: Field>(coeffs: &[F], x: F) -> F {
333 coeffs
334 .iter()
335 .enumerate()
336 .map(|(i, &coeff)| coeff * Field::pow(&x, slice::from_ref(&(i as u64))))
337 .sum()
338 }
339
340 #[test]
341 fn test_new_domain() {
342 let domain_factory = DefaultEvaluationDomainFactory::<BinaryField8b>::default();
343 assert_eq!(
344 domain_factory.create(3).unwrap().finite_points,
345 &[
346 BinaryField8b::new(0),
347 BinaryField8b::new(1),
348 BinaryField8b::new(2)
349 ]
350 );
351 }
352
353 #[test]
354 fn test_domain_factory_binary_field() {
355 let default_domain_factory = DefaultEvaluationDomainFactory::<BinaryField32b>::default();
356 let iso_domain_factory = IsomorphicEvaluationDomainFactory::<BinaryField32b>::default();
357 let domain_1: EvaluationDomain<BinaryField32b> = default_domain_factory.create(10).unwrap();
358 let domain_2: EvaluationDomain<BinaryField32b> = iso_domain_factory.create(10).unwrap();
359 assert_eq!(domain_1.finite_points, domain_2.finite_points);
360 }
361
362 #[test]
363 fn test_domain_factory_aes() {
364 let default_domain_factory = DefaultEvaluationDomainFactory::<BinaryField32b>::default();
365 let iso_domain_factory = IsomorphicEvaluationDomainFactory::<BinaryField32b>::default();
366 let domain_1: EvaluationDomain<BinaryField32b> = default_domain_factory.create(10).unwrap();
367 let domain_2: EvaluationDomain<AESTowerField32b> = iso_domain_factory.create(10).unwrap();
368 assert_eq!(
369 domain_1
370 .finite_points
371 .into_iter()
372 .map(AESTowerField32b::from)
373 .collect::<Vec<_>>(),
374 domain_2.finite_points
375 );
376 }
377
378 #[test]
379 fn test_new_oversized_domain() {
380 let default_domain_factory = DefaultEvaluationDomainFactory::<BinaryField8b>::default();
381 assert_matches!(default_domain_factory.create(300), Err(Error::DomainSizeTooLarge));
382 }
383
384 #[test]
385 fn test_evaluate_univariate() {
386 let mut rng = StdRng::seed_from_u64(0);
387 let coeffs = repeat_with(|| <BinaryField8b as Field>::random(&mut rng))
388 .take(6)
389 .collect::<Vec<_>>();
390 let x = <BinaryField8b as Field>::random(&mut rng);
391 assert_eq!(evaluate_univariate(&coeffs, x), evaluate_univariate_naive(&coeffs, x));
392 }
393
394 #[test]
395 fn test_evaluate_univariate_no_coeffs() {
396 let mut rng = StdRng::seed_from_u64(0);
397 let x = <BinaryField32b as Field>::random(&mut rng);
398 assert_eq!(evaluate_univariate(&[], x), BinaryField32b::ZERO);
399 }
400
401 #[test]
402 fn test_random_extrapolate() {
403 let mut rng = StdRng::seed_from_u64(0);
404 let degree = 6;
405
406 let domain = EvaluationDomain::from_points(
407 repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
408 .take(degree + 1)
409 .collect(),
410 false,
411 )
412 .unwrap();
413
414 let coeffs = repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
415 .take(degree + 1)
416 .collect::<Vec<_>>();
417
418 let values = domain
419 .finite_points()
420 .iter()
421 .map(|&x| evaluate_univariate(&coeffs, x))
422 .collect::<Vec<_>>();
423
424 let x = <BinaryField32b as Field>::random(&mut rng);
425 let expected_y = evaluate_univariate(&coeffs, x);
426 assert_eq!(domain.extrapolate(&values, x).unwrap(), expected_y);
427 }
428
429 #[test]
430 fn test_interpolation() {
431 let mut rng = StdRng::seed_from_u64(0);
432 let degree = 6;
433
434 let domain = EvaluationDomain::from_points(
435 repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
436 .take(degree + 1)
437 .collect(),
438 false,
439 )
440 .unwrap();
441
442 let coeffs = repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
443 .take(degree + 1)
444 .collect::<Vec<_>>();
445
446 let values = domain
447 .finite_points()
448 .iter()
449 .map(|&x| evaluate_univariate(&coeffs, x))
450 .collect::<Vec<_>>();
451
452 let interpolated = InterpolationDomain::from(domain)
453 .interpolate(&values)
454 .unwrap();
455 assert_eq!(interpolated, coeffs);
456 }
457
458 #[test]
459 fn test_infinity() {
460 let mut rng = StdRng::seed_from_u64(0);
461 let degree = 6;
462
463 let domain = EvaluationDomain::from_points(
464 repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
465 .take(degree)
466 .collect(),
467 true,
468 )
469 .unwrap();
470
471 let coeffs = repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
472 .take(degree + 1)
473 .collect::<Vec<_>>();
474
475 let mut values = domain
476 .finite_points()
477 .iter()
478 .map(|&x| evaluate_univariate(&coeffs, x))
479 .collect::<Vec<_>>();
480 values.push(coeffs.last().copied().unwrap());
481
482 let x = <BinaryField32b as Field>::random(&mut rng);
483 let expected_y = evaluate_univariate(&coeffs, x);
484 assert_eq!(domain.extrapolate(&values, x).unwrap(), expected_y);
485
486 let interpolated = InterpolationDomain::from(domain)
487 .interpolate(&values)
488 .unwrap();
489 assert_eq!(interpolated, coeffs);
490 }
491
492 proptest! {
493 #[test]
494 fn test_extrapolate_line(x0 in 0u32.., x1 in 0u32.., z in 0u8..) {
495 let x0 = BinaryField32b::from(x0);
496 let x1 = BinaryField32b::from(x1);
497 let z = BinaryField8b::from(z);
498 assert_eq!(extrapolate_line(x0, x1, z), x0 + (x1 - x0) * z);
499 assert_eq!(extrapolate_line_scalar(x0, x1, z), x0 + (x1 - x0) * z);
500 }
501
502 #[test]
503 fn test_lagrange_evals(values in vec(0u32.., 0..100), z in 0u32..) {
504 let field_values = values.into_iter().map(BinaryField32b::from).collect::<Vec<_>>();
505 let factory = DefaultEvaluationDomainFactory::<BinaryField32b>::default();
506 let evaluation_domain = factory.create(field_values.len()).unwrap();
507
508 let z = BinaryField32b::new(z);
509
510 let extrapolated = evaluation_domain.extrapolate(field_values.as_slice(), z).unwrap();
511 let lagrange_coeffs = evaluation_domain.lagrange_evals(z);
512 let lagrange_eval = inner_product_unchecked(lagrange_coeffs.into_iter(), field_values.into_iter());
513 assert_eq!(lagrange_eval, extrapolated);
514 }
515 }
516
517 #[test]
518 fn test_barycentric_weights_simple() {
519 let p1 = BinaryField32b::from(1);
520 let p2 = BinaryField32b::from(2);
521 let p3 = BinaryField32b::from(3);
522
523 let points = vec![p1, p2, p3];
524 let weights = compute_barycentric_weights(&points).unwrap();
525
526 let w1 = ((p1 - p2) * (p1 - p3)).invert().unwrap();
528 let w2 = ((p2 - p1) * (p2 - p3)).invert().unwrap();
529 let w3 = ((p3 - p1) * (p3 - p2)).invert().unwrap();
530
531 assert_eq!(weights, vec![w1, w2, w3]);
532 }
533
534 #[test]
535 fn test_barycentric_weights_four_points() {
536 let p1 = BinaryField32b::from(1);
537 let p2 = BinaryField32b::from(2);
538 let p3 = BinaryField32b::from(3);
539 let p4 = BinaryField32b::from(4);
540
541 let points = vec![p1, p2, p3, p4];
542
543 let weights = compute_barycentric_weights(&points).unwrap();
544
545 let w1 = ((p1 - p2) * (p1 - p3) * (p1 - p4)).invert().unwrap();
547 let w2 = ((p2 - p1) * (p2 - p3) * (p2 - p4)).invert().unwrap();
548 let w3 = ((p3 - p1) * (p3 - p2) * (p3 - p4)).invert().unwrap();
549 let w4 = ((p4 - p1) * (p4 - p2) * (p4 - p3)).invert().unwrap();
550
551 assert_eq!(weights, vec![w1, w2, w3, w4]);
552 }
553
554 #[test]
555 fn test_barycentric_weights_single_point() {
556 let p1 = BinaryField32b::from(5);
557
558 let points = vec![p1];
559 let result = compute_barycentric_weights(&points).unwrap();
560
561 assert_equal(result, vec![BinaryField32b::from(1)]);
562 }
563
564 #[test]
565 fn test_barycentric_weights_duplicate_points() {
566 let p1 = BinaryField32b::from(7);
567 let p2 = BinaryField32b::from(7); let points = vec![p1, p2];
570 let result = compute_barycentric_weights(&points);
571
572 assert!(result.is_err());
574 }
575
576 #[test]
577 fn test_vandermonde_basic() {
578 let p1 = BinaryField32b::from(1);
579 let p2 = BinaryField32b::from(2);
580 let p3 = BinaryField32b::from(3);
581
582 let points = vec![p1, p2, p3];
583
584 let matrix = vandermonde(&points, false);
585
586 let expected = Matrix::new(
593 3,
594 3,
595 &[
596 BinaryField32b::from(1),
597 p1,
598 p1.pow(2),
599 BinaryField32b::from(1),
600 p2,
601 p2.pow(2),
602 BinaryField32b::from(1),
603 p3,
604 p3.pow(2),
605 ],
606 )
607 .unwrap();
608
609 assert_eq!(matrix, expected);
610 }
611
612 #[test]
613 fn test_vandermonde_with_infinity() {
614 let p1 = BinaryField32b::from(1);
615 let p2 = BinaryField32b::from(2);
616 let p3 = BinaryField32b::from(3);
617
618 let points = vec![p1, p2, p3];
619 let matrix = vandermonde(&points, true);
620
621 let expected = Matrix::new(
629 4,
630 4,
631 &[
632 BinaryField32b::from(1),
633 p1,
634 p1.pow(2),
635 p1.pow(3),
636 BinaryField32b::from(1),
637 p2,
638 p2.pow(2),
639 p2.pow(3),
640 BinaryField32b::from(1),
641 p3,
642 p3.pow(2),
643 p3.pow(3),
644 BinaryField32b::from(0),
645 BinaryField32b::from(0),
646 BinaryField32b::from(0),
647 BinaryField32b::from(1),
648 ],
649 )
650 .unwrap();
651
652 assert_eq!(matrix, expected);
653 }
654}