binius_math/multilinear/
eq.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, ops::DerefMut};
4
5use binius_field::{
6	Field, PackedField,
7	packed::{get_packed_slice, set_packed_slice},
8};
9use binius_utils::rayon::prelude::*;
10
11use crate::{Error, FieldBuffer};
12
13/// Tensor of values with the eq indicator evaluated at extra_query_coordinates.
14///
15/// Let $n$ be log_n_values, $p$, $k$ be the lengths of `packed_values` and
16/// `extra_query_coordinates`. Requires
17///     * $n \geq k$
18///     * p = max(1, 2^{n+k} / P::WIDTH)
19/// Let $v$ be a vector corresponding to the first $2^n$ scalar values of `values`.
20/// Let $r = (r_0, \ldots, r_{k-1})$ be the vector of `extra_query_coordinates`.
21///
22/// # Precondition:
23/// * `values` must be zero-extended to the new log length before calling this function. This
24///   condition is necessary to get the best performance.
25///
26/// # Formal Definition
27/// `values` is updated to contain the result of:
28/// $v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})$
29/// which is now a vector of length $2^{n+k}$. If 2^{n+k} < P::WIDTH, then
30/// the result is packed into a single element of `values` where only the first
31/// 2^{n+k} elements have meaning.
32///
33/// # Interpretation
34/// Let $f$ be an $n$ variate multilinear polynomial that has evaluations over
35/// the $n$ dimensional hypercube corresponding to $v$.
36/// Then `values` is updated to contain the evaluations of $g$ over the $n+k$-dimensional
37/// hypercube where
38/// * $g(x_0, \ldots, x_{n+k-1}) = f(x_0, \ldots, x_{n-1}) * eq(x_n, \ldots, x_{n+k-1}, r)$
39fn tensor_prod_eq_ind<P: PackedField, Data: DerefMut<Target = [P]>>(
40	values: &mut FieldBuffer<P, Data>,
41	extra_query_coordinates: &[P::Scalar],
42) -> Result<(), Error> {
43	let new_log_len = values.log_len() + extra_query_coordinates.len();
44
45	if values.log_cap() < new_log_len {
46		return Err(Error::IncorrectArgumentLength {
47			arg: "values.log_cap()".to_string(),
48			expected: new_log_len,
49		});
50	}
51
52	for &r_i in extra_query_coordinates {
53		let packed_r_i = P::broadcast(r_i);
54
55		values.resize(values.log_len() + 1)?;
56		let mut split = values
57			.split_half_mut_no_closure()
58			.expect("doubled by zero_extend()");
59		let (mut lo, mut hi) = split.halves();
60
61		(lo.as_mut(), hi.as_mut())
62			.into_par_iter()
63			.for_each(|(lo_i, hi_i)| {
64				let prod = (*lo_i) * packed_r_i;
65				*lo_i -= prod;
66				*hi_i = prod;
67			});
68	}
69
70	Ok(())
71}
72
73/// Left tensor of values with the eq indicator evaluated at extra_query_coordinates.
74///
75/// # Formal definition
76/// This differs from `tensor_prod_eq_ind` in tensor product being applied on the left
77/// and in reversed order:
78/// $(1 - r_{k-1}, r_{k-1}) \otimes \ldots \otimes (1 - r_0, r_0) \otimes v$
79///
80/// # Implementation
81/// This operation is inplace, singlethreaded, and not very optimized. Main intent is to
82/// use it on small tensors out of the hot paths.
83pub fn tensor_prod_eq_ind_prepend<P: PackedField, Data: DerefMut<Target = [P]>>(
84	values: &mut FieldBuffer<P, Data>,
85	extra_query_coordinates: &[P::Scalar],
86) -> Result<(), Error> {
87	let new_log_len = values.log_len() + extra_query_coordinates.len();
88
89	if values.log_cap() < new_log_len {
90		return Err(Error::IncorrectArgumentLength {
91			arg: "values.log_cap()".to_string(),
92			expected: new_log_len,
93		});
94	}
95
96	for &r_i in extra_query_coordinates.iter().rev() {
97		values.zero_extend(values.log_len() + 1)?;
98		for i in (0..values.len() / 2).rev() {
99			let eval = get_packed_slice(values.as_ref(), i);
100			set_packed_slice(values.as_mut(), 2 * i, eval * (P::Scalar::ONE - r_i));
101			set_packed_slice(values.as_mut(), 2 * i + 1, eval * r_i);
102		}
103	}
104
105	Ok(())
106}
107
108/// Computes the partial evaluation of the equality indicator polynomial.
109///
110/// Given an $n$-coordinate point $r_0, ..., r_n$, this computes the partial evaluation of the
111/// equality indicator polynomial $\widetilde{eq}(X_0, ..., X_{n-1}, r_0, ..., r_{n-1})$ and
112/// returns its values over the $n$-dimensional hypercube.
113///
114/// The returned values are equal to the tensor product
115///
116/// $$
117/// (1 - r_0, r_0) \otimes ... \otimes (1 - r_{n-1}, r_{n-1}).
118/// $$
119///
120/// See [DP23], Section 2.1 for more information about the equality indicator polynomial.
121///
122/// [DP23]: <https://eprint.iacr.org/2023/1784>
123pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> FieldBuffer<P> {
124	// The buffer needs to have the correct size: 2^max(point.len(), P::LOG_WIDTH) elements
125	// but since tensor_prod_eq_ind starts with log_n_values=0, we need the final size
126	let log_size = point.len();
127	let mut buffer = FieldBuffer::zeros_truncated(0, log_size).expect("log_size >= 0");
128	buffer.set(0, P::Scalar::ONE);
129	tensor_prod_eq_ind(&mut buffer, point).expect("buffer is allocated with the correct length");
130	buffer
131}
132
133/// Truncate the equality indicator expansion to the low indexed variables.
134///
135/// This routine computes $\widetilde{eq}(X_0, ..., X_{n'-1}, r_0, ..., r_{n'-1})$ from
136/// $\widetilde{eq}(X_0, ..., X_{n-1}, r_0, ..., r_{n-1})$ where $n' \le n$ by repeatedly summing
137/// field buffer "halves" inplace. The equality indicator expansion occupies a prefix of
138/// the field buffer; scalars after the truncated length are zeroed out.
139pub fn eq_ind_truncate_low_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
140	values: &mut FieldBuffer<P, Data>,
141	truncated_log_len: usize,
142) -> Result<(), Error> {
143	if truncated_log_len > values.log_len() {
144		return Err(Error::ArgumentRangeError {
145			arg: "truncated_log_len".to_string(),
146			range: 0..values.log_len() + 1,
147		});
148	}
149
150	for log_len in (truncated_log_len..values.log_len()).rev() {
151		values.split_half_mut(|lo, hi| {
152			(lo.as_mut(), hi.as_ref())
153				.into_par_iter()
154				.for_each(|(zero, one)| {
155					*zero += *one;
156				});
157		})?;
158
159		values.truncate(log_len);
160	}
161
162	Ok(())
163}
164
165/// Evaluates the 2-variate multilinear which indicates the equality condition over the hypercube.
166///
167/// This evaluates the bivariate polynomial
168///
169/// $$
170/// \widetilde{eq}(X, Y) = X Y + (1 - X) (1 - Y)
171/// $$
172///
173/// In the special case of binary fields, the evaluation can be simplified to
174///
175/// $$
176/// \widetilde{eq}(X, Y) = X + Y + 1
177/// $$
178#[inline(always)]
179pub fn eq_one_var<F: Field>(x: F, y: F) -> F {
180	if F::CHARACTERISTIC == 2 {
181		// Optimize away the multiplication for binary fields
182		x + y + F::ONE
183	} else {
184		x * y + (F::ONE - x) * (F::ONE - y)
185	}
186}
187
188/// Evaluates the equality indicator multilinear at a pair of coordinates.
189///
190/// This evaluates the 2n-variate multilinear polynomial
191///
192/// $$
193/// \widetilde{eq}(X_0, \ldots, X_{n-1}, Y_0, \ldots, Y_{n-1}) = \prod_{i=0}^{n-1} X_i Y_i + (1 -
194/// X_i) (1 - Y_i) $$
195///
196/// In the special case of binary fields, the evaluation can be simplified to
197///
198/// See [DP23], Section 2.1 for more information about the equality indicator polynomial.
199///
200/// [DP23]: <https://eprint.iacr.org/2023/1784>
201pub fn eq_ind<F: Field>(x: &[F], y: &[F]) -> F {
202	assert_eq!(x.len(), y.len(), "pre-condition: x and y must be the same length");
203	iter::zip(x, y).map(|(&x, &y)| eq_one_var(x, y)).product()
204}
205
206#[cfg(test)]
207mod tests {
208	use rand::prelude::*;
209
210	use super::*;
211	use crate::test_utils::{B128, Packed128b, index_to_hypercube_point, random_scalars};
212
213	type P = Packed128b;
214	type F = B128;
215
216	#[test]
217	fn test_tensor_prod_eq_ind() {
218		let v0 = F::from(1);
219		let v1 = F::from(2);
220		let query = vec![v0, v1];
221		// log_len = 0, query.len() = 2, so total log_cap = 2
222		let mut result = FieldBuffer::zeros_truncated(0, query.len()).unwrap();
223		result.set_checked(0, F::ONE).unwrap();
224		tensor_prod_eq_ind(&mut result, &query).unwrap();
225		let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
226		assert_eq!(
227			result_vec,
228			vec![
229				(F::ONE - v0) * (F::ONE - v1),
230				v0 * (F::ONE - v1),
231				(F::ONE - v0) * v1,
232				v0 * v1
233			]
234		);
235	}
236
237	#[test]
238	fn test_tensor_prod_eq_ind_inplace_expansion() {
239		let mut rng = StdRng::seed_from_u64(0);
240
241		let exps = 4;
242		let max_n_vars = exps * (exps + 1) / 2;
243		let mut coords = Vec::with_capacity(max_n_vars);
244		let mut eq_expansion = FieldBuffer::zeros_truncated(0, max_n_vars).unwrap();
245		eq_expansion.set_checked(0, F::ONE).unwrap();
246
247		for extra_count in 1..=exps {
248			let extra = random_scalars(&mut rng, extra_count);
249
250			tensor_prod_eq_ind::<P, _>(&mut eq_expansion, &extra).unwrap();
251			coords.extend(&extra);
252
253			assert_eq!(eq_expansion.log_len(), coords.len());
254			for i in 0..eq_expansion.len() {
255				let v = eq_expansion.get_checked(i).unwrap();
256				let hypercube_point = index_to_hypercube_point(coords.len(), i);
257				assert_eq!(v, eq_ind(&hypercube_point, &coords));
258			}
259		}
260	}
261
262	#[test]
263	fn test_eq_ind_partial_eval_empty() {
264		let result = eq_ind_partial_eval::<P>(&[]);
265		// For P with LOG_WIDTH = 2, the minimum buffer size is 4 elements
266		assert_eq!(result.log_len(), 0);
267		assert_eq!(result.len(), 1);
268		let result_mut = result;
269		assert_eq!(result_mut.get_checked(0).unwrap(), F::ONE);
270	}
271
272	#[test]
273	fn test_eq_ind_partial_eval_single_var() {
274		// Only one query coordinate
275		let r0 = F::new(2);
276		let result = eq_ind_partial_eval::<P>(&[r0]);
277		assert_eq!(result.log_len(), 1);
278		assert_eq!(result.len(), 2);
279		let result_mut = result;
280		assert_eq!(result_mut.get_checked(0).unwrap(), F::ONE - r0);
281		assert_eq!(result_mut.get_checked(1).unwrap(), r0);
282	}
283
284	#[test]
285	fn test_eq_ind_partial_eval_two_vars() {
286		// Two query coordinates
287		let r0 = F::new(2);
288		let r1 = F::new(3);
289		let result = eq_ind_partial_eval::<P>(&[r0, r1]);
290		assert_eq!(result.log_len(), 2);
291		assert_eq!(result.len(), 4);
292		let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
293		let expected = vec![
294			(F::ONE - r0) * (F::ONE - r1),
295			r0 * (F::ONE - r1),
296			(F::ONE - r0) * r1,
297			r0 * r1,
298		];
299		assert_eq!(result_vec, expected);
300	}
301
302	#[test]
303	fn test_eq_ind_partial_eval_three_vars() {
304		// Case with three query coordinates
305		let r0 = F::new(2);
306		let r1 = F::new(3);
307		let r2 = F::new(5);
308		let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
309		assert_eq!(result.log_len(), 3);
310		assert_eq!(result.len(), 8);
311		let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
312
313		let expected = vec![
314			(F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
315			r0 * (F::ONE - r1) * (F::ONE - r2),
316			(F::ONE - r0) * r1 * (F::ONE - r2),
317			r0 * r1 * (F::ONE - r2),
318			(F::ONE - r0) * (F::ONE - r1) * r2,
319			r0 * (F::ONE - r1) * r2,
320			(F::ONE - r0) * r1 * r2,
321			r0 * r1 * r2,
322		];
323		assert_eq!(result_vec, expected);
324	}
325
326	// Property-based test that eq_ind_partial_eval is consistent with eq_ind at a random index.
327	#[test]
328	fn test_eq_ind_partial_eval_consistent_on_hypercube() {
329		let mut rng = StdRng::seed_from_u64(0);
330
331		let n_vars = 5;
332
333		let point = random_scalars(&mut rng, n_vars);
334		let result = eq_ind_partial_eval::<P>(&point);
335		let index = rng.random_range(..1 << n_vars);
336
337		// Query the value at that index
338		let result_mut = result;
339		let partial_eval_value = result_mut.get_checked(index).unwrap();
340
341		let index_bits = index_to_hypercube_point(n_vars, index);
342		let eq_ind_value = eq_ind(&point, &index_bits);
343
344		assert_eq!(partial_eval_value, eq_ind_value);
345	}
346
347	#[test]
348	fn test_eq_ind_truncate_low_inplace() {
349		let mut rng = StdRng::seed_from_u64(0);
350
351		let reds = 4;
352		let n_vars = reds * (reds + 1) / 2;
353		let point = random_scalars(&mut rng, n_vars);
354
355		let mut eq_ind = eq_ind_partial_eval::<P>(&point);
356		let mut log_n_values = n_vars;
357
358		for reduction in (0..=reds).rev() {
359			let truncated_log_n_values = log_n_values - reduction;
360			eq_ind_truncate_low_inplace(&mut eq_ind, truncated_log_n_values).unwrap();
361
362			let eq_ind_ref = eq_ind_partial_eval::<P>(&point[..truncated_log_n_values]);
363			assert_eq!(eq_ind_ref.len(), eq_ind.len());
364			for i in 0..eq_ind.len() {
365				assert_eq!(eq_ind.get_checked(i).unwrap(), eq_ind_ref.get_checked(i).unwrap());
366			}
367
368			log_n_values = truncated_log_n_values;
369		}
370
371		assert_eq!(log_n_values, 0);
372	}
373
374	#[test]
375	fn test_tensor_prod_eq_prepend_conforms_to_append() {
376		let mut rng = StdRng::seed_from_u64(0);
377
378		let n_vars = 10;
379		let base_vars = 4;
380
381		let point = random_scalars::<F>(&mut rng, n_vars);
382
383		let append = eq_ind_partial_eval(&point);
384
385		let mut prepend = FieldBuffer::<P>::zeros_truncated(0, n_vars).unwrap();
386		let (prefix, suffix) = point.split_at(n_vars - base_vars);
387		prepend.set_checked(0, F::ONE).unwrap();
388		tensor_prod_eq_ind(&mut prepend, suffix).unwrap();
389		tensor_prod_eq_ind_prepend(&mut prepend, prefix).unwrap();
390
391		assert_eq!(append, prepend);
392	}
393}