1use auto_impl::auto_impl;
5use binius_field::{
6 packed::mul_by_subfield_scalar, BinaryField, ExtensionField, Field, PackedExtension,
7 PackedField,
8};
9use binius_utils::bail;
10use itertools::{izip, Either};
11
12use super::{binary_subspace::BinarySubspace, error::Error};
13use crate::Matrix;
14
15#[derive(Debug, Clone)]
26pub struct EvaluationDomain<F: Field> {
27 finite_points: Vec<F>,
28 weights: Vec<F>,
29 with_infinity: bool,
30}
31
32#[derive(Debug, Clone)]
35pub struct InterpolationDomain<F: Field> {
36 evaluation_domain: EvaluationDomain<F>,
37 interpolation_matrix: Matrix<F>,
38}
39
40#[auto_impl(&)]
42pub trait EvaluationDomainFactory<DomainField: Field>: Clone + Sync {
43 fn create(&self, size: usize) -> Result<EvaluationDomain<DomainField>, Error>;
55}
56
57#[derive(Default, Clone)]
58pub struct DefaultEvaluationDomainFactory<F: BinaryField> {
59 subspace: BinarySubspace<F>,
60}
61
62#[derive(Default, Clone)]
63pub struct IsomorphicEvaluationDomainFactory<F: BinaryField> {
64 subspace: BinarySubspace<F>,
65}
66
67impl<F: BinaryField> EvaluationDomainFactory<F> for DefaultEvaluationDomainFactory<F> {
68 fn create(&self, size: usize) -> Result<EvaluationDomain<F>, Error> {
69 let with_infinity = size >= 3;
70 EvaluationDomain::from_points(
71 make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?,
72 with_infinity,
73 )
74 }
75}
76
77impl<FSrc, FTgt> EvaluationDomainFactory<FTgt> for IsomorphicEvaluationDomainFactory<FSrc>
78where
79 FSrc: BinaryField,
80 FTgt: Field + From<FSrc> + BinaryField,
81{
82 fn create(&self, size: usize) -> Result<EvaluationDomain<FTgt>, Error> {
83 let with_infinity = size >= 3;
84 let points =
85 make_evaluation_points(&self.subspace, size - if with_infinity { 1 } else { 0 })?;
86 EvaluationDomain::from_points(points.into_iter().map(Into::into).collect(), with_infinity)
87 }
88}
89
90fn make_evaluation_points<F: BinaryField>(
91 subspace: &BinarySubspace<F>,
92 size: usize,
93) -> Result<Vec<F>, Error> {
94 let points = subspace.iter().take(size).collect::<Vec<F>>();
95 if points.len() != size {
96 bail!(Error::DomainSizeTooLarge);
97 }
98 Ok(points)
99}
100
101impl<F: Field> From<EvaluationDomain<F>> for InterpolationDomain<F> {
102 fn from(evaluation_domain: EvaluationDomain<F>) -> Self {
103 let n = evaluation_domain.size();
104 let evaluation_matrix =
105 vandermonde(evaluation_domain.finite_points(), evaluation_domain.with_infinity());
106 let mut interpolation_matrix = Matrix::zeros(n, n);
107 evaluation_matrix
108 .inverse_into(&mut interpolation_matrix)
109 .expect(
110 "matrix is square; \
111 there are no duplicate points because that would have been caught when computing \
112 weights; \
113 matrix is non-singular because it is Vandermonde with no duplicate points",
114 );
115
116 Self {
117 evaluation_domain,
118 interpolation_matrix,
119 }
120 }
121}
122
123impl<F: Field> EvaluationDomain<F> {
124 pub fn from_points(finite_points: Vec<F>, with_infinity: bool) -> Result<Self, Error> {
125 let weights = compute_barycentric_weights(&finite_points)?;
126 Ok(Self {
127 finite_points,
128 weights,
129 with_infinity,
130 })
131 }
132
133 pub fn size(&self) -> usize {
134 self.finite_points.len() + if self.with_infinity { 1 } else { 0 }
135 }
136
137 pub fn finite_points(&self) -> &[F] {
138 self.finite_points.as_slice()
139 }
140
141 pub const fn with_infinity(&self) -> bool {
142 self.with_infinity
143 }
144
145 pub fn lagrange_evals<FE: ExtensionField<F>>(&self, x: FE) -> Vec<FE> {
151 let num_evals = self.finite_points().len();
152
153 let mut result: Vec<FE> = vec![FE::ONE; num_evals];
154
155 for i in (1..num_evals).rev() {
157 result[i - 1] = result[i] * (x - self.finite_points[i]);
158 }
159
160 let mut prefix = FE::ONE;
161
162 for ((r, &point), &weight) in result
164 .iter_mut()
165 .zip(&self.finite_points)
166 .zip(&self.weights)
167 {
168 *r *= prefix * weight;
169 prefix *= x - point;
170 }
171
172 result
173 }
174
175 pub fn extrapolate<PE>(&self, values: &[PE], x: PE::Scalar) -> Result<PE, Error>
177 where
178 PE: PackedField<Scalar: ExtensionField<F>>,
179 {
180 if values.len() != self.size() {
181 bail!(Error::ExtrapolateNumberOfEvaluations);
182 }
183
184 let (values_iter, infinity_term) = if self.with_infinity {
185 let (&value_at_infinity, finite_values) =
186 values.split_last().expect("values length checked above");
187 let highest_degree = finite_values.len() as u64;
188 let iter = izip!(&self.finite_points, finite_values).map(move |(&point, &value)| {
189 value - value_at_infinity * PE::Scalar::from(point).pow(highest_degree)
190 });
191 (Either::Left(iter), value_at_infinity * x.pow(highest_degree))
192 } else {
193 (Either::Right(values.iter().copied()), PE::zero())
194 };
195
196 let result = izip!(self.lagrange_evals(x), values_iter)
197 .map(|(lagrange_at_x, value)| value * lagrange_at_x)
198 .sum::<PE>()
199 + infinity_term;
200
201 Ok(result)
202 }
203}
204
205impl<F: Field> InterpolationDomain<F> {
206 pub fn size(&self) -> usize {
207 self.evaluation_domain.size()
208 }
209
210 pub fn finite_points(&self) -> &[F] {
211 self.evaluation_domain.finite_points()
212 }
213
214 pub const fn with_infinity(&self) -> bool {
215 self.evaluation_domain.with_infinity()
216 }
217
218 pub fn extrapolate<PE: PackedExtension<F>>(
219 &self,
220 values: &[PE],
221 x: PE::Scalar,
222 ) -> Result<PE, Error> {
223 self.evaluation_domain.extrapolate(values, x)
224 }
225
226 pub fn interpolate<FE: ExtensionField<F>>(&self, values: &[FE]) -> Result<Vec<FE>, Error> {
227 if values.len() != self.evaluation_domain.size() {
228 bail!(Error::ExtrapolateNumberOfEvaluations);
229 }
230
231 let mut coeffs = vec![FE::ZERO; values.len()];
232 self.interpolation_matrix.mul_vec_into(values, &mut coeffs);
233 Ok(coeffs)
234 }
235}
236
237#[inline]
239pub fn extrapolate_line<P: PackedExtension<FS>, FS: Field>(x0: P, x1: P, z: FS) -> P {
240 x0 + mul_by_subfield_scalar(x1 - x0, z)
241}
242
243#[inline]
245pub fn extrapolate_lines<P>(x0: P, x1: P, z: P) -> P
246where
247 P: PackedField,
248{
249 x0 + (x1 - x0) * z
250}
251
252#[inline]
254pub fn extrapolate_line_scalar<F, FS>(x0: F, x1: F, z: FS) -> F
255where
256 F: ExtensionField<FS>,
257 FS: Field,
258{
259 x0 + (x1 - x0) * z
260}
261
262pub fn evaluate_univariate<F: Field>(coeffs: &[F], x: F) -> F {
264 coeffs
266 .iter()
267 .rfold(F::ZERO, |eval, &coeff| eval * x + coeff)
268}
269
270fn compute_barycentric_weights<F: Field>(points: &[F]) -> Result<Vec<F>, Error> {
271 let n = points.len();
272 (0..n)
273 .map(|i| {
274 let product = (0..n)
275 .filter(|&j| j != i)
276 .map(|j| points[i] - points[j])
277 .product::<F>();
278 product.invert().ok_or(Error::DuplicateDomainPoint)
279 })
280 .collect()
281}
282
283fn vandermonde<F: Field>(xs: &[F], with_infinity: bool) -> Matrix<F> {
284 let n = xs.len() + if with_infinity { 1 } else { 0 };
285
286 let mut mat = Matrix::zeros(n, n);
287 for (i, x_i) in xs.iter().copied().enumerate() {
288 let mut acc = F::ONE;
289 mat[(i, 0)] = acc;
290
291 for j in 1..n {
292 acc *= x_i;
293 mat[(i, j)] = acc;
294 }
295 }
296
297 if with_infinity {
298 mat[(n - 1, n - 1)] = F::ONE;
299 }
300
301 mat
302}
303
304#[cfg(test)]
305mod tests {
306 use std::{iter::repeat_with, slice};
307
308 use assert_matches::assert_matches;
309 use binius_field::{
310 util::inner_product_unchecked, AESTowerField32b, BinaryField32b, BinaryField8b,
311 };
312 use itertools::assert_equal;
313 use proptest::{collection::vec, proptest};
314 use rand::{rngs::StdRng, SeedableRng};
315
316 use super::*;
317
318 fn evaluate_univariate_naive<F: Field>(coeffs: &[F], x: F) -> F {
319 coeffs
320 .iter()
321 .enumerate()
322 .map(|(i, &coeff)| coeff * Field::pow(&x, slice::from_ref(&(i as u64))))
323 .sum()
324 }
325
326 #[test]
327 fn test_new_domain() {
328 let domain_factory = DefaultEvaluationDomainFactory::<BinaryField8b>::default();
329 assert_eq!(
330 domain_factory.create(3).unwrap().finite_points,
331 &[BinaryField8b::new(0), BinaryField8b::new(1),]
332 );
333 }
334
335 #[test]
336 fn test_domain_factory_binary_field() {
337 let default_domain_factory = DefaultEvaluationDomainFactory::<BinaryField32b>::default();
338 let iso_domain_factory = IsomorphicEvaluationDomainFactory::<BinaryField32b>::default();
339 let domain_1: EvaluationDomain<BinaryField32b> = default_domain_factory.create(10).unwrap();
340 let domain_2: EvaluationDomain<BinaryField32b> = iso_domain_factory.create(10).unwrap();
341 assert_eq!(domain_1.finite_points, domain_2.finite_points);
342 }
343
344 #[test]
345 fn test_domain_factory_aes() {
346 let default_domain_factory = DefaultEvaluationDomainFactory::<BinaryField32b>::default();
347 let iso_domain_factory = IsomorphicEvaluationDomainFactory::<BinaryField32b>::default();
348 let domain_1: EvaluationDomain<BinaryField32b> = default_domain_factory.create(10).unwrap();
349 let domain_2: EvaluationDomain<AESTowerField32b> = iso_domain_factory.create(10).unwrap();
350 assert_eq!(
351 domain_1
352 .finite_points
353 .into_iter()
354 .map(AESTowerField32b::from)
355 .collect::<Vec<_>>(),
356 domain_2.finite_points
357 );
358 }
359
360 #[test]
361 fn test_new_oversized_domain() {
362 let default_domain_factory = DefaultEvaluationDomainFactory::<BinaryField8b>::default();
363 assert_matches!(default_domain_factory.create(300), Err(Error::DomainSizeTooLarge));
364 }
365
366 #[test]
367 fn test_evaluate_univariate() {
368 let mut rng = StdRng::seed_from_u64(0);
369 let coeffs = repeat_with(|| <BinaryField8b as Field>::random(&mut rng))
370 .take(6)
371 .collect::<Vec<_>>();
372 let x = <BinaryField8b as Field>::random(&mut rng);
373 assert_eq!(evaluate_univariate(&coeffs, x), evaluate_univariate_naive(&coeffs, x));
374 }
375
376 #[test]
377 fn test_evaluate_univariate_no_coeffs() {
378 let mut rng = StdRng::seed_from_u64(0);
379 let x = <BinaryField32b as Field>::random(&mut rng);
380 assert_eq!(evaluate_univariate(&[], x), BinaryField32b::ZERO);
381 }
382
383 #[test]
384 fn test_random_extrapolate() {
385 let mut rng = StdRng::seed_from_u64(0);
386 let degree = 6;
387
388 let domain = EvaluationDomain::from_points(
389 repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
390 .take(degree + 1)
391 .collect(),
392 false,
393 )
394 .unwrap();
395
396 let coeffs = repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
397 .take(degree + 1)
398 .collect::<Vec<_>>();
399
400 let values = domain
401 .finite_points()
402 .iter()
403 .map(|&x| evaluate_univariate(&coeffs, x))
404 .collect::<Vec<_>>();
405
406 let x = <BinaryField32b as Field>::random(&mut rng);
407 let expected_y = evaluate_univariate(&coeffs, x);
408 assert_eq!(domain.extrapolate(&values, x).unwrap(), expected_y);
409 }
410
411 #[test]
412 fn test_interpolation() {
413 let mut rng = StdRng::seed_from_u64(0);
414 let degree = 6;
415
416 let domain = EvaluationDomain::from_points(
417 repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
418 .take(degree + 1)
419 .collect(),
420 false,
421 )
422 .unwrap();
423
424 let coeffs = repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
425 .take(degree + 1)
426 .collect::<Vec<_>>();
427
428 let values = domain
429 .finite_points()
430 .iter()
431 .map(|&x| evaluate_univariate(&coeffs, x))
432 .collect::<Vec<_>>();
433
434 let interpolated = InterpolationDomain::from(domain)
435 .interpolate(&values)
436 .unwrap();
437 assert_eq!(interpolated, coeffs);
438 }
439
440 #[test]
441 fn test_infinity() {
442 let mut rng = StdRng::seed_from_u64(0);
443 let degree = 6;
444
445 let domain = EvaluationDomain::from_points(
446 repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
447 .take(degree)
448 .collect(),
449 true,
450 )
451 .unwrap();
452
453 let coeffs = repeat_with(|| <BinaryField32b as Field>::random(&mut rng))
454 .take(degree + 1)
455 .collect::<Vec<_>>();
456
457 let mut values = domain
458 .finite_points()
459 .iter()
460 .map(|&x| evaluate_univariate(&coeffs, x))
461 .collect::<Vec<_>>();
462 values.push(coeffs.last().copied().unwrap());
463
464 let x = <BinaryField32b as Field>::random(&mut rng);
465 let expected_y = evaluate_univariate(&coeffs, x);
466 assert_eq!(domain.extrapolate(&values, x).unwrap(), expected_y);
467
468 let interpolated = InterpolationDomain::from(domain)
469 .interpolate(&values)
470 .unwrap();
471 assert_eq!(interpolated, coeffs);
472 }
473
474 proptest! {
475 #[test]
476 fn test_extrapolate_line(x0 in 0u32.., x1 in 0u32.., z in 0u8..) {
477 let x0 = BinaryField32b::from(x0);
478 let x1 = BinaryField32b::from(x1);
479 let z = BinaryField8b::from(z);
480 assert_eq!(extrapolate_line(x0, x1, z), x0 + (x1 - x0) * z);
481 assert_eq!(extrapolate_line_scalar(x0, x1, z), x0 + (x1 - x0) * z);
482 }
483
484 #[test]
485 fn test_lagrange_evals(values in vec(0u32.., 0..100), z in 0u32..) {
486 let field_values = values.into_iter().map(BinaryField32b::from).collect::<Vec<_>>();
487 let subspace = BinarySubspace::<BinaryField32b>::with_dim(8).unwrap();
488 let domain_points = subspace.iter().take(field_values.len()).collect::<Vec<_>>();
489 let evaluation_domain = EvaluationDomain::from_points(domain_points, false).unwrap();
490
491 let z = BinaryField32b::new(z);
492
493 let extrapolated = evaluation_domain.extrapolate(field_values.as_slice(), z).unwrap();
494 let lagrange_coeffs = evaluation_domain.lagrange_evals(z);
495 let lagrange_eval = inner_product_unchecked(lagrange_coeffs.into_iter(), field_values.into_iter());
496 assert_eq!(lagrange_eval, extrapolated);
497 }
498 }
499
500 #[test]
501 fn test_barycentric_weights_simple() {
502 let p1 = BinaryField32b::from(1);
503 let p2 = BinaryField32b::from(2);
504 let p3 = BinaryField32b::from(3);
505
506 let points = vec![p1, p2, p3];
507 let weights = compute_barycentric_weights(&points).unwrap();
508
509 let w1 = ((p1 - p2) * (p1 - p3)).invert().unwrap();
511 let w2 = ((p2 - p1) * (p2 - p3)).invert().unwrap();
512 let w3 = ((p3 - p1) * (p3 - p2)).invert().unwrap();
513
514 assert_eq!(weights, vec![w1, w2, w3]);
515 }
516
517 #[test]
518 fn test_barycentric_weights_four_points() {
519 let p1 = BinaryField32b::from(1);
520 let p2 = BinaryField32b::from(2);
521 let p3 = BinaryField32b::from(3);
522 let p4 = BinaryField32b::from(4);
523
524 let points = vec![p1, p2, p3, p4];
525
526 let weights = compute_barycentric_weights(&points).unwrap();
527
528 let w1 = ((p1 - p2) * (p1 - p3) * (p1 - p4)).invert().unwrap();
530 let w2 = ((p2 - p1) * (p2 - p3) * (p2 - p4)).invert().unwrap();
531 let w3 = ((p3 - p1) * (p3 - p2) * (p3 - p4)).invert().unwrap();
532 let w4 = ((p4 - p1) * (p4 - p2) * (p4 - p3)).invert().unwrap();
533
534 assert_eq!(weights, vec![w1, w2, w3, w4]);
535 }
536
537 #[test]
538 fn test_barycentric_weights_single_point() {
539 let p1 = BinaryField32b::from(5);
540
541 let points = vec![p1];
542 let result = compute_barycentric_weights(&points).unwrap();
543
544 assert_equal(result, vec![BinaryField32b::from(1)]);
545 }
546
547 #[test]
548 fn test_barycentric_weights_duplicate_points() {
549 let p1 = BinaryField32b::from(7);
550 let p2 = BinaryField32b::from(7); let points = vec![p1, p2];
553 let result = compute_barycentric_weights(&points);
554
555 assert!(result.is_err());
557 }
558
559 #[test]
560 fn test_vandermonde_basic() {
561 let p1 = BinaryField32b::from(1);
562 let p2 = BinaryField32b::from(2);
563 let p3 = BinaryField32b::from(3);
564
565 let points = vec![p1, p2, p3];
566
567 let matrix = vandermonde(&points, false);
568
569 let expected = Matrix::new(
576 3,
577 3,
578 &[
579 BinaryField32b::from(1),
580 p1,
581 p1.pow(2),
582 BinaryField32b::from(1),
583 p2,
584 p2.pow(2),
585 BinaryField32b::from(1),
586 p3,
587 p3.pow(2),
588 ],
589 )
590 .unwrap();
591
592 assert_eq!(matrix, expected);
593 }
594
595 #[test]
596 fn test_vandermonde_with_infinity() {
597 let p1 = BinaryField32b::from(1);
598 let p2 = BinaryField32b::from(2);
599 let p3 = BinaryField32b::from(3);
600
601 let points = vec![p1, p2, p3];
602 let matrix = vandermonde(&points, true);
603
604 let expected = Matrix::new(
612 4,
613 4,
614 &[
615 BinaryField32b::from(1),
616 p1,
617 p1.pow(2),
618 p1.pow(3),
619 BinaryField32b::from(1),
620 p2,
621 p2.pow(2),
622 p2.pow(3),
623 BinaryField32b::from(1),
624 p3,
625 p3.pow(2),
626 p3.pow(3),
627 BinaryField32b::from(0),
628 BinaryField32b::from(0),
629 BinaryField32b::from(0),
630 BinaryField32b::from(1),
631 ],
632 )
633 .unwrap();
634
635 assert_eq!(matrix, expected);
636 }
637}