binius_core/transparent/
eq_ind.rs1use binius_field::{Field, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::MultilinearExtension;
6use binius_utils::bail;
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10#[derive(Debug, Clone)]
30pub struct EqIndPartialEval<F: Field> {
31 r: Vec<F>,
32}
33
34impl<F: Field> EqIndPartialEval<F> {
35 pub fn new(r: impl Into<Vec<F>>) -> Self {
36 Self { r: r.into() }
37 }
38
39 pub fn n_vars(&self) -> usize {
40 self.r.len()
41 }
42
43 pub fn multilinear_extension<P: PackedField<Scalar = F>, Backend: ComputationBackend>(
44 &self,
45 backend: &Backend,
46 ) -> Result<MultilinearExtension<P, Backend::Vec<P>>, Error> {
47 let multilin_query = backend.tensor_product_full_query(&self.r)?;
48 Ok(MultilinearExtension::new(self.n_vars(), multilin_query)?)
49 }
50}
51
52impl<F: TowerField, P: PackedField<Scalar = F>> MultivariatePoly<P> for EqIndPartialEval<F> {
53 fn n_vars(&self) -> usize {
54 self.r.len()
55 }
56
57 fn degree(&self) -> usize {
58 self.r.len()
59 }
60
61 fn evaluate(&self, query: &[P]) -> Result<P, Error> {
62 let n_vars = MultivariatePoly::<P>::n_vars(self);
63 if query.len() != n_vars {
64 bail!(Error::IncorrectQuerySize {
65 expected: n_vars,
66 actual: query.len()
67 });
68 }
69
70 let mut result = P::one();
71 for (&q_i, &r_i) in query.iter().zip(self.r.iter()) {
72 let term_one = q_i * r_i;
73 let term_two = (P::one() - q_i) * (P::one() - r_i);
74 let factor = term_one + term_two;
75 result *= factor;
76 }
77 Ok(result)
78 }
79
80 fn binary_tower_level(&self) -> usize {
81 F::TOWER_LEVEL
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use std::iter::repeat_with;
88
89 use binius_field::{BinaryField32b, PackedBinaryField4x32b, PackedField};
90 use binius_hal::{ComputationBackendExt, make_portable_backend};
91 use rand::{SeedableRng, rngs::StdRng};
92
93 use super::EqIndPartialEval;
94 use crate::polynomial::MultivariatePoly;
95
96 fn test_eq_consistency_help(n_vars: usize) {
97 type F = BinaryField32b;
98 type P = PackedBinaryField4x32b;
99
100 let mut rng = StdRng::seed_from_u64(0);
101 let r = repeat_with(|| F::random(&mut rng))
102 .take(n_vars)
103 .collect::<Vec<_>>();
104 let eval_point = &repeat_with(|| F::random(&mut rng))
105 .take(n_vars)
106 .collect::<Vec<_>>();
107 let backend = make_portable_backend();
108
109 let eq_r_mvp = EqIndPartialEval::new(r);
111 let eval_mvp = eq_r_mvp.evaluate(eval_point).unwrap();
112
113 let eq_r_mle = eq_r_mvp.multilinear_extension::<P, _>(&backend).unwrap();
115 let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
116 let eval_mle = eq_r_mle.evaluate(&multilin_query).unwrap();
117
118 assert_eq!(eval_mle, eval_mvp);
120 }
121
122 #[test]
123 fn test_eq_consistency_schwartz_zippel() {
124 for n_vars in 2..=10 {
125 test_eq_consistency_help(n_vars);
126 }
127 }
128}