1use std::{iter, ops::DerefMut};
4
5use binius_field::{
6 Field, PackedField,
7 packed::{get_packed_slice, set_packed_slice},
8};
9use binius_utils::rayon::prelude::*;
10
11use crate::FieldBuffer;
12
13fn tensor_prod_eq_ind<P: PackedField, Data: DerefMut<Target = [P]>>(
43 values: &mut FieldBuffer<P, Data>,
44 extra_query_coordinates: &[P::Scalar],
45) {
46 let new_log_len = values.log_len() + extra_query_coordinates.len();
47
48 assert!(
49 values.log_cap() >= new_log_len,
50 "precondition: values capacity must be sufficient for expansion"
51 );
52
53 for &r_i in extra_query_coordinates {
54 let packed_r_i = P::broadcast(r_i);
55
56 values.resize(values.log_len() + 1);
57 let mut split = values.split_half_mut();
58 let (mut lo, mut hi) = split.halves();
59
60 (lo.as_mut(), hi.as_mut())
61 .into_par_iter()
62 .for_each(|(lo_i, hi_i)| {
63 let prod = (*lo_i) * packed_r_i;
64 *lo_i -= prod;
65 *hi_i = prod;
66 });
67 }
68}
69
70pub fn tensor_prod_eq_ind_prepend<P: PackedField, Data: DerefMut<Target = [P]>>(
86 values: &mut FieldBuffer<P, Data>,
87 extra_query_coordinates: &[P::Scalar],
88) {
89 let new_log_len = values.log_len() + extra_query_coordinates.len();
90
91 assert!(
92 values.log_cap() >= new_log_len,
93 "precondition: values capacity must be sufficient for expansion"
94 );
95
96 for &r_i in extra_query_coordinates.iter().rev() {
97 values.zero_extend(values.log_len() + 1);
98 for i in (0..values.len() / 2).rev() {
99 let eval = get_packed_slice(values.as_ref(), i);
100 set_packed_slice(values.as_mut(), 2 * i, eval * (P::Scalar::ONE - r_i));
101 set_packed_slice(values.as_mut(), 2 * i + 1, eval * r_i);
102 }
103 }
104}
105
106pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> FieldBuffer<P> {
122 let log_size = point.len();
125 let mut buffer = FieldBuffer::zeros_truncated(0, log_size);
126 buffer.set(0, P::Scalar::ONE);
127 tensor_prod_eq_ind(&mut buffer, point);
128 buffer
129}
130
131pub fn eq_ind_truncate_low_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
142 values: &mut FieldBuffer<P, Data>,
143 truncated_log_len: usize,
144) {
145 assert!(
146 truncated_log_len <= values.log_len(),
147 "precondition: truncated_log_len must be at most values.log_len()"
148 );
149
150 for log_len in (truncated_log_len..values.log_len()).rev() {
151 {
152 let mut split = values.split_half_mut();
153 let (mut lo, hi) = split.halves();
154 (lo.as_mut(), hi.as_ref())
155 .into_par_iter()
156 .for_each(|(zero, one)| {
157 *zero += *one;
158 });
159 }
160
161 values.truncate(log_len);
162 }
163}
164
165#[inline(always)]
179pub fn eq_one_var<F: Field>(x: F, y: F) -> F {
180 if F::CHARACTERISTIC == 2 {
181 x + y + F::ONE
183 } else {
184 x * y + (F::ONE - x) * (F::ONE - y)
185 }
186}
187
188pub fn eq_ind<F: Field>(x: &[F], y: &[F]) -> F {
202 assert_eq!(x.len(), y.len(), "pre-condition: x and y must be the same length");
203 iter::zip(x, y).map(|(&x, &y)| eq_one_var(x, y)).product()
204}
205
206#[cfg(test)]
207mod tests {
208 use rand::prelude::*;
209
210 use super::*;
211 use crate::test_utils::{B128, Packed128b, index_to_hypercube_point, random_scalars};
212
213 type P = Packed128b;
214 type F = B128;
215
216 #[test]
217 fn test_tensor_prod_eq_ind() {
218 let v0 = F::from(1);
219 let v1 = F::from(2);
220 let query = vec![v0, v1];
221 let mut result = FieldBuffer::zeros_truncated(0, query.len());
223 result.set(0, F::ONE);
224 tensor_prod_eq_ind(&mut result, &query);
225 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
226 assert_eq!(
227 result_vec,
228 vec![
229 (F::ONE - v0) * (F::ONE - v1),
230 v0 * (F::ONE - v1),
231 (F::ONE - v0) * v1,
232 v0 * v1
233 ]
234 );
235 }
236
237 #[test]
238 fn test_tensor_prod_eq_ind_inplace_expansion() {
239 let mut rng = StdRng::seed_from_u64(0);
240
241 let exps = 4;
242 let max_n_vars = exps * (exps + 1) / 2;
243 let mut coords = Vec::with_capacity(max_n_vars);
244 let mut eq_expansion = FieldBuffer::zeros_truncated(0, max_n_vars);
245 eq_expansion.set(0, F::ONE);
246
247 for extra_count in 1..=exps {
248 let extra = random_scalars(&mut rng, extra_count);
249
250 tensor_prod_eq_ind::<P, _>(&mut eq_expansion, &extra);
251 coords.extend(&extra);
252
253 assert_eq!(eq_expansion.log_len(), coords.len());
254 for i in 0..eq_expansion.len() {
255 let v = eq_expansion.get(i);
256 let hypercube_point = index_to_hypercube_point(coords.len(), i);
257 assert_eq!(v, eq_ind(&hypercube_point, &coords));
258 }
259 }
260 }
261
262 #[test]
263 fn test_eq_ind_partial_eval_empty() {
264 let result = eq_ind_partial_eval::<P>(&[]);
265 assert_eq!(result.log_len(), 0);
267 assert_eq!(result.len(), 1);
268 let result_mut = result;
269 assert_eq!(result_mut.get(0), F::ONE);
270 }
271
272 #[test]
273 fn test_eq_ind_partial_eval_single_var() {
274 let r0 = F::new(2);
276 let result = eq_ind_partial_eval::<P>(&[r0]);
277 assert_eq!(result.log_len(), 1);
278 assert_eq!(result.len(), 2);
279 let result_mut = result;
280 assert_eq!(result_mut.get(0), F::ONE - r0);
281 assert_eq!(result_mut.get(1), r0);
282 }
283
284 #[test]
285 fn test_eq_ind_partial_eval_two_vars() {
286 let r0 = F::new(2);
288 let r1 = F::new(3);
289 let result = eq_ind_partial_eval::<P>(&[r0, r1]);
290 assert_eq!(result.log_len(), 2);
291 assert_eq!(result.len(), 4);
292 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
293 let expected = vec![
294 (F::ONE - r0) * (F::ONE - r1),
295 r0 * (F::ONE - r1),
296 (F::ONE - r0) * r1,
297 r0 * r1,
298 ];
299 assert_eq!(result_vec, expected);
300 }
301
302 #[test]
303 fn test_eq_ind_partial_eval_three_vars() {
304 let r0 = F::new(2);
306 let r1 = F::new(3);
307 let r2 = F::new(5);
308 let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
309 assert_eq!(result.log_len(), 3);
310 assert_eq!(result.len(), 8);
311 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
312
313 let expected = vec![
314 (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
315 r0 * (F::ONE - r1) * (F::ONE - r2),
316 (F::ONE - r0) * r1 * (F::ONE - r2),
317 r0 * r1 * (F::ONE - r2),
318 (F::ONE - r0) * (F::ONE - r1) * r2,
319 r0 * (F::ONE - r1) * r2,
320 (F::ONE - r0) * r1 * r2,
321 r0 * r1 * r2,
322 ];
323 assert_eq!(result_vec, expected);
324 }
325
326 #[test]
328 fn test_eq_ind_partial_eval_consistent_on_hypercube() {
329 let mut rng = StdRng::seed_from_u64(0);
330
331 let n_vars = 5;
332
333 let point = random_scalars(&mut rng, n_vars);
334 let result = eq_ind_partial_eval::<P>(&point);
335 let index = rng.random_range(..1 << n_vars);
336
337 let result_mut = result;
339 let partial_eval_value = result_mut.get(index);
340
341 let index_bits = index_to_hypercube_point(n_vars, index);
342 let eq_ind_value = eq_ind(&point, &index_bits);
343
344 assert_eq!(partial_eval_value, eq_ind_value);
345 }
346
347 #[test]
348 fn test_eq_ind_truncate_low_inplace() {
349 let mut rng = StdRng::seed_from_u64(0);
350
351 let reds = 4;
352 let n_vars = reds * (reds + 1) / 2;
353 let point = random_scalars(&mut rng, n_vars);
354
355 let mut eq_ind = eq_ind_partial_eval::<P>(&point);
356 let mut log_n_values = n_vars;
357
358 for reduction in (0..=reds).rev() {
359 let truncated_log_n_values = log_n_values - reduction;
360 eq_ind_truncate_low_inplace(&mut eq_ind, truncated_log_n_values);
361
362 let eq_ind_ref = eq_ind_partial_eval::<P>(&point[..truncated_log_n_values]);
363 assert_eq!(eq_ind_ref.len(), eq_ind.len());
364 for i in 0..eq_ind.len() {
365 assert_eq!(eq_ind.get(i), eq_ind_ref.get(i));
366 }
367
368 log_n_values = truncated_log_n_values;
369 }
370
371 assert_eq!(log_n_values, 0);
372 }
373
374 #[test]
375 fn test_tensor_prod_eq_prepend_conforms_to_append() {
376 let mut rng = StdRng::seed_from_u64(0);
377
378 let n_vars = 10;
379 let base_vars = 4;
380
381 let point = random_scalars::<F>(&mut rng, n_vars);
382
383 let append = eq_ind_partial_eval(&point);
384
385 let mut prepend = FieldBuffer::<P>::zeros_truncated(0, n_vars);
386 let (prefix, suffix) = point.split_at(n_vars - base_vars);
387 prepend.set(0, F::ONE);
388 tensor_prod_eq_ind(&mut prepend, suffix);
389 tensor_prod_eq_ind_prepend(&mut prepend, prefix);
390
391 assert_eq!(append, prepend);
392 }
393}