1use auto_impl::auto_impl;
5use binius_field::{
6 BinaryField, ExtensionField, Field, PackedExtension, PackedField,
7 packed::mul_by_subfield_scalar,
8};
9use binius_utils::bail;
10use itertools::{Either, izip};
11
12use super::{binary_subspace::BinarySubspace, error::Error};
13use crate::Matrix;
14
15#[derive(Debug, Clone)]
26pub struct EvaluationDomain<F: Field> {
27 finite_points: Vec<F>,
28 weights: Vec<F>,
29 with_infinity: bool,
30}
31
32#[derive(Debug, Clone)]
35pub struct InterpolationDomain<F: Field> {
36 evaluation_domain: EvaluationDomain<F>,
37 interpolation_matrix: Matrix<F>,
38}
39
40#[auto_impl(&)]
42pub trait EvaluationDomainFactory<DomainField: Field>: Clone + Sync {
43 fn create(&self, size: usize) -> Result<EvaluationDomain<DomainField>, Error>;
55}
56
57#[derive(Default, Clone)]
58pub struct DefaultEvaluationDomainFactory<F: BinaryField> {
59 subspace: BinarySubspace<F>,
60}
61
62#[derive(Default, Clone)]
63pub struct IsomorphicEvaluationDomainFactory<F: BinaryField> {
64 subspace: BinarySubspace<F>,
65}
66
67impl<F: BinaryField> EvaluationDomainFactory<F> for DefaultEvaluationDomainFactory<F> {
68 fn create(&self, size: usize) -> Result<EvaluationDomain<F>, Error> {
69 let with_infinity = size >= 3;
70 EvaluationDomain::from_points(
71 make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?,
72 with_infinity,
73 )
74 }
75}
76
77impl<FSrc, FTgt> EvaluationDomainFactory<FTgt> for IsomorphicEvaluationDomainFactory<FSrc>
78where
79 FSrc: BinaryField,
80 FTgt: Field + From<FSrc> + BinaryField,
81{
82 fn create(&self, size: usize) -> Result<EvaluationDomain<FTgt>, Error> {
83 let with_infinity = size >= 3;
84 let points =
85 make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?;
86 EvaluationDomain::from_points(points.into_iter().map(Into::into).collect(), with_infinity)
87 }
88}
89
90fn make_evaluation_points<F: BinaryField>(
91 subspace: &BinarySubspace<F>,
92 size: usize,
93) -> Result<Vec<F>, Error> {
94 let points = subspace.iter().take(size).collect::<Vec<F>>();
95 if points.len() != size {
96 bail!(Error::DomainSizeTooLarge);
97 }
98 Ok(points)
99}
100
101impl<F: Field> From<EvaluationDomain<F>> for InterpolationDomain<F> {
102 fn from(evaluation_domain: EvaluationDomain<F>) -> Self {
103 let n = evaluation_domain.size();
104 let evaluation_matrix =
105 vandermonde(evaluation_domain.finite_points(), evaluation_domain.with_infinity());
106 let mut interpolation_matrix = Matrix::zeros(n, n);
107 evaluation_matrix
108 .inverse_into(&mut interpolation_matrix)
109 .expect(
110 "matrix is square; \
111 there are no duplicate points because that would have been caught when computing \
112 weights; \
113 matrix is non-singular because it is Vandermonde with no duplicate points",
114 );
115
116 Self {
117 evaluation_domain,
118 interpolation_matrix,
119 }
120 }
121}
122
123impl<F: Field> EvaluationDomain<F> {
124 pub fn from_points(finite_points: Vec<F>, with_infinity: bool) -> Result<Self, Error> {
125 let weights = compute_barycentric_weights(&finite_points)?;
126 Ok(Self {
127 finite_points,
128 weights,
129 with_infinity,
130 })
131 }
132
133 pub fn size(&self) -> usize {
134 self.finite_points.len() + if self.with_infinity { 1 } else { 0 }
135 }
136
137 pub fn finite_points(&self) -> &[F] {
138 self.finite_points.as_slice()
139 }
140
141 pub const fn with_infinity(&self) -> bool {
142 self.with_infinity
143 }
144
145 pub fn lagrange_evals<FE: ExtensionField<F>>(&self, x: FE) -> Vec<FE> {
151 let num_evals = self.finite_points().len();
152
153 let mut result: Vec<FE> = vec![FE::ONE; num_evals];
154
155 for i in (1..num_evals).rev() {
157 result[i - 1] = result[i] * (x - self.finite_points[i]);
158 }
159
160 let mut prefix = FE::ONE;
161
162 for ((r, &point), &weight) in result
164 .iter_mut()
165 .zip(&self.finite_points)
166 .zip(&self.weights)
167 {
168 *r *= prefix * weight;
169 prefix *= x - point;
170 }
171
172 result
173 }
174
175 pub fn extrapolate<PE>(&self, values: &[PE], x: PE::Scalar) -> Result<PE, Error>
178 where
179 PE: PackedField<Scalar: ExtensionField<F>>,
180 {
181 if values.len() != self.size() {
182 bail!(Error::ExtrapolateNumberOfEvaluations);
183 }
184
185 let (values_iter, infinity_term) = if self.with_infinity {
186 let (&value_at_infinity, finite_values) =
187 values.split_last().expect("values length checked above");
188 let highest_degree = finite_values.len() as u64;
189 let iter = izip!(&self.finite_points, finite_values).map(move |(&point, &value)| {
190 value - value_at_infinity * PE::Scalar::from(point).pow(highest_degree)
191 });
192 (Either::Left(iter), value_at_infinity * x.pow(highest_degree))
193 } else {
194 (Either::Right(values.iter().copied()), PE::zero())
195 };
196
197 let result = izip!(self.lagrange_evals(x), values_iter)
198 .map(|(lagrange_at_x, value)| value * lagrange_at_x)
199 .sum::<PE>()
200 + infinity_term;
201
202 Ok(result)
203 }
204}
205
206impl<F: Field> InterpolationDomain<F> {
207 pub fn size(&self) -> usize {
208 self.evaluation_domain.size()
209 }
210
211 pub fn finite_points(&self) -> &[F] {
212 self.evaluation_domain.finite_points()
213 }
214
215 pub const fn with_infinity(&self) -> bool {
216 self.evaluation_domain.with_infinity()
217 }
218
219 pub fn extrapolate<PE: PackedExtension<F>>(
220 &self,
221 values: &[PE],
222 x: PE::Scalar,
223 ) -> Result<PE, Error> {
224 self.evaluation_domain.extrapolate(values, x)
225 }
226
227 pub fn interpolate<FE: ExtensionField<F>>(&self, values: &[FE]) -> Result<Vec<FE>, Error> {
228 if values.len() != self.evaluation_domain.size() {
229 bail!(Error::ExtrapolateNumberOfEvaluations);
230 }
231
232 let mut coeffs = vec![FE::ZERO; values.len()];
233 self.interpolation_matrix.mul_vec_into(values, &mut coeffs);
234 Ok(coeffs)
235 }
236}
237
238#[inline]
240pub fn extrapolate_line<P: PackedExtension<FS>, FS: Field>(x0: P, x1: P, z: FS) -> P {
241 x0 + mul_by_subfield_scalar(x1 - x0, z)
242}
243
244#[inline]
246pub fn extrapolate_lines<P>(x0: P, x1: P, z: P) -> P
247where
248 P: PackedField,
249{
250 x0 + (x1 - x0) * z
251}
252
253#[inline]
255pub fn extrapolate_line_scalar<F, FS>(x0: F, x1: F, z: FS) -> F
256where
257 F: ExtensionField<FS>,
258 FS: Field,
259{
260 x0 + (x1 - x0) * z
261}
262
263pub fn evaluate_univariate<F: Field>(coeffs: &[F], x: F) -> F {
265 coeffs
267 .iter()
268 .rfold(F::ZERO, |eval, &coeff| eval * x + coeff)
269}
270
271fn compute_barycentric_weights<F: Field>(points: &[F]) -> Result<Vec<F>, Error> {
272 let n = points.len();
273 (0..n)
274 .map(|i| {
275 let product = (0..n)
276 .filter(|&j| j != i)
277 .map(|j| points[i] - points[j])
278 .product::<F>();
279 product.invert().ok_or(Error::DuplicateDomainPoint)
280 })
281 .collect()
282}
283
284fn vandermonde<F: Field>(xs: &[F], with_infinity: bool) -> Matrix<F> {
285 let n = xs.len() + if with_infinity { 1 } else { 0 };
286
287 let mut mat = Matrix::zeros(n, n);
288 for (i, x_i) in xs.iter().copied().enumerate() {
289 let mut acc = F::ONE;
290 mat[(i, 0)] = acc;
291
292 for j in 1..n {
293 acc *= x_i;
294 mat[(i, j)] = acc;
295 }
296 }
297
298 if with_infinity {
299 mat[(n - 1, n - 1)] = F::ONE;
300 }
301
302 mat
303}
304
305#[cfg(test)]
306mod tests {
307 use std::{iter::repeat_with, slice};
308
309 use assert_matches::assert_matches;
310 use binius_field::{
311 AESTowerField32b, BinaryField8b, BinaryField32b, util::inner_product_unchecked,
312 };
313 use itertools::assert_equal;
314 use proptest::{collection::vec, proptest};
315 use rand::{SeedableRng, rngs::StdRng};
316
317 use super::*;
318
319 fn evaluate_univariate_naive<F: Field>(coeffs: &[F], x: F) -> F {
320 coeffs
321 .iter()
322 .enumerate()
323 .map(|(i, &coeff)| coeff * Field::pow(&x, slice::from_ref(&(i as u64))))
324 .sum()
325 }
326
327 #[test]
328 fn test_new_domain() {
329 let domain_factory = DefaultEvaluationDomainFactory::<BinaryField8b>::default();
330 assert_eq!(
331 domain_factory.create(3).unwrap().finite_points,
332 &[BinaryField8b::new(0), BinaryField8b::new(1),]
333 );
334 }
335
336 #[test]
337 fn test_domain_factory_binary_field() {
338 let default_domain_factory = DefaultEvaluationDomainFactory::<BinaryField32b>::default();
339 let iso_domain_factory = IsomorphicEvaluationDomainFactory::<BinaryField32b>::default();
340 let domain_1: EvaluationDomain<BinaryField32b> = default_domain_factory.create(10).unwrap();
341 let domain_2: EvaluationDomain<BinaryField32b> = iso_domain_factory.create(10).unwrap();
342 assert_eq!(domain_1.finite_points, domain_2.finite_points);
343 }
344
345 #[test]
346 fn test_domain_factory_aes() {
347 let default_domain_factory = DefaultEvaluationDomainFactory::<BinaryField32b>::default();
348 let iso_domain_factory = IsomorphicEvaluationDomainFactory::<BinaryField32b>::default();
349 let domain_1: EvaluationDomain<BinaryField32b> = default_domain_factory.create(10).unwrap();
350 let domain_2: EvaluationDomain<AESTowerField32b> = iso_domain_factory.create(10).unwrap();
351 assert_eq!(
352 domain_1
353 .finite_points
354 .into_iter()
355 .map(AESTowerField32b::from)
356 .collect::<Vec<_>>(),
357 domain_2.finite_points
358 );
359 }
360
361 #[test]
362 fn test_new_oversized_domain() {
363 let default_domain_factory = DefaultEvaluationDomainFactory::<BinaryField8b>::default();
364 assert_matches!(default_domain_factory.create(300), Err(Error::DomainSizeTooLarge));
365 }
366
367 #[test]
368 fn test_evaluate_univariate() {
369 let mut rng = StdRng::seed_from_u64(0);
370 let coeffs = repeat_with(|| <BinaryField8b as Field>::random(&mut rng))
371 .take(6)
372 .collect::<Vec<_>>();
373 let x = <BinaryField8b as Field>::random(&mut rng);
374 assert_eq!(evaluate_univariate(&coeffs, x), evaluate_univariate_naive(&coeffs, x));
375 }
376
377 #[test]
378 fn test_evaluate_univariate_no_coeffs() {
379 let mut rng = StdRng::seed_from_u64(0);
380 let x = <BinaryField32b as Field>::random(&mut rng);
381 assert_eq!(evaluate_univariate(&[], x), BinaryField32b::ZERO);
382 }
383
384 #[test]
385 fn test_random_extrapolate() {
386 let mut rng = StdRng::seed_from_u64(0);
387 let degree = 6;
388
389 let domain = EvaluationDomain::from_points(
390 repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
391 .take(degree + 1)
392 .collect(),
393 false,
394 )
395 .unwrap();
396
397 let coeffs = repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
398 .take(degree + 1)
399 .collect::<Vec<_>>();
400
401 let values = domain
402 .finite_points()
403 .iter()
404 .map(|&x| evaluate_univariate(&coeffs, x))
405 .collect::<Vec<_>>();
406
407 let x = <BinaryField32b as Field>::random(&mut rng);
408 let expected_y = evaluate_univariate(&coeffs, x);
409 assert_eq!(domain.extrapolate(&values, x).unwrap(), expected_y);
410 }
411
412 #[test]
413 fn test_interpolation() {
414 let mut rng = StdRng::seed_from_u64(0);
415 let degree = 6;
416
417 let domain = EvaluationDomain::from_points(
418 repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
419 .take(degree + 1)
420 .collect(),
421 false,
422 )
423 .unwrap();
424
425 let coeffs = repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
426 .take(degree + 1)
427 .collect::<Vec<_>>();
428
429 let values = domain
430 .finite_points()
431 .iter()
432 .map(|&x| evaluate_univariate(&coeffs, x))
433 .collect::<Vec<_>>();
434
435 let interpolated = InterpolationDomain::from(domain)
436 .interpolate(&values)
437 .unwrap();
438 assert_eq!(interpolated, coeffs);
439 }
440
441 #[test]
442 fn test_infinity() {
443 let mut rng = StdRng::seed_from_u64(0);
444 let degree = 6;
445
446 let domain = EvaluationDomain::from_points(
447 repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
448 .take(degree)
449 .collect(),
450 true,
451 )
452 .unwrap();
453
454 let coeffs = repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
455 .take(degree + 1)
456 .collect::<Vec<_>>();
457
458 let mut values = domain
459 .finite_points()
460 .iter()
461 .map(|&x| evaluate_univariate(&coeffs, x))
462 .collect::<Vec<_>>();
463 values.push(coeffs.last().copied().unwrap());
464
465 let x = <BinaryField32b as Field>::random(&mut rng);
466 let expected_y = evaluate_univariate(&coeffs, x);
467 assert_eq!(domain.extrapolate(&values, x).unwrap(), expected_y);
468
469 let interpolated = InterpolationDomain::from(domain)
470 .interpolate(&values)
471 .unwrap();
472 assert_eq!(interpolated, coeffs);
473 }
474
475 proptest! {
476 #[test]
477 fn test_extrapolate_line(x0 in 0u32.., x1 in 0u32.., z in 0u8..) {
478 let x0 = BinaryField32b::from(x0);
479 let x1 = BinaryField32b::from(x1);
480 let z = BinaryField8b::from(z);
481 assert_eq!(extrapolate_line(x0, x1, z), x0 + (x1 - x0) * z);
482 assert_eq!(extrapolate_line_scalar(x0, x1, z), x0 + (x1 - x0) * z);
483 }
484
485 #[test]
486 fn test_lagrange_evals(values in vec(0u32.., 0..100), z in 0u32..) {
487 let field_values = values.into_iter().map(BinaryField32b::from).collect::<Vec<_>>();
488 let subspace = BinarySubspace::<BinaryField32b>::with_dim(8).unwrap();
489 let domain_points = subspace.iter().take(field_values.len()).collect::<Vec<_>>();
490 let evaluation_domain = EvaluationDomain::from_points(domain_points, false).unwrap();
491
492 let z = BinaryField32b::new(z);
493
494 let extrapolated = evaluation_domain.extrapolate(field_values.as_slice(), z).unwrap();
495 let lagrange_coeffs = evaluation_domain.lagrange_evals(z);
496 let lagrange_eval = inner_product_unchecked(lagrange_coeffs.into_iter(), field_values.into_iter());
497 assert_eq!(lagrange_eval, extrapolated);
498 }
499 }
500
501 #[test]
502 fn test_barycentric_weights_simple() {
503 let p1 = BinaryField32b::from(1);
504 let p2 = BinaryField32b::from(2);
505 let p3 = BinaryField32b::from(3);
506
507 let points = vec![p1, p2, p3];
508 let weights = compute_barycentric_weights(&points).unwrap();
509
510 let w1 = ((p1 - p2) * (p1 - p3)).invert().unwrap();
512 let w2 = ((p2 - p1) * (p2 - p3)).invert().unwrap();
513 let w3 = ((p3 - p1) * (p3 - p2)).invert().unwrap();
514
515 assert_eq!(weights, vec![w1, w2, w3]);
516 }
517
518 #[test]
519 fn test_barycentric_weights_four_points() {
520 let p1 = BinaryField32b::from(1);
521 let p2 = BinaryField32b::from(2);
522 let p3 = BinaryField32b::from(3);
523 let p4 = BinaryField32b::from(4);
524
525 let points = vec![p1, p2, p3, p4];
526
527 let weights = compute_barycentric_weights(&points).unwrap();
528
529 let w1 = ((p1 - p2) * (p1 - p3) * (p1 - p4)).invert().unwrap();
531 let w2 = ((p2 - p1) * (p2 - p3) * (p2 - p4)).invert().unwrap();
532 let w3 = ((p3 - p1) * (p3 - p2) * (p3 - p4)).invert().unwrap();
533 let w4 = ((p4 - p1) * (p4 - p2) * (p4 - p3)).invert().unwrap();
534
535 assert_eq!(weights, vec![w1, w2, w3, w4]);
536 }
537
538 #[test]
539 fn test_barycentric_weights_single_point() {
540 let p1 = BinaryField32b::from(5);
541
542 let points = vec![p1];
543 let result = compute_barycentric_weights(&points).unwrap();
544
545 assert_equal(result, vec![BinaryField32b::from(1)]);
546 }
547
548 #[test]
549 fn test_barycentric_weights_duplicate_points() {
550 let p1 = BinaryField32b::from(7);
551 let p2 = BinaryField32b::from(7); let points = vec![p1, p2];
554 let result = compute_barycentric_weights(&points);
555
556 assert!(result.is_err());
558 }
559
560 #[test]
561 fn test_vandermonde_basic() {
562 let p1 = BinaryField32b::from(1);
563 let p2 = BinaryField32b::from(2);
564 let p3 = BinaryField32b::from(3);
565
566 let points = vec![p1, p2, p3];
567
568 let matrix = vandermonde(&points, false);
569
570 let expected = Matrix::new(
577 3,
578 3,
579 &[
580 BinaryField32b::from(1),
581 p1,
582 p1.pow(2),
583 BinaryField32b::from(1),
584 p2,
585 p2.pow(2),
586 BinaryField32b::from(1),
587 p3,
588 p3.pow(2),
589 ],
590 )
591 .unwrap();
592
593 assert_eq!(matrix, expected);
594 }
595
596 #[test]
597 fn test_vandermonde_with_infinity() {
598 let p1 = BinaryField32b::from(1);
599 let p2 = BinaryField32b::from(2);
600 let p3 = BinaryField32b::from(3);
601
602 let points = vec![p1, p2, p3];
603 let matrix = vandermonde(&points, true);
604
605 let expected = Matrix::new(
613 4,
614 4,
615 &[
616 BinaryField32b::from(1),
617 p1,
618 p1.pow(2),
619 p1.pow(3),
620 BinaryField32b::from(1),
621 p2,
622 p2.pow(2),
623 p2.pow(3),
624 BinaryField32b::from(1),
625 p3,
626 p3.pow(2),
627 p3.pow(3),
628 BinaryField32b::from(0),
629 BinaryField32b::from(0),
630 BinaryField32b::from(0),
631 BinaryField32b::from(1),
632 ],
633 )
634 .unwrap();
635
636 assert_eq!(matrix, expected);
637 }
638}