binius_core/transparent/
eq_ind.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{Field, PackedField, TowerField};
4use binius_hal::ComputationBackend;
5use binius_math::MultilinearExtension;
6use binius_utils::bail;
7
8use crate::polynomial::{Error, MultivariatePoly};
9
10/// Represents $\text{eq}(X, r)$, the partial evaluation of the
11/// [equality indicator polynomial](https://www.binius.xyz/blueprint/background/multilinears#the-equality-indicator-polynomial)
12/// at a point $r$.
13///
14/// The $2 \mu$-variate multilinear polynomial $\text{eq}(X, Y)$ is defined as the multilinear
15/// extension of the map
16///
17/// $$
18/// (x, y) \mapsto \begin{cases}
19///   1 &\text{if } x = y \\\\
20///   0 &\text{if } x \ne y
21/// \end{cases}.
22/// $$
23///
24/// The polynomial can be efficiency computed with the following explicit formulation:
25///
26/// $$
27/// \text{eq}(X, Y) = \prod_{i=0}^{\mu - 1} \left(X_i Y_i + (1 - X_i)(1 - Y_i)\right).
28/// $$
29#[derive(Debug, Clone)]
30pub struct EqIndPartialEval<F: Field> {
31	r: Vec<F>,
32}
33
34impl<F: Field> EqIndPartialEval<F> {
35	pub fn new(r: impl Into<Vec<F>>) -> Self {
36		Self { r: r.into() }
37	}
38
39	pub fn n_vars(&self) -> usize {
40		self.r.len()
41	}
42
43	pub fn multilinear_extension<P: PackedField<Scalar = F>, Backend: ComputationBackend>(
44		&self,
45		backend: &Backend,
46	) -> Result<MultilinearExtension<P, Backend::Vec<P>>, Error> {
47		let multilin_query = backend.tensor_product_full_query(&self.r)?;
48		Ok(MultilinearExtension::from_values_generic(multilin_query)?)
49	}
50}
51
52impl<F: TowerField, P: PackedField<Scalar = F>> MultivariatePoly<P> for EqIndPartialEval<F> {
53	fn n_vars(&self) -> usize {
54		self.r.len()
55	}
56
57	fn degree(&self) -> usize {
58		self.r.len()
59	}
60
61	fn evaluate(&self, query: &[P]) -> Result<P, Error> {
62		let n_vars = MultivariatePoly::<P>::n_vars(self);
63		if query.len() != n_vars {
64			bail!(Error::IncorrectQuerySize { expected: n_vars });
65		}
66
67		let mut result = P::one();
68		for (&q_i, &r_i) in query.iter().zip(self.r.iter()) {
69			let term_one = q_i * r_i;
70			let term_two = (P::one() - q_i) * (P::one() - r_i);
71			let factor = term_one + term_two;
72			result *= factor;
73		}
74		Ok(result)
75	}
76
77	fn binary_tower_level(&self) -> usize {
78		F::TOWER_LEVEL
79	}
80}
81
82#[cfg(test)]
83mod tests {
84	use std::iter::repeat_with;
85
86	use binius_field::{BinaryField32b, PackedBinaryField4x32b, PackedField};
87	use binius_hal::{make_portable_backend, ComputationBackendExt};
88	use rand::{rngs::StdRng, SeedableRng};
89
90	use super::EqIndPartialEval;
91	use crate::polynomial::MultivariatePoly;
92
93	fn test_eq_consistency_help(n_vars: usize) {
94		type F = BinaryField32b;
95		type P = PackedBinaryField4x32b;
96
97		let mut rng = StdRng::seed_from_u64(0);
98		let r = repeat_with(|| F::random(&mut rng))
99			.take(n_vars)
100			.collect::<Vec<_>>();
101		let eval_point = &repeat_with(|| F::random(&mut rng))
102			.take(n_vars)
103			.collect::<Vec<_>>();
104		let backend = make_portable_backend();
105
106		// Get Multivariate Poly version of eq_r
107		let eq_r_mvp = EqIndPartialEval::new(r);
108		let eval_mvp = eq_r_mvp.evaluate(eval_point).unwrap();
109
110		// Get MultilinearExtension version of eq_r
111		let eq_r_mle = eq_r_mvp.multilinear_extension::<P, _>(&backend).unwrap();
112		let multilin_query = backend.multilinear_query::<P>(eval_point).unwrap();
113		let eval_mle = eq_r_mle.evaluate(&multilin_query).unwrap();
114
115		// Assert equality
116		assert_eq!(eval_mle, eval_mvp);
117	}
118
119	#[test]
120	fn test_eq_consistency_schwartz_zippel() {
121		for n_vars in 2..=10 {
122			test_eq_consistency_help(n_vars);
123		}
124	}
125}