binius_core/transparent/
eq_ind.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::MultilinearExtension;
6use binius_utils::bail;
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10/// Represents $\text{eq}(X, r)$, the partial evaluation of the
11/// [equality indicator polynomial](https://www.binius.xyz/blueprint/background/multilinears#the-equality-indicator-polynomial)
12/// at a point $r$.
13///
14/// The $2 \mu$-variate multilinear polynomial $\text{eq}(X, Y)$ is defined as the multilinear
15/// extension of the map
16///
17/// $$
18/// (x, y) \mapsto \begin{cases}
19///   1 &\text{if } x = y \\\\
20///   0 &\text{if } x \ne y
21/// \end{cases}.
22/// $$
23///
24/// The polynomial can be efficiency computed with the following explicit formulation:
25///
26/// $$
27/// \text{eq}(X, Y) = \prod_{i=0}^{\mu - 1} \left(X_i Y_i + (1 - X_i)(1 - Y_i)\right).
28/// $$
29#[derive(Debug, Clone)]
30pub struct EqIndPartialEval<F: Field> {
31	r: Vec<F>,
32}
33
34impl<F: Field> EqIndPartialEval<F> {
35	pub fn new(r: impl Into<Vec<F>>) -> Self {
36		Self { r: r.into() }
37	}
38
39	pub fn n_vars(&self) -> usize {
40		self.r.len()
41	}
42
43	pub fn multilinear_extension<P: PackedField<Scalar = F>, Backend: ComputationBackend>(
44		&self,
45		backend: &Backend,
46	) -> Result<MultilinearExtension<P, Backend::Vec<P>>, Error> {
47		let multilin_query = backend.tensor_product_full_query(&self.r)?;
48		Ok(MultilinearExtension::new(self.n_vars(), multilin_query)?)
49	}
50}
51
52impl<F: TowerField, P: PackedField<Scalar = F>> MultivariatePoly<P> for EqIndPartialEval<F> {
53	fn n_vars(&self) -> usize {
54		self.r.len()
55	}
56
57	fn degree(&self) -> usize {
58		self.r.len()
59	}
60
61	fn evaluate(&self, query: &[P]) -> Result<P, Error> {
62		let n_vars = MultivariatePoly::<P>::n_vars(self);
63		if query.len() != n_vars {
64			bail!(Error::IncorrectQuerySize {
65				expected: n_vars,
66				actual: query.len()
67			});
68		}
69
70		let mut result = P::one();
71		for (&q_i, &r_i) in query.iter().zip(self.r.iter()) {
72			let term_one = q_i * r_i;
73			let term_two = (P::one() - q_i) * (P::one() - r_i);
74			let factor = term_one + term_two;
75			result *= factor;
76		}
77		Ok(result)
78	}
79
80	fn binary_tower_level(&self) -> usize {
81		F::TOWER_LEVEL
82	}
83}
84
85#[cfg(test)]
86mod tests {
87	use std::iter::repeat_with;
88
89	use binius_field::{BinaryField32b, PackedBinaryField4x32b, PackedField};
90	use binius_hal::{ComputationBackendExt, make_portable_backend};
91	use rand::{SeedableRng, rngs::StdRng};
92
93	use super::EqIndPartialEval;
94	use crate::polynomial::MultivariatePoly;
95
96	fn test_eq_consistency_help(n_vars: usize) {
97		type F = BinaryField32b;
98		type P = PackedBinaryField4x32b;
99
100		let mut rng = StdRng::seed_from_u64(0);
101		let r = repeat_with(|| F::random(&mut rng))
102			.take(n_vars)
103			.collect::<Vec<_>>();
104		let eval_point = &repeat_with(|| F::random(&mut rng))
105			.take(n_vars)
106			.collect::<Vec<_>>();
107		let backend = make_portable_backend();
108
109		// Get Multivariate Poly version of eq_r
110		let eq_r_mvp = EqIndPartialEval::new(r);
111		let eval_mvp = eq_r_mvp.evaluate(eval_point).unwrap();
112
113		// Get MultilinearExtension version of eq_r
114		let eq_r_mle = eq_r_mvp.multilinear_extension::<P, _>(&backend).unwrap();
115		let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
116		let eval_mle = eq_r_mle.evaluate(&multilin_query).unwrap();
117
118		// Assert equality
119		assert_eq!(eval_mle, eval_mvp);
120	}
121
122	#[test]
123	fn test_eq_consistency_schwartz_zippel() {
124		for n_vars in 2..=10 {
125			test_eq_consistency_help(n_vars);
126		}
127	}
128}