1use std::{iter, ops::DerefMut};
4
5use binius_field::{
6 Field, PackedField,
7 field::FieldOps,
8 packed::{get_packed_slice, set_packed_slice},
9};
10use binius_utils::rayon::prelude::*;
11
12use crate::FieldBuffer;
13
14fn tensor_prod_eq_ind<P: PackedField, Data: DerefMut<Target = [P]>>(
44 values: &mut FieldBuffer<P, Data>,
45 extra_query_coordinates: &[P::Scalar],
46) {
47 let new_log_len = values.log_len() + extra_query_coordinates.len();
48
49 assert!(
50 values.log_cap() >= new_log_len,
51 "precondition: values capacity must be sufficient for expansion"
52 );
53
54 for &r_i in extra_query_coordinates {
55 let packed_r_i = P::broadcast(r_i);
56
57 values.resize(values.log_len() + 1);
58 let mut split = values.split_half_mut();
59 let (mut lo, mut hi) = split.halves();
60
61 (lo.as_mut(), hi.as_mut())
62 .into_par_iter()
63 .for_each(|(lo_i, hi_i)| {
64 let prod = (*lo_i) * packed_r_i;
65 *lo_i -= prod;
66 *hi_i = prod;
67 });
68 }
69}
70
71pub fn tensor_prod_eq_ind_prepend<P: PackedField, Data: DerefMut<Target = [P]>>(
87 values: &mut FieldBuffer<P, Data>,
88 extra_query_coordinates: &[P::Scalar],
89) {
90 let new_log_len = values.log_len() + extra_query_coordinates.len();
91
92 assert!(
93 values.log_cap() >= new_log_len,
94 "precondition: values capacity must be sufficient for expansion"
95 );
96
97 for &r_i in extra_query_coordinates.iter().rev() {
98 values.zero_extend(values.log_len() + 1);
99 for i in (0..values.len() / 2).rev() {
100 let eval = get_packed_slice(values.as_ref(), i);
101 set_packed_slice(values.as_mut(), 2 * i, eval * (P::Scalar::ONE - r_i));
102 set_packed_slice(values.as_mut(), 2 * i + 1, eval * r_i);
103 }
104 }
105}
106
107pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> FieldBuffer<P> {
123 let log_size = point.len();
126 let mut buffer = FieldBuffer::zeros_truncated(0, log_size);
127 buffer.set(0, P::Scalar::ONE);
128 tensor_prod_eq_ind(&mut buffer, point);
129 buffer
130}
131
132pub fn eq_ind_truncate_low_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
143 values: &mut FieldBuffer<P, Data>,
144 truncated_log_len: usize,
145) {
146 assert!(
147 truncated_log_len <= values.log_len(),
148 "precondition: truncated_log_len must be at most values.log_len()"
149 );
150
151 for log_len in (truncated_log_len..values.log_len()).rev() {
152 {
153 let mut split = values.split_half_mut();
154 let (mut lo, hi) = split.halves();
155 (lo.as_mut(), hi.as_ref())
156 .into_par_iter()
157 .for_each(|(zero, one)| {
158 *zero += *one;
159 });
160 }
161
162 values.truncate(log_len);
163 }
164}
165
166#[inline(always)]
180pub fn eq_one_var<F: FieldOps>(x: F, y: F) -> F {
181 let one = F::one();
182 x.clone() * y.clone() + (one.clone() - x) * (one.clone() - y)
183}
184
185pub fn eq_ind<F: FieldOps>(x: &[F], y: &[F]) -> F {
199 assert_eq!(x.len(), y.len(), "pre-condition: x and y must be the same length");
200 iter::zip(x, y)
201 .map(|(x, y)| eq_one_var(x.clone(), y.clone()))
202 .product()
203}
204
205pub fn eq_ind_partial_eval_scalars<F: FieldOps>(point: &[F]) -> Vec<F> {
214 let mut result = Vec::with_capacity(1 << point.len());
215 result.push(F::one());
216
217 for r_i in point {
218 let len = result.len();
222 for j in 0..len {
223 let prod = result[j].clone() * r_i.clone();
224 result[j] -= prod.clone();
225 result.push(prod);
226 }
227 }
228 result
229}
230
231#[cfg(test)]
232mod tests {
233 use rand::prelude::*;
234
235 use super::*;
236 use crate::test_utils::{B128, Packed128b, index_to_hypercube_point, random_scalars};
237
238 type P = Packed128b;
239 type F = B128;
240
241 #[test]
242 fn test_tensor_prod_eq_ind() {
243 let v0 = F::from(1);
244 let v1 = F::from(2);
245 let query = vec![v0, v1];
246 let mut result = FieldBuffer::zeros_truncated(0, query.len());
248 result.set(0, F::ONE);
249 tensor_prod_eq_ind(&mut result, &query);
250 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
251 assert_eq!(
252 result_vec,
253 vec![
254 (F::ONE - v0) * (F::ONE - v1),
255 v0 * (F::ONE - v1),
256 (F::ONE - v0) * v1,
257 v0 * v1
258 ]
259 );
260 }
261
262 #[test]
263 fn test_tensor_prod_eq_ind_inplace_expansion() {
264 let mut rng = StdRng::seed_from_u64(0);
265
266 let exps = 4;
267 let max_n_vars = exps * (exps + 1) / 2;
268 let mut coords = Vec::with_capacity(max_n_vars);
269 let mut eq_expansion = FieldBuffer::zeros_truncated(0, max_n_vars);
270 eq_expansion.set(0, F::ONE);
271
272 for extra_count in 1..=exps {
273 let extra = random_scalars(&mut rng, extra_count);
274
275 tensor_prod_eq_ind::<P, _>(&mut eq_expansion, &extra);
276 coords.extend(&extra);
277
278 assert_eq!(eq_expansion.log_len(), coords.len());
279 for i in 0..eq_expansion.len() {
280 let v = eq_expansion.get(i);
281 let hypercube_point = index_to_hypercube_point(coords.len(), i);
282 assert_eq!(v, eq_ind(&hypercube_point, &coords));
283 }
284 }
285 }
286
287 #[test]
288 fn test_eq_ind_partial_eval_empty() {
289 let result = eq_ind_partial_eval::<P>(&[]);
290 assert_eq!(result.log_len(), 0);
292 assert_eq!(result.len(), 1);
293 let result_mut = result;
294 assert_eq!(result_mut.get(0), F::ONE);
295 }
296
297 #[test]
298 fn test_eq_ind_partial_eval_single_var() {
299 let r0 = F::new(2);
301 let result = eq_ind_partial_eval::<P>(&[r0]);
302 assert_eq!(result.log_len(), 1);
303 assert_eq!(result.len(), 2);
304 let result_mut = result;
305 assert_eq!(result_mut.get(0), F::ONE - r0);
306 assert_eq!(result_mut.get(1), r0);
307 }
308
309 #[test]
310 fn test_eq_ind_partial_eval_two_vars() {
311 let r0 = F::new(2);
313 let r1 = F::new(3);
314 let result = eq_ind_partial_eval::<P>(&[r0, r1]);
315 assert_eq!(result.log_len(), 2);
316 assert_eq!(result.len(), 4);
317 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
318 let expected = vec![
319 (F::ONE - r0) * (F::ONE - r1),
320 r0 * (F::ONE - r1),
321 (F::ONE - r0) * r1,
322 r0 * r1,
323 ];
324 assert_eq!(result_vec, expected);
325 }
326
327 #[test]
328 fn test_eq_ind_partial_eval_three_vars() {
329 let r0 = F::new(2);
331 let r1 = F::new(3);
332 let r2 = F::new(5);
333 let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
334 assert_eq!(result.log_len(), 3);
335 assert_eq!(result.len(), 8);
336 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
337
338 let expected = vec![
339 (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
340 r0 * (F::ONE - r1) * (F::ONE - r2),
341 (F::ONE - r0) * r1 * (F::ONE - r2),
342 r0 * r1 * (F::ONE - r2),
343 (F::ONE - r0) * (F::ONE - r1) * r2,
344 r0 * (F::ONE - r1) * r2,
345 (F::ONE - r0) * r1 * r2,
346 r0 * r1 * r2,
347 ];
348 assert_eq!(result_vec, expected);
349 }
350
351 #[test]
353 fn test_eq_ind_partial_eval_consistent_on_hypercube() {
354 let mut rng = StdRng::seed_from_u64(0);
355
356 let n_vars = 5;
357
358 let point = random_scalars(&mut rng, n_vars);
359 let result = eq_ind_partial_eval::<P>(&point);
360 let index = rng.random_range(..1 << n_vars);
361
362 let result_mut = result;
364 let partial_eval_value = result_mut.get(index);
365
366 let index_bits = index_to_hypercube_point(n_vars, index);
367 let eq_ind_value = eq_ind(&point, &index_bits);
368
369 assert_eq!(partial_eval_value, eq_ind_value);
370 }
371
372 #[test]
373 fn test_eq_ind_truncate_low_inplace() {
374 let mut rng = StdRng::seed_from_u64(0);
375
376 let reds = 4;
377 let n_vars = reds * (reds + 1) / 2;
378 let point = random_scalars(&mut rng, n_vars);
379
380 let mut eq_ind = eq_ind_partial_eval::<P>(&point);
381 let mut log_n_values = n_vars;
382
383 for reduction in (0..=reds).rev() {
384 let truncated_log_n_values = log_n_values - reduction;
385 eq_ind_truncate_low_inplace(&mut eq_ind, truncated_log_n_values);
386
387 let eq_ind_ref = eq_ind_partial_eval::<P>(&point[..truncated_log_n_values]);
388 assert_eq!(eq_ind_ref.len(), eq_ind.len());
389 for i in 0..eq_ind.len() {
390 assert_eq!(eq_ind.get(i), eq_ind_ref.get(i));
391 }
392
393 log_n_values = truncated_log_n_values;
394 }
395
396 assert_eq!(log_n_values, 0);
397 }
398
399 #[test]
400 fn test_eq_ind_partial_eval_scalars_consistency() {
401 let mut rng = StdRng::seed_from_u64(0);
402
403 for log_n in [0, 1, 2, 5, 8] {
404 let point = random_scalars::<F>(&mut rng, log_n);
405
406 let packed_result = eq_ind_partial_eval::<P>(&point);
407 let scalar_result = eq_ind_partial_eval_scalars(&point);
408
409 let packed_scalars: Vec<F> = packed_result.iter_scalars().collect();
410 assert_eq!(packed_scalars, scalar_result, "mismatch at log_n={log_n}");
411 }
412 }
413
414 #[test]
415 fn test_tensor_prod_eq_prepend_conforms_to_append() {
416 let mut rng = StdRng::seed_from_u64(0);
417
418 let n_vars = 10;
419 let base_vars = 4;
420
421 let point = random_scalars::<F>(&mut rng, n_vars);
422
423 let append = eq_ind_partial_eval(&point);
424
425 let mut prepend = FieldBuffer::<P>::zeros_truncated(0, n_vars);
426 let (prefix, suffix) = point.split_at(n_vars - base_vars);
427 prepend.set(0, F::ONE);
428 tensor_prod_eq_ind(&mut prepend, suffix);
429 tensor_prod_eq_ind_prepend(&mut prepend, prefix);
430
431 assert_eq!(append, prepend);
432 }
433}