binius_math/multilinear/
eq.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, ops::DerefMut};
4
5use binius_field::{
6	Field, PackedField,
7	packed::{get_packed_slice, set_packed_slice},
8};
9use binius_utils::rayon::prelude::*;
10
11use crate::{Error, FieldBuffer};
12
13/// Tensor of values with the eq indicator evaluated at extra_query_coordinates.
14///
15/// Let $n$ be log_n_values, $p$, $k$ be the lengths of `packed_values` and
16/// `extra_query_coordinates`. Requires
17///     * $n \geq k$
18///     * p = max(1, 2^{n+k} / P::WIDTH)
19/// Let $v$ be a vector corresponding to the first $2^n$ scalar values of `values`.
20/// Let $r = (r_0, \ldots, r_{k-1})$ be the vector of `extra_query_coordinates`.
21///
22/// # Precondition:
23/// * `values` must be zero-extended to the new log length before calling this function. This
24///   condition is necessary to get the best performance.
25///
26/// # Formal Definition
27/// `values` is updated to contain the result of:
28/// $v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})$
29/// which is now a vector of length $2^{n+k}$. If 2^{n+k} < P::WIDTH, then
30/// the result is packed into a single element of `values` where only the first
31/// 2^{n+k} elements have meaning.
32///
33/// # Interpretation
34/// Let $f$ be an $n$ variate multilinear polynomial that has evaluations over
35/// the $n$ dimensional hypercube corresponding to $v$.
36/// Then `values` is updated to contain the evaluations of $g$ over the $n+k$-dimensional
37/// hypercube where
38/// * $g(x_0, \ldots, x_{n+k-1}) = f(x_0, \ldots, x_{n-1}) * eq(x_n, \ldots, x_{n+k-1}, r)$
39fn tensor_prod_eq_ind<P: PackedField, Data: DerefMut<Target = [P]>>(
40	values: &mut FieldBuffer<P, Data>,
41	extra_query_coordinates: &[P::Scalar],
42) -> Result<(), Error> {
43	let new_log_len = values.log_len() + extra_query_coordinates.len();
44
45	if values.log_cap() < new_log_len {
46		return Err(Error::IncorrectArgumentLength {
47			arg: "values.log_cap()".to_string(),
48			expected: new_log_len,
49		});
50	}
51
52	for &r_i in extra_query_coordinates {
53		let packed_r_i = P::broadcast(r_i);
54
55		values.resize(values.log_len() + 1)?;
56		let mut split = values
57			.split_half_mut_no_closure()
58			.expect("doubled by zero_extend()");
59		let (mut lo, mut hi) = split.halves();
60
61		(lo.as_mut(), hi.as_mut())
62			.into_par_iter()
63			.for_each(|(lo_i, hi_i)| {
64				let prod = (*lo_i) * packed_r_i;
65				*lo_i -= prod;
66				*hi_i = prod;
67			});
68	}
69
70	Ok(())
71}
72
73/// Left tensor of values with the eq indicator evaluated at extra_query_coordinates.
74///
75/// # Formal definition
76/// This differs from `tensor_prod_eq_ind` in tensor product being applied on the left
77/// and in reversed order:
78/// $(1 - r_{k-1}, r_{k-1}) \otimes \ldots \otimes (1 - r_0, r_0) \otimes v$
79///
80/// # Implementation
81/// This operation is inplace, singlethreaded, and not very optimized. Main intent is to
82/// use it on small tensors out of the hot paths.
83pub fn tensor_prod_eq_ind_prepend<P: PackedField, Data: DerefMut<Target = [P]>>(
84	values: &mut FieldBuffer<P, Data>,
85	extra_query_coordinates: &[P::Scalar],
86) -> Result<(), Error> {
87	let new_log_len = values.log_len() + extra_query_coordinates.len();
88
89	if values.log_cap() < new_log_len {
90		return Err(Error::IncorrectArgumentLength {
91			arg: "values.log_cap()".to_string(),
92			expected: new_log_len,
93		});
94	}
95
96	for &r_i in extra_query_coordinates.iter().rev() {
97		values.zero_extend(values.log_len() + 1)?;
98		for i in (0..values.len() / 2).rev() {
99			let eval = get_packed_slice(values.as_ref(), i);
100			set_packed_slice(values.as_mut(), 2 * i, eval * (P::Scalar::ONE - r_i));
101			set_packed_slice(values.as_mut(), 2 * i + 1, eval * r_i);
102		}
103	}
104
105	Ok(())
106}
107
108/// Computes the partial evaluation of the equality indicator polynomial.
109///
110/// Given an $n$-coordinate point $r_0, ..., r_n$, this computes the partial evaluation of the
111/// equality indicator polynomial $\widetilde{eq}(X_0, ..., X_{n-1}, r_0, ..., r_{n-1})$ and
112/// returns its values over the $n$-dimensional hypercube.
113///
114/// The returned values are equal to the tensor product
115///
116/// $$
117/// (1 - r_0, r_0) \otimes ... \otimes (1 - r_{n-1}, r_{n-1}).
118/// $$
119///
120/// See [DP23], Section 2.1 for more information about the equality indicator polynomial.
121///
122/// [DP23]: <https://eprint.iacr.org/2023/1784>
123pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> FieldBuffer<P> {
124	// The buffer needs to have the correct size: 2^max(point.len(), P::LOG_WIDTH) elements
125	// but since tensor_prod_eq_ind starts with log_n_values=0, we need the final size
126	let log_size = point.len();
127	let mut buffer = FieldBuffer::zeros_truncated(0, log_size).expect("log_size >= 0");
128	buffer
129		.set(0, P::Scalar::ONE)
130		.expect("buffer has length exactly 1");
131	tensor_prod_eq_ind(&mut buffer, point).expect("buffer is allocated with the correct length");
132	buffer
133}
134
135/// Truncate the equality indicator expansion to the low indexed variables.
136///
137/// This routine computes $\widetilde{eq}(X_0, ..., X_{n'-1}, r_0, ..., r_{n'-1})$ from
138/// $\widetilde{eq}(X_0, ..., X_{n-1}, r_0, ..., r_{n-1})$ where $n' \le n$ by repeatedly summing
139/// field buffer "halves" inplace. The equality indicator expansion occupies a prefix of
140/// the field buffer; scalars after the truncated length are zeroed out.
141pub fn eq_ind_truncate_low_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
142	values: &mut FieldBuffer<P, Data>,
143	truncated_log_len: usize,
144) -> Result<(), Error> {
145	if truncated_log_len > values.log_len() {
146		return Err(Error::ArgumentRangeError {
147			arg: "truncated_log_len".to_string(),
148			range: 0..values.log_len() + 1,
149		});
150	}
151
152	for log_len in (truncated_log_len..values.log_len()).rev() {
153		values.split_half_mut(|lo, hi| {
154			(lo.as_mut(), hi.as_ref())
155				.into_par_iter()
156				.for_each(|(zero, one)| {
157					*zero += *one;
158				});
159		})?;
160
161		values.truncate(log_len);
162	}
163
164	Ok(())
165}
166
167/// Evaluates the 2-variate multilinear which indicates the equality condition over the hypercube.
168///
169/// This evaluates the bivariate polynomial
170///
171/// $$
172/// \widetilde{eq}(X, Y) = X Y + (1 - X) (1 - Y)
173/// $$
174///
175/// In the special case of binary fields, the evaluation can be simplified to
176///
177/// $$
178/// \widetilde{eq}(X, Y) = X + Y + 1
179/// $$
180#[inline(always)]
181pub fn eq_one_var<F: Field>(x: F, y: F) -> F {
182	if F::CHARACTERISTIC == 2 {
183		// Optimize away the multiplication for binary fields
184		x + y + F::ONE
185	} else {
186		x * y + (F::ONE - x) * (F::ONE - y)
187	}
188}
189
190/// Evaluates the equality indicator multilinear at a pair of coordinates.
191///
192/// This evaluates the 2n-variate multilinear polynomial
193///
194/// $$
195/// \widetilde{eq}(X_0, \ldots, X_{n-1}, Y_0, \ldots, Y_{n-1}) = \prod_{i=0}^{n-1} X_i Y_i + (1 -
196/// X_i) (1 - Y_i) $$
197///
198/// In the special case of binary fields, the evaluation can be simplified to
199///
200/// See [DP23], Section 2.1 for more information about the equality indicator polynomial.
201///
202/// [DP23]: <https://eprint.iacr.org/2023/1784>
203pub fn eq_ind<F: Field>(x: &[F], y: &[F]) -> F {
204	assert_eq!(x.len(), y.len(), "pre-condition: x and y must be the same length");
205	iter::zip(x, y).map(|(&x, &y)| eq_one_var(x, y)).product()
206}
207
208#[cfg(test)]
209mod tests {
210	use rand::prelude::*;
211
212	use super::*;
213	use crate::test_utils::{B128, Packed128b, index_to_hypercube_point, random_scalars};
214
215	type P = Packed128b;
216	type F = B128;
217
218	#[test]
219	fn test_tensor_prod_eq_ind() {
220		let v0 = F::from(1);
221		let v1 = F::from(2);
222		let query = vec![v0, v1];
223		// log_len = 0, query.len() = 2, so total log_cap = 2
224		let mut result = FieldBuffer::zeros_truncated(0, query.len()).unwrap();
225		result.set(0, F::ONE).unwrap();
226		tensor_prod_eq_ind(&mut result, &query).unwrap();
227		let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
228		assert_eq!(
229			result_vec,
230			vec![
231				(F::ONE - v0) * (F::ONE - v1),
232				v0 * (F::ONE - v1),
233				(F::ONE - v0) * v1,
234				v0 * v1
235			]
236		);
237	}
238
239	#[test]
240	fn test_tensor_prod_eq_ind_inplace_expansion() {
241		let mut rng = StdRng::seed_from_u64(0);
242
243		let exps = 4;
244		let max_n_vars = exps * (exps + 1) / 2;
245		let mut coords = Vec::with_capacity(max_n_vars);
246		let mut eq_expansion = FieldBuffer::zeros_truncated(0, max_n_vars).unwrap();
247		eq_expansion.set(0, F::ONE).unwrap();
248
249		for extra_count in 1..=exps {
250			let extra = random_scalars(&mut rng, extra_count);
251
252			tensor_prod_eq_ind::<P, _>(&mut eq_expansion, &extra).unwrap();
253			coords.extend(&extra);
254
255			assert_eq!(eq_expansion.log_len(), coords.len());
256			for i in 0..eq_expansion.len() {
257				let v = eq_expansion.get(i).unwrap();
258				let hypercube_point = index_to_hypercube_point(coords.len(), i);
259				assert_eq!(v, eq_ind(&hypercube_point, &coords));
260			}
261		}
262	}
263
264	#[test]
265	fn test_eq_ind_partial_eval_empty() {
266		let result = eq_ind_partial_eval::<P>(&[]);
267		// For P with LOG_WIDTH = 2, the minimum buffer size is 4 elements
268		assert_eq!(result.log_len(), 0);
269		assert_eq!(result.len(), 1);
270		let result_mut = result;
271		assert_eq!(result_mut.get(0).unwrap(), F::ONE);
272	}
273
274	#[test]
275	fn test_eq_ind_partial_eval_single_var() {
276		// Only one query coordinate
277		let r0 = F::new(2);
278		let result = eq_ind_partial_eval::<P>(&[r0]);
279		assert_eq!(result.log_len(), 1);
280		assert_eq!(result.len(), 2);
281		let result_mut = result;
282		assert_eq!(result_mut.get(0).unwrap(), F::ONE - r0);
283		assert_eq!(result_mut.get(1).unwrap(), r0);
284	}
285
286	#[test]
287	fn test_eq_ind_partial_eval_two_vars() {
288		// Two query coordinates
289		let r0 = F::new(2);
290		let r1 = F::new(3);
291		let result = eq_ind_partial_eval::<P>(&[r0, r1]);
292		assert_eq!(result.log_len(), 2);
293		assert_eq!(result.len(), 4);
294		let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
295		let expected = vec![
296			(F::ONE - r0) * (F::ONE - r1),
297			r0 * (F::ONE - r1),
298			(F::ONE - r0) * r1,
299			r0 * r1,
300		];
301		assert_eq!(result_vec, expected);
302	}
303
304	#[test]
305	fn test_eq_ind_partial_eval_three_vars() {
306		// Case with three query coordinates
307		let r0 = F::new(2);
308		let r1 = F::new(3);
309		let r2 = F::new(5);
310		let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
311		assert_eq!(result.log_len(), 3);
312		assert_eq!(result.len(), 8);
313		let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
314
315		let expected = vec![
316			(F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
317			r0 * (F::ONE - r1) * (F::ONE - r2),
318			(F::ONE - r0) * r1 * (F::ONE - r2),
319			r0 * r1 * (F::ONE - r2),
320			(F::ONE - r0) * (F::ONE - r1) * r2,
321			r0 * (F::ONE - r1) * r2,
322			(F::ONE - r0) * r1 * r2,
323			r0 * r1 * r2,
324		];
325		assert_eq!(result_vec, expected);
326	}
327
328	// Property-based test that eq_ind_partial_eval is consistent with eq_ind at a random index.
329	#[test]
330	fn test_eq_ind_partial_eval_consistent_on_hypercube() {
331		let mut rng = StdRng::seed_from_u64(0);
332
333		let n_vars = 5;
334
335		let point = random_scalars(&mut rng, n_vars);
336		let result = eq_ind_partial_eval::<P>(&point);
337		let index = rng.random_range(..1 << n_vars);
338
339		// Query the value at that index
340		let result_mut = result;
341		let partial_eval_value = result_mut.get(index).unwrap();
342
343		let index_bits = index_to_hypercube_point(n_vars, index);
344		let eq_ind_value = eq_ind(&point, &index_bits);
345
346		assert_eq!(partial_eval_value, eq_ind_value);
347	}
348
349	#[test]
350	fn test_eq_ind_truncate_low_inplace() {
351		let mut rng = StdRng::seed_from_u64(0);
352
353		let reds = 4;
354		let n_vars = reds * (reds + 1) / 2;
355		let point = random_scalars(&mut rng, n_vars);
356
357		let mut eq_ind = eq_ind_partial_eval::<P>(&point);
358		let mut log_n_values = n_vars;
359
360		for reduction in (0..=reds).rev() {
361			let truncated_log_n_values = log_n_values - reduction;
362			eq_ind_truncate_low_inplace(&mut eq_ind, truncated_log_n_values).unwrap();
363
364			let eq_ind_ref = eq_ind_partial_eval::<P>(&point[..truncated_log_n_values]);
365			assert_eq!(eq_ind_ref.len(), eq_ind.len());
366			for i in 0..eq_ind.len() {
367				assert_eq!(eq_ind.get(i).unwrap(), eq_ind_ref.get(i).unwrap());
368			}
369
370			log_n_values = truncated_log_n_values;
371		}
372
373		assert_eq!(log_n_values, 0);
374	}
375
376	#[test]
377	fn test_tensor_prod_eq_prepend_conforms_to_append() {
378		let mut rng = StdRng::seed_from_u64(0);
379
380		let n_vars = 10;
381		let base_vars = 4;
382
383		let point = random_scalars::<F>(&mut rng, n_vars);
384
385		let append = eq_ind_partial_eval(&point);
386
387		let mut prepend = FieldBuffer::<P>::zeros_truncated(0, n_vars).unwrap();
388		let (prefix, suffix) = point.split_at(n_vars - base_vars);
389		prepend.set(0, F::ONE).unwrap();
390		tensor_prod_eq_ind(&mut prepend, suffix).unwrap();
391		tensor_prod_eq_ind_prepend(&mut prepend, prefix).unwrap();
392
393		assert_eq!(append, prepend);
394	}
395}