binius_math/
tensor_prod_eq_ind.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::cmp::max;
4
5use binius_field::{Field, PackedField};
6use binius_maybe_rayon::prelude::*;
7use binius_utils::bail;
8use bytemuck::zeroed_vec;
9
10use crate::Error;
11
12/// Tensor Product expansion of values with partial eq indicator evaluated at extra_query_coordinates
13///
14/// Let $n$ be log_n_values, $p$, $k$ be the lengths of `packed_values` and `extra_query_coordinates`.
15/// Requires
16///     * $n \geq k$
17///     * p = max(1, 2^{n+k} / P::WIDTH)
18/// Let $v$ be a vector corresponding to the first $2^n$ scalar values of `values`.
19/// Let $r = (r_0, \ldots, r_{k-1})$ be the vector of `extra_query_coordinates`.
20///
21/// # Formal Definition
22/// `values` is updated to contain the result of:
23/// $v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})$
24/// which is now a vector of length $2^{n+k}$. If 2^{n+k} < P::WIDTH, then
25/// the result is packed into a single element of `values` where only the first
26/// 2^{n+k} elements have meaning.
27///
28/// # Interpretation
29/// Let $f$ be an $n$ variate multilinear polynomial that has evaluations over
30/// the $n$ dimensional hypercube corresponding to $v$.
31/// Then `values` is updated to contain the evaluations of $g$ over the $n+k$-dimensional
32/// hypercube where
33/// * $g(x_0, \ldots, x_{n+k-1}) = f(x_0, \ldots, x_{n-1}) * eq(x_n, \ldots, x_{n+k-1}, r)$
34pub fn tensor_prod_eq_ind<P: PackedField>(
35	log_n_values: usize,
36	packed_values: &mut [P],
37	extra_query_coordinates: &[P::Scalar],
38) -> Result<(), Error> {
39	let new_n_vars = log_n_values + extra_query_coordinates.len();
40	if packed_values.len() != max(1, (1 << new_n_vars) / P::WIDTH) {
41		bail!(Error::InvalidPackedValuesLength);
42	}
43
44	for (i, r_i) in extra_query_coordinates.iter().enumerate() {
45		let prev_length = 1 << (log_n_values + i);
46		if prev_length < P::WIDTH {
47			let q = &mut packed_values[0];
48			for h in 0..prev_length {
49				let x = q.get(h);
50				let prod = x * r_i;
51				q.set(h, x - prod);
52				q.set(prev_length | h, prod);
53			}
54		} else {
55			let prev_packed_length = prev_length / P::WIDTH;
56			let packed_r_i = P::broadcast(*r_i);
57			let (xs, ys) = packed_values.split_at_mut(prev_packed_length);
58			assert!(xs.len() <= ys.len());
59
60			// These magic numbers were chosen experimentally to have a reasonable performance
61			// for the calls with small number of elements.
62			xs.par_iter_mut()
63				.zip(ys.par_iter_mut())
64				.with_min_len(64)
65				.for_each(|(x, y)| {
66					// x = x * (1 - packed_r_i) = x - x * packed_r_i
67					// y = x * packed_r_i
68					// Notice that we can reuse the multiplication: (x * packed_r_i)
69					let prod = (*x) * packed_r_i;
70					*x -= prod;
71					*y = prod;
72				});
73		}
74	}
75	Ok(())
76}
77
78/// Computes the partial evaluation of the equality indicator polynomial.
79///
80/// Given an $n$-coordinate point $r_0, ..., r_n$, this computes the partial evaluation of the
81/// equality indicator polynomial $\widetilde{eq}(X_0, ..., X_{n-1}, r_0, ..., r_{n-1})$ and
82/// returns its values over the $n$-dimensional hypercube.
83///
84/// The returned values are equal to the tensor product
85///
86/// $$
87/// (1 - r_0, r_0) \otimes ... \otimes (1 - r_{n-1}, r_{n-1}).
88/// $$
89///
90/// See [DP23], Section 2.1 for more information about the equality indicator polynomial.
91///
92/// [DP23]: <https://eprint.iacr.org/2023/1784>
93pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> Vec<P> {
94	let n = point.len();
95	let len = 1 << n.saturating_sub(P::LOG_WIDTH);
96	let mut buffer = zeroed_vec::<P>(len);
97	buffer[0].set(0, P::Scalar::ONE);
98	tensor_prod_eq_ind(0, &mut buffer, point).expect("buffer is allocated with the correct length");
99	buffer
100}
101
102#[cfg(test)]
103mod tests {
104	use binius_field::{packed::set_packed_slice, Field, PackedBinaryField4x32b};
105	use itertools::Itertools;
106
107	use super::*;
108
109	type P = PackedBinaryField4x32b;
110	type F = <P as PackedField>::Scalar;
111
112	#[test]
113	fn test_tensor_prod_eq_ind() {
114		let v0 = F::new(1);
115		let v1 = F::new(2);
116		let query = vec![v0, v1];
117		let mut result = vec![P::default(); 1 << (query.len() - P::LOG_WIDTH)];
118		set_packed_slice(&mut result, 0, F::ONE);
119		tensor_prod_eq_ind(0, &mut result, &query).unwrap();
120		let result = PackedField::iter_slice(&result).collect_vec();
121		assert_eq!(
122			result,
123			vec![
124				(F::ONE - v0) * (F::ONE - v1),
125				v0 * (F::ONE - v1),
126				(F::ONE - v0) * v1,
127				v0 * v1
128			]
129		);
130	}
131
132	#[test]
133	fn test_eq_ind_partial_eval_empty() {
134		let result = eq_ind_partial_eval::<P>(&[]);
135		let expected = vec![P::set_single(F::ONE)];
136		assert_eq!(result, expected);
137	}
138
139	#[test]
140	fn test_eq_ind_partial_eval_single_var() {
141		// Only one query coordinate
142		let r0 = F::new(2);
143		let result = eq_ind_partial_eval::<P>(&[r0]);
144		let expected = vec![(F::ONE - r0), r0, F::ZERO, F::ZERO];
145		let result = PackedField::iter_slice(&result).collect_vec();
146		assert_eq!(result, expected);
147	}
148
149	#[test]
150	fn test_eq_ind_partial_eval_two_vars() {
151		// Two query coordinates
152		let r0 = F::new(2);
153		let r1 = F::new(3);
154		let result = eq_ind_partial_eval::<P>(&[r0, r1]);
155		let result = PackedField::iter_slice(&result).collect_vec();
156		let expected = vec![
157			(F::ONE - r0) * (F::ONE - r1),
158			r0 * (F::ONE - r1),
159			(F::ONE - r0) * r1,
160			r0 * r1,
161		];
162		assert_eq!(result, expected);
163	}
164
165	#[test]
166	fn test_eq_ind_partial_eval_three_vars() {
167		// Case with three query coordinates
168		let r0 = F::new(2);
169		let r1 = F::new(3);
170		let r2 = F::new(5);
171		let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
172		let result = PackedField::iter_slice(&result).collect_vec();
173
174		let expected = vec![
175			(F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
176			r0 * (F::ONE - r1) * (F::ONE - r2),
177			(F::ONE - r0) * r1 * (F::ONE - r2),
178			r0 * r1 * (F::ONE - r2),
179			(F::ONE - r0) * (F::ONE - r1) * r2,
180			r0 * (F::ONE - r1) * r2,
181			(F::ONE - r0) * r1 * r2,
182			r0 * r1 * r2,
183		];
184		assert_eq!(result, expected);
185	}
186}