binius_core/transparent/
eq_ind.rs1use binius_field::{Field, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::MultilinearExtension;
6use binius_utils::bail;
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10#[derive(Debug, Clone)]
30pub struct EqIndPartialEval<F: Field> {
31 r: Vec<F>,
32}
33
34impl<F: Field> EqIndPartialEval<F> {
35 pub fn new(r: impl Into<Vec<F>>) -> Self {
36 Self { r: r.into() }
37 }
38
39 pub fn n_vars(&self) -> usize {
40 self.r.len()
41 }
42
43 pub fn multilinear_extension<P: PackedField<Scalar = F>, Backend: ComputationBackend>(
44 &self,
45 backend: &Backend,
46 ) -> Result<MultilinearExtension<P, Backend::Vec<P>>, Error> {
47 let multilin_query = backend.tensor_product_full_query(&self.r)?;
48 Ok(MultilinearExtension::from_values_generic(multilin_query)?)
49 }
50}
51
52impl<F: TowerField, P: PackedField<Scalar = F>> MultivariatePoly<P> for EqIndPartialEval<F> {
53 fn n_vars(&self) -> usize {
54 self.r.len()
55 }
56
57 fn degree(&self) -> usize {
58 self.r.len()
59 }
60
61 fn evaluate(&self, query: &[P]) -> Result<P, Error> {
62 let n_vars = MultivariatePoly::<P>::n_vars(self);
63 if query.len() != n_vars {
64 bail!(Error::IncorrectQuerySize { expected: n_vars });
65 }
66
67 let mut result = P::one();
68 for (&q_i, &r_i) in query.iter().zip(self.r.iter()) {
69 let term_one = q_i * r_i;
70 let term_two = (P::one() - q_i) * (P::one() - r_i);
71 let factor = term_one + term_two;
72 result *= factor;
73 }
74 Ok(result)
75 }
76
77 fn binary_tower_level(&self) -> usize {
78 F::TOWER_LEVEL
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use std::iter::repeat_with;
85
86 use binius_field::{BinaryField32b, PackedBinaryField4x32b, PackedField};
87 use binius_hal::{make_portable_backend, ComputationBackendExt};
88 use rand::{rngs::StdRng, SeedableRng};
89
90 use super::EqIndPartialEval;
91 use crate::polynomial::MultivariatePoly;
92
93 fn test_eq_consistency_help(n_vars: usize) {
94 type F = BinaryField32b;
95 type P = PackedBinaryField4x32b;
96
97 let mut rng = StdRng::seed_from_u64(0);
98 let r = repeat_with(|| F::random(&mut rng))
99 .take(n_vars)
100 .collect::<Vec<_>>();
101 let eval_point = &repeat_with(|| F::random(&mut rng))
102 .take(n_vars)
103 .collect::<Vec<_>>();
104 let backend = make_portable_backend();
105
106 let eq_r_mvp = EqIndPartialEval::new(r);
108 let eval_mvp = eq_r_mvp.evaluate(eval_point).unwrap();
109
110 let eq_r_mle = eq_r_mvp.multilinear_extension::<P, _>(&backend).unwrap();
112 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
113 let eval_mle = eq_r_mle.evaluate(&multilin_query).unwrap();
114
115 assert_eq!(eval_mle, eval_mvp);
117 }
118
119 #[test]
120 fn test_eq_consistency_schwartz_zippel() {
121 for n_vars in 2..=10 {
122 test_eq_consistency_help(n_vars);
123 }
124 }
125}