binius_math/
tensor_prod_eq_ind.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::cmp::max;
4
5use binius_field::{Field, PackedField};
6use binius_maybe_rayon::prelude::*;
7use binius_utils::bail;
8use bytemuck::zeroed_vec;
9
10use crate::Error;
11
12/// Tensor Product expansion of values with partial eq indicator evaluated at
13/// extra_query_coordinates
14///
15/// Let $n$ be log_n_values, $p$, $k$ be the lengths of `packed_values` and
16/// `extra_query_coordinates`. Requires
17///     * $n \geq k$
18///     * p = max(1, 2^{n+k} / P::WIDTH)
19/// Let $v$ be a vector corresponding to the first $2^n$ scalar values of `values`.
20/// Let $r = (r_0, \ldots, r_{k-1})$ be the vector of `extra_query_coordinates`.
21///
22/// # Formal Definition
23/// `values` is updated to contain the result of:
24/// $v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})$
25/// which is now a vector of length $2^{n+k}$. If 2^{n+k} < P::WIDTH, then
26/// the result is packed into a single element of `values` where only the first
27/// 2^{n+k} elements have meaning.
28///
29/// # Interpretation
30/// Let $f$ be an $n$ variate multilinear polynomial that has evaluations over
31/// the $n$ dimensional hypercube corresponding to $v$.
32/// Then `values` is updated to contain the evaluations of $g$ over the $n+k$-dimensional
33/// hypercube where
34/// * $g(x_0, \ldots, x_{n+k-1}) = f(x_0, \ldots, x_{n-1}) * eq(x_n, \ldots, x_{n+k-1}, r)$
35pub fn tensor_prod_eq_ind<P: PackedField>(
36	log_n_values: usize,
37	packed_values: &mut [P],
38	extra_query_coordinates: &[P::Scalar],
39) -> Result<(), Error> {
40	let new_n_vars = log_n_values + extra_query_coordinates.len();
41	if packed_values.len() != max(1, (1 << new_n_vars) / P::WIDTH) {
42		bail!(Error::InvalidPackedValuesLength);
43	}
44
45	for (i, r_i) in extra_query_coordinates.iter().enumerate() {
46		let prev_length = 1 << (log_n_values + i);
47		if prev_length < P::WIDTH {
48			let q = &mut packed_values[0];
49			for h in 0..prev_length {
50				let x = q.get(h);
51				let prod = x * r_i;
52				q.set(h, x - prod);
53				q.set(prev_length | h, prod);
54			}
55		} else {
56			let prev_packed_length = prev_length / P::WIDTH;
57			let packed_r_i = P::broadcast(*r_i);
58			let (xs, ys) = packed_values.split_at_mut(prev_packed_length);
59			assert!(xs.len() <= ys.len());
60
61			// These magic numbers were chosen experimentally to have a reasonable performance
62			// for the calls with small number of elements.
63			xs.par_iter_mut()
64				.zip(ys.par_iter_mut())
65				.with_min_len(64)
66				.for_each(|(x, y)| {
67					// x = x * (1 - packed_r_i) = x - x * packed_r_i
68					// y = x * packed_r_i
69					// Notice that we can reuse the multiplication: (x * packed_r_i)
70					let prod = (*x) * packed_r_i;
71					*x -= prod;
72					*y = prod;
73				});
74		}
75	}
76	Ok(())
77}
78
79/// Computes the partial evaluation of the equality indicator polynomial.
80///
81/// Given an $n$-coordinate point $r_0, ..., r_n$, this computes the partial evaluation of the
82/// equality indicator polynomial $\widetilde{eq}(X_0, ..., X_{n-1}, r_0, ..., r_{n-1})$ and
83/// returns its values over the $n$-dimensional hypercube.
84///
85/// The returned values are equal to the tensor product
86///
87/// $$
88/// (1 - r_0, r_0) \otimes ... \otimes (1 - r_{n-1}, r_{n-1}).
89/// $$
90///
91/// See [DP23], Section 2.1 for more information about the equality indicator polynomial.
92///
93/// [DP23]: <https://eprint.iacr.org/2023/1784>
94pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> Vec<P> {
95	let n = point.len();
96	let len = 1 << n.saturating_sub(P::LOG_WIDTH);
97	let mut buffer = zeroed_vec::<P>(len);
98	buffer[0].set(0, P::Scalar::ONE);
99	tensor_prod_eq_ind(0, &mut buffer, point).expect("buffer is allocated with the correct length");
100	buffer
101}
102
103#[cfg(test)]
104mod tests {
105	use binius_field::{Field, PackedBinaryField4x32b, packed::set_packed_slice};
106	use itertools::Itertools;
107
108	use super::*;
109
110	type P = PackedBinaryField4x32b;
111	type F = <P as PackedField>::Scalar;
112
113	#[test]
114	fn test_tensor_prod_eq_ind() {
115		let v0 = F::new(1);
116		let v1 = F::new(2);
117		let query = vec![v0, v1];
118		let mut result = vec![P::default(); 1 << (query.len() - P::LOG_WIDTH)];
119		set_packed_slice(&mut result, 0, F::ONE);
120		tensor_prod_eq_ind(0, &mut result, &query).unwrap();
121		let result = PackedField::iter_slice(&result).collect_vec();
122		assert_eq!(
123			result,
124			vec![
125				(F::ONE - v0) * (F::ONE - v1),
126				v0 * (F::ONE - v1),
127				(F::ONE - v0) * v1,
128				v0 * v1
129			]
130		);
131	}
132
133	#[test]
134	fn test_eq_ind_partial_eval_empty() {
135		let result = eq_ind_partial_eval::<P>(&[]);
136		let expected = vec![P::set_single(F::ONE)];
137		assert_eq!(result, expected);
138	}
139
140	#[test]
141	fn test_eq_ind_partial_eval_single_var() {
142		// Only one query coordinate
143		let r0 = F::new(2);
144		let result = eq_ind_partial_eval::<P>(&[r0]);
145		let expected = vec![(F::ONE - r0), r0, F::ZERO, F::ZERO];
146		let result = PackedField::iter_slice(&result).collect_vec();
147		assert_eq!(result, expected);
148	}
149
150	#[test]
151	fn test_eq_ind_partial_eval_two_vars() {
152		// Two query coordinates
153		let r0 = F::new(2);
154		let r1 = F::new(3);
155		let result = eq_ind_partial_eval::<P>(&[r0, r1]);
156		let result = PackedField::iter_slice(&result).collect_vec();
157		let expected = vec![
158			(F::ONE - r0) * (F::ONE - r1),
159			r0 * (F::ONE - r1),
160			(F::ONE - r0) * r1,
161			r0 * r1,
162		];
163		assert_eq!(result, expected);
164	}
165
166	#[test]
167	fn test_eq_ind_partial_eval_three_vars() {
168		// Case with three query coordinates
169		let r0 = F::new(2);
170		let r1 = F::new(3);
171		let r2 = F::new(5);
172		let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
173		let result = PackedField::iter_slice(&result).collect_vec();
174
175		let expected = vec![
176			(F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
177			r0 * (F::ONE - r1) * (F::ONE - r2),
178			(F::ONE - r0) * r1 * (F::ONE - r2),
179			r0 * r1 * (F::ONE - r2),
180			(F::ONE - r0) * (F::ONE - r1) * r2,
181			r0 * (F::ONE - r1) * r2,
182			(F::ONE - r0) * r1 * r2,
183			r0 * r1 * r2,
184		];
185		assert_eq!(result, expected);
186	}
187}