Skip to main content

binius_math/multilinear/
eq.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{iter, ops::DerefMut};
4
5use binius_field::{
6	Field, PackedField,
7	field::FieldOps,
8	packed::{get_packed_slice, set_packed_slice},
9};
10use binius_utils::rayon::prelude::*;
11
12use crate::FieldBuffer;
13
14/// Tensor of values with the eq indicator evaluated at extra_query_coordinates.
15///
16/// Let $n$ be log_n_values, $p$, $k$ be the lengths of `packed_values` and
17/// `extra_query_coordinates`. Requires
18///     * $n \geq k$
19///     * p = max(1, 2^{n+k} / P::WIDTH)
20/// Let $v$ be a vector corresponding to the first $2^n$ scalar values of `values`.
21/// Let $r = (r_0, \ldots, r_{k-1})$ be the vector of `extra_query_coordinates`.
22///
23/// ## Preconditions
24///
25/// * `values` must have enough capacity: `values.log_cap() >= values.log_len() +
26///   extra_query_coordinates.len()`
27/// * `values` must be zero-extended to the new log length before calling this function. This
28///   condition is necessary to get the best performance.
29///
30/// # Formal Definition
31/// `values` is updated to contain the result of:
32/// $v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})$
33/// which is now a vector of length $2^{n+k}$. If 2^{n+k} < P::WIDTH, then
34/// the result is packed into a single element of `values` where only the first
35/// 2^{n+k} elements have meaning.
36///
37/// # Interpretation
38/// Let $f$ be an $n$ variate multilinear polynomial that has evaluations over
39/// the $n$ dimensional hypercube corresponding to $v$.
40/// Then `values` is updated to contain the evaluations of $g$ over the $n+k$-dimensional
41/// hypercube where
42/// * $g(x_0, \ldots, x_{n+k-1}) = f(x_0, \ldots, x_{n-1}) * eq(x_n, \ldots, x_{n+k-1}, r)$
43fn tensor_prod_eq_ind<P: PackedField, Data: DerefMut<Target = [P]>>(
44	values: &mut FieldBuffer<P, Data>,
45	extra_query_coordinates: &[P::Scalar],
46) {
47	let new_log_len = values.log_len() + extra_query_coordinates.len();
48
49	assert!(
50		values.log_cap() >= new_log_len,
51		"precondition: values capacity must be sufficient for expansion"
52	);
53
54	for &r_i in extra_query_coordinates {
55		let packed_r_i = P::broadcast(r_i);
56
57		values.resize(values.log_len() + 1);
58		let mut split = values.split_half_mut();
59		let (mut lo, mut hi) = split.halves();
60
61		(lo.as_mut(), hi.as_mut())
62			.into_par_iter()
63			.for_each(|(lo_i, hi_i)| {
64				let prod = (*lo_i) * packed_r_i;
65				*lo_i -= prod;
66				*hi_i = prod;
67			});
68	}
69}
70
71/// Left tensor of values with the eq indicator evaluated at extra_query_coordinates.
72///
73/// # Formal definition
74/// This differs from `tensor_prod_eq_ind` in tensor product being applied on the left
75/// and in reversed order:
76/// $(1 - r_{k-1}, r_{k-1}) \otimes \ldots \otimes (1 - r_0, r_0) \otimes v$
77///
78/// # Implementation
79/// This operation is inplace, singlethreaded, and not very optimized. Main intent is to
80/// use it on small tensors out of the hot paths.
81///
82/// ## Preconditions
83///
84/// * `values` must have enough capacity: `values.log_cap() >= values.log_len() +
85///   extra_query_coordinates.len()`
86pub fn tensor_prod_eq_ind_prepend<P: PackedField, Data: DerefMut<Target = [P]>>(
87	values: &mut FieldBuffer<P, Data>,
88	extra_query_coordinates: &[P::Scalar],
89) {
90	let new_log_len = values.log_len() + extra_query_coordinates.len();
91
92	assert!(
93		values.log_cap() >= new_log_len,
94		"precondition: values capacity must be sufficient for expansion"
95	);
96
97	for &r_i in extra_query_coordinates.iter().rev() {
98		values.zero_extend(values.log_len() + 1);
99		for i in (0..values.len() / 2).rev() {
100			let eval = get_packed_slice(values.as_ref(), i);
101			set_packed_slice(values.as_mut(), 2 * i, eval * (P::Scalar::ONE - r_i));
102			set_packed_slice(values.as_mut(), 2 * i + 1, eval * r_i);
103		}
104	}
105}
106
107/// Computes the partial evaluation of the equality indicator polynomial.
108///
109/// Given an $n$-coordinate point $r_0, ..., r_n$, this computes the partial evaluation of the
110/// equality indicator polynomial $\widetilde{eq}(X_0, ..., X_{n-1}, r_0, ..., r_{n-1})$ and
111/// returns its values over the $n$-dimensional hypercube.
112///
113/// The returned values are equal to the tensor product
114///
115/// $$
116/// (1 - r_0, r_0) \otimes ... \otimes (1 - r_{n-1}, r_{n-1}).
117/// $$
118///
119/// See [DP23], Section 2.1 for more information about the equality indicator polynomial.
120///
121/// [DP23]: <https://eprint.iacr.org/2023/1784>
122pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> FieldBuffer<P> {
123	// The buffer needs to have the correct size: 2^max(point.len(), P::LOG_WIDTH) elements
124	// but since tensor_prod_eq_ind starts with log_n_values=0, we need the final size
125	let log_size = point.len();
126	let mut buffer = FieldBuffer::zeros_truncated(0, log_size);
127	buffer.set(0, P::Scalar::ONE);
128	tensor_prod_eq_ind(&mut buffer, point);
129	buffer
130}
131
132/// Truncate the equality indicator expansion to the low indexed variables.
133///
134/// This routine computes $\widetilde{eq}(X_0, ..., X_{n'-1}, r_0, ..., r_{n'-1})$ from
135/// $\widetilde{eq}(X_0, ..., X_{n-1}, r_0, ..., r_{n-1})$ where $n' \le n$ by repeatedly summing
136/// field buffer "halves" inplace. The equality indicator expansion occupies a prefix of
137/// the field buffer; scalars after the truncated length are zeroed out.
138///
139/// ## Preconditions
140///
141/// * `truncated_log_len` must be at most `values.log_len()`
142pub fn eq_ind_truncate_low_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
143	values: &mut FieldBuffer<P, Data>,
144	truncated_log_len: usize,
145) {
146	assert!(
147		truncated_log_len <= values.log_len(),
148		"precondition: truncated_log_len must be at most values.log_len()"
149	);
150
151	for log_len in (truncated_log_len..values.log_len()).rev() {
152		{
153			let mut split = values.split_half_mut();
154			let (mut lo, hi) = split.halves();
155			(lo.as_mut(), hi.as_ref())
156				.into_par_iter()
157				.for_each(|(zero, one)| {
158					*zero += *one;
159				});
160		}
161
162		values.truncate(log_len);
163	}
164}
165
166/// Evaluates the 2-variate multilinear which indicates the equality condition over the hypercube.
167///
168/// This evaluates the bivariate polynomial
169///
170/// $$
171/// \widetilde{eq}(X, Y) = X Y + (1 - X) (1 - Y)
172/// $$
173///
174/// In the special case of binary fields, the evaluation can be simplified to
175///
176/// $$
177/// \widetilde{eq}(X, Y) = X + Y + 1
178/// $$
179#[inline(always)]
180pub fn eq_one_var<F: FieldOps>(x: F, y: F) -> F {
181	let one = F::one();
182	x.clone() * y.clone() + (one.clone() - x) * (one.clone() - y)
183}
184
185/// Evaluates the equality indicator multilinear at a pair of coordinates.
186///
187/// This evaluates the 2n-variate multilinear polynomial
188///
189/// $$
190/// \widetilde{eq}(X_0, \ldots, X_{n-1}, Y_0, \ldots, Y_{n-1}) = \prod_{i=0}^{n-1} X_i Y_i + (1 -
191/// X_i) (1 - Y_i) $$
192///
193/// In the special case of binary fields, the evaluation can be simplified to
194///
195/// See [DP23], Section 2.1 for more information about the equality indicator polynomial.
196///
197/// [DP23]: <https://eprint.iacr.org/2023/1784>
198pub fn eq_ind<F: FieldOps>(x: &[F], y: &[F]) -> F {
199	assert_eq!(x.len(), y.len(), "pre-condition: x and y must be the same length");
200	iter::zip(x, y)
201		.map(|(x, y)| eq_one_var(x.clone(), y.clone()))
202		.product()
203}
204
205/// Computes the partial evaluation of the equality indicator polynomial, returning scalars.
206///
207/// This is a scalar-only variant of [`eq_ind_partial_eval`] that returns a `Vec<F>` instead of
208/// a [`FieldBuffer`]. It computes the tensor product
209///
210/// $$
211/// (1 - r_0, r_0) \otimes ... \otimes (1 - r_{n-1}, r_{n-1}).
212/// $$
213pub fn eq_ind_partial_eval_scalars<F: FieldOps>(point: &[F]) -> Vec<F> {
214	let mut result = Vec::with_capacity(1 << point.len());
215	result.push(F::one());
216
217	for r_i in point {
218		// Double the buffer size. For each existing value in 0..size,
219		// the lo half gets val * (1 - r_i) and the hi half gets val * r_i.
220		// Process in reverse so that writes to hi don't overwrite values we need.
221		let len = result.len();
222		for j in 0..len {
223			let prod = result[j].clone() * r_i.clone();
224			result[j] -= prod.clone();
225			result.push(prod);
226		}
227	}
228	result
229}
230
231#[cfg(test)]
232mod tests {
233	use rand::prelude::*;
234
235	use super::*;
236	use crate::test_utils::{B128, Packed128b, index_to_hypercube_point, random_scalars};
237
238	type P = Packed128b;
239	type F = B128;
240
241	#[test]
242	fn test_tensor_prod_eq_ind() {
243		let v0 = F::from(1);
244		let v1 = F::from(2);
245		let query = vec![v0, v1];
246		// log_len = 0, query.len() = 2, so total log_cap = 2
247		let mut result = FieldBuffer::zeros_truncated(0, query.len());
248		result.set(0, F::ONE);
249		tensor_prod_eq_ind(&mut result, &query);
250		let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
251		assert_eq!(
252			result_vec,
253			vec![
254				(F::ONE - v0) * (F::ONE - v1),
255				v0 * (F::ONE - v1),
256				(F::ONE - v0) * v1,
257				v0 * v1
258			]
259		);
260	}
261
262	#[test]
263	fn test_tensor_prod_eq_ind_inplace_expansion() {
264		let mut rng = StdRng::seed_from_u64(0);
265
266		let exps = 4;
267		let max_n_vars = exps * (exps + 1) / 2;
268		let mut coords = Vec::with_capacity(max_n_vars);
269		let mut eq_expansion = FieldBuffer::zeros_truncated(0, max_n_vars);
270		eq_expansion.set(0, F::ONE);
271
272		for extra_count in 1..=exps {
273			let extra = random_scalars(&mut rng, extra_count);
274
275			tensor_prod_eq_ind::<P, _>(&mut eq_expansion, &extra);
276			coords.extend(&extra);
277
278			assert_eq!(eq_expansion.log_len(), coords.len());
279			for i in 0..eq_expansion.len() {
280				let v = eq_expansion.get(i);
281				let hypercube_point = index_to_hypercube_point(coords.len(), i);
282				assert_eq!(v, eq_ind(&hypercube_point, &coords));
283			}
284		}
285	}
286
287	#[test]
288	fn test_eq_ind_partial_eval_empty() {
289		let result = eq_ind_partial_eval::<P>(&[]);
290		// For P with LOG_WIDTH = 2, the minimum buffer size is 4 elements
291		assert_eq!(result.log_len(), 0);
292		assert_eq!(result.len(), 1);
293		let result_mut = result;
294		assert_eq!(result_mut.get(0), F::ONE);
295	}
296
297	#[test]
298	fn test_eq_ind_partial_eval_single_var() {
299		// Only one query coordinate
300		let r0 = F::new(2);
301		let result = eq_ind_partial_eval::<P>(&[r0]);
302		assert_eq!(result.log_len(), 1);
303		assert_eq!(result.len(), 2);
304		let result_mut = result;
305		assert_eq!(result_mut.get(0), F::ONE - r0);
306		assert_eq!(result_mut.get(1), r0);
307	}
308
309	#[test]
310	fn test_eq_ind_partial_eval_two_vars() {
311		// Two query coordinates
312		let r0 = F::new(2);
313		let r1 = F::new(3);
314		let result = eq_ind_partial_eval::<P>(&[r0, r1]);
315		assert_eq!(result.log_len(), 2);
316		assert_eq!(result.len(), 4);
317		let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
318		let expected = vec![
319			(F::ONE - r0) * (F::ONE - r1),
320			r0 * (F::ONE - r1),
321			(F::ONE - r0) * r1,
322			r0 * r1,
323		];
324		assert_eq!(result_vec, expected);
325	}
326
327	#[test]
328	fn test_eq_ind_partial_eval_three_vars() {
329		// Case with three query coordinates
330		let r0 = F::new(2);
331		let r1 = F::new(3);
332		let r2 = F::new(5);
333		let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
334		assert_eq!(result.log_len(), 3);
335		assert_eq!(result.len(), 8);
336		let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
337
338		let expected = vec![
339			(F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
340			r0 * (F::ONE - r1) * (F::ONE - r2),
341			(F::ONE - r0) * r1 * (F::ONE - r2),
342			r0 * r1 * (F::ONE - r2),
343			(F::ONE - r0) * (F::ONE - r1) * r2,
344			r0 * (F::ONE - r1) * r2,
345			(F::ONE - r0) * r1 * r2,
346			r0 * r1 * r2,
347		];
348		assert_eq!(result_vec, expected);
349	}
350
351	// Property-based test that eq_ind_partial_eval is consistent with eq_ind at a random index.
352	#[test]
353	fn test_eq_ind_partial_eval_consistent_on_hypercube() {
354		let mut rng = StdRng::seed_from_u64(0);
355
356		let n_vars = 5;
357
358		let point = random_scalars(&mut rng, n_vars);
359		let result = eq_ind_partial_eval::<P>(&point);
360		let index = rng.random_range(..1 << n_vars);
361
362		// Query the value at that index
363		let result_mut = result;
364		let partial_eval_value = result_mut.get(index);
365
366		let index_bits = index_to_hypercube_point(n_vars, index);
367		let eq_ind_value = eq_ind(&point, &index_bits);
368
369		assert_eq!(partial_eval_value, eq_ind_value);
370	}
371
372	#[test]
373	fn test_eq_ind_truncate_low_inplace() {
374		let mut rng = StdRng::seed_from_u64(0);
375
376		let reds = 4;
377		let n_vars = reds * (reds + 1) / 2;
378		let point = random_scalars(&mut rng, n_vars);
379
380		let mut eq_ind = eq_ind_partial_eval::<P>(&point);
381		let mut log_n_values = n_vars;
382
383		for reduction in (0..=reds).rev() {
384			let truncated_log_n_values = log_n_values - reduction;
385			eq_ind_truncate_low_inplace(&mut eq_ind, truncated_log_n_values);
386
387			let eq_ind_ref = eq_ind_partial_eval::<P>(&point[..truncated_log_n_values]);
388			assert_eq!(eq_ind_ref.len(), eq_ind.len());
389			for i in 0..eq_ind.len() {
390				assert_eq!(eq_ind.get(i), eq_ind_ref.get(i));
391			}
392
393			log_n_values = truncated_log_n_values;
394		}
395
396		assert_eq!(log_n_values, 0);
397	}
398
399	#[test]
400	fn test_eq_ind_partial_eval_scalars_consistency() {
401		let mut rng = StdRng::seed_from_u64(0);
402
403		for log_n in [0, 1, 2, 5, 8] {
404			let point = random_scalars::<F>(&mut rng, log_n);
405
406			let packed_result = eq_ind_partial_eval::<P>(&point);
407			let scalar_result = eq_ind_partial_eval_scalars(&point);
408
409			let packed_scalars: Vec<F> = packed_result.iter_scalars().collect();
410			assert_eq!(packed_scalars, scalar_result, "mismatch at log_n={log_n}");
411		}
412	}
413
414	#[test]
415	fn test_tensor_prod_eq_prepend_conforms_to_append() {
416		let mut rng = StdRng::seed_from_u64(0);
417
418		let n_vars = 10;
419		let base_vars = 4;
420
421		let point = random_scalars::<F>(&mut rng, n_vars);
422
423		let append = eq_ind_partial_eval(&point);
424
425		let mut prepend = FieldBuffer::<P>::zeros_truncated(0, n_vars);
426		let (prefix, suffix) = point.split_at(n_vars - base_vars);
427		prepend.set(0, F::ONE);
428		tensor_prod_eq_ind(&mut prepend, suffix);
429		tensor_prod_eq_ind_prepend(&mut prepend, prefix);
430
431		assert_eq!(append, prepend);
432	}
433}