1use std::{iter, ops::DerefMut};
4
5use binius_field::{
6 Field, PackedField,
7 packed::{get_packed_slice, set_packed_slice},
8};
9use binius_utils::rayon::prelude::*;
10
11use crate::{Error, FieldBuffer};
12
13fn tensor_prod_eq_ind<P: PackedField, Data: DerefMut<Target = [P]>>(
40 values: &mut FieldBuffer<P, Data>,
41 extra_query_coordinates: &[P::Scalar],
42) -> Result<(), Error> {
43 let new_log_len = values.log_len() + extra_query_coordinates.len();
44
45 if values.log_cap() < new_log_len {
46 return Err(Error::IncorrectArgumentLength {
47 arg: "values.log_cap()".to_string(),
48 expected: new_log_len,
49 });
50 }
51
52 for &r_i in extra_query_coordinates {
53 let packed_r_i = P::broadcast(r_i);
54
55 values.resize(values.log_len() + 1)?;
56 let mut split = values
57 .split_half_mut_no_closure()
58 .expect("doubled by zero_extend()");
59 let (mut lo, mut hi) = split.halves();
60
61 (lo.as_mut(), hi.as_mut())
62 .into_par_iter()
63 .for_each(|(lo_i, hi_i)| {
64 let prod = (*lo_i) * packed_r_i;
65 *lo_i -= prod;
66 *hi_i = prod;
67 });
68 }
69
70 Ok(())
71}
72
73pub fn tensor_prod_eq_ind_prepend<P: PackedField, Data: DerefMut<Target = [P]>>(
84 values: &mut FieldBuffer<P, Data>,
85 extra_query_coordinates: &[P::Scalar],
86) -> Result<(), Error> {
87 let new_log_len = values.log_len() + extra_query_coordinates.len();
88
89 if values.log_cap() < new_log_len {
90 return Err(Error::IncorrectArgumentLength {
91 arg: "values.log_cap()".to_string(),
92 expected: new_log_len,
93 });
94 }
95
96 for &r_i in extra_query_coordinates.iter().rev() {
97 values.zero_extend(values.log_len() + 1)?;
98 for i in (0..values.len() / 2).rev() {
99 let eval = get_packed_slice(values.as_ref(), i);
100 set_packed_slice(values.as_mut(), 2 * i, eval * (P::Scalar::ONE - r_i));
101 set_packed_slice(values.as_mut(), 2 * i + 1, eval * r_i);
102 }
103 }
104
105 Ok(())
106}
107
108pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> FieldBuffer<P> {
124 let log_size = point.len();
127 let mut buffer = FieldBuffer::zeros_truncated(0, log_size).expect("log_size >= 0");
128 buffer.set(0, P::Scalar::ONE);
129 tensor_prod_eq_ind(&mut buffer, point).expect("buffer is allocated with the correct length");
130 buffer
131}
132
133pub fn eq_ind_truncate_low_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
140 values: &mut FieldBuffer<P, Data>,
141 truncated_log_len: usize,
142) -> Result<(), Error> {
143 if truncated_log_len > values.log_len() {
144 return Err(Error::ArgumentRangeError {
145 arg: "truncated_log_len".to_string(),
146 range: 0..values.log_len() + 1,
147 });
148 }
149
150 for log_len in (truncated_log_len..values.log_len()).rev() {
151 values.split_half_mut(|lo, hi| {
152 (lo.as_mut(), hi.as_ref())
153 .into_par_iter()
154 .for_each(|(zero, one)| {
155 *zero += *one;
156 });
157 })?;
158
159 values.truncate(log_len);
160 }
161
162 Ok(())
163}
164
165#[inline(always)]
179pub fn eq_one_var<F: Field>(x: F, y: F) -> F {
180 if F::CHARACTERISTIC == 2 {
181 x + y + F::ONE
183 } else {
184 x * y + (F::ONE - x) * (F::ONE - y)
185 }
186}
187
188pub fn eq_ind<F: Field>(x: &[F], y: &[F]) -> F {
202 assert_eq!(x.len(), y.len(), "pre-condition: x and y must be the same length");
203 iter::zip(x, y).map(|(&x, &y)| eq_one_var(x, y)).product()
204}
205
206#[cfg(test)]
207mod tests {
208 use rand::prelude::*;
209
210 use super::*;
211 use crate::test_utils::{B128, Packed128b, index_to_hypercube_point, random_scalars};
212
213 type P = Packed128b;
214 type F = B128;
215
216 #[test]
217 fn test_tensor_prod_eq_ind() {
218 let v0 = F::from(1);
219 let v1 = F::from(2);
220 let query = vec![v0, v1];
221 let mut result = FieldBuffer::zeros_truncated(0, query.len()).unwrap();
223 result.set_checked(0, F::ONE).unwrap();
224 tensor_prod_eq_ind(&mut result, &query).unwrap();
225 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
226 assert_eq!(
227 result_vec,
228 vec![
229 (F::ONE - v0) * (F::ONE - v1),
230 v0 * (F::ONE - v1),
231 (F::ONE - v0) * v1,
232 v0 * v1
233 ]
234 );
235 }
236
237 #[test]
238 fn test_tensor_prod_eq_ind_inplace_expansion() {
239 let mut rng = StdRng::seed_from_u64(0);
240
241 let exps = 4;
242 let max_n_vars = exps * (exps + 1) / 2;
243 let mut coords = Vec::with_capacity(max_n_vars);
244 let mut eq_expansion = FieldBuffer::zeros_truncated(0, max_n_vars).unwrap();
245 eq_expansion.set_checked(0, F::ONE).unwrap();
246
247 for extra_count in 1..=exps {
248 let extra = random_scalars(&mut rng, extra_count);
249
250 tensor_prod_eq_ind::<P, _>(&mut eq_expansion, &extra).unwrap();
251 coords.extend(&extra);
252
253 assert_eq!(eq_expansion.log_len(), coords.len());
254 for i in 0..eq_expansion.len() {
255 let v = eq_expansion.get_checked(i).unwrap();
256 let hypercube_point = index_to_hypercube_point(coords.len(), i);
257 assert_eq!(v, eq_ind(&hypercube_point, &coords));
258 }
259 }
260 }
261
262 #[test]
263 fn test_eq_ind_partial_eval_empty() {
264 let result = eq_ind_partial_eval::<P>(&[]);
265 assert_eq!(result.log_len(), 0);
267 assert_eq!(result.len(), 1);
268 let result_mut = result;
269 assert_eq!(result_mut.get_checked(0).unwrap(), F::ONE);
270 }
271
272 #[test]
273 fn test_eq_ind_partial_eval_single_var() {
274 let r0 = F::new(2);
276 let result = eq_ind_partial_eval::<P>(&[r0]);
277 assert_eq!(result.log_len(), 1);
278 assert_eq!(result.len(), 2);
279 let result_mut = result;
280 assert_eq!(result_mut.get_checked(0).unwrap(), F::ONE - r0);
281 assert_eq!(result_mut.get_checked(1).unwrap(), r0);
282 }
283
284 #[test]
285 fn test_eq_ind_partial_eval_two_vars() {
286 let r0 = F::new(2);
288 let r1 = F::new(3);
289 let result = eq_ind_partial_eval::<P>(&[r0, r1]);
290 assert_eq!(result.log_len(), 2);
291 assert_eq!(result.len(), 4);
292 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
293 let expected = vec![
294 (F::ONE - r0) * (F::ONE - r1),
295 r0 * (F::ONE - r1),
296 (F::ONE - r0) * r1,
297 r0 * r1,
298 ];
299 assert_eq!(result_vec, expected);
300 }
301
302 #[test]
303 fn test_eq_ind_partial_eval_three_vars() {
304 let r0 = F::new(2);
306 let r1 = F::new(3);
307 let r2 = F::new(5);
308 let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
309 assert_eq!(result.log_len(), 3);
310 assert_eq!(result.len(), 8);
311 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
312
313 let expected = vec![
314 (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
315 r0 * (F::ONE - r1) * (F::ONE - r2),
316 (F::ONE - r0) * r1 * (F::ONE - r2),
317 r0 * r1 * (F::ONE - r2),
318 (F::ONE - r0) * (F::ONE - r1) * r2,
319 r0 * (F::ONE - r1) * r2,
320 (F::ONE - r0) * r1 * r2,
321 r0 * r1 * r2,
322 ];
323 assert_eq!(result_vec, expected);
324 }
325
326 #[test]
328 fn test_eq_ind_partial_eval_consistent_on_hypercube() {
329 let mut rng = StdRng::seed_from_u64(0);
330
331 let n_vars = 5;
332
333 let point = random_scalars(&mut rng, n_vars);
334 let result = eq_ind_partial_eval::<P>(&point);
335 let index = rng.random_range(..1 << n_vars);
336
337 let result_mut = result;
339 let partial_eval_value = result_mut.get_checked(index).unwrap();
340
341 let index_bits = index_to_hypercube_point(n_vars, index);
342 let eq_ind_value = eq_ind(&point, &index_bits);
343
344 assert_eq!(partial_eval_value, eq_ind_value);
345 }
346
347 #[test]
348 fn test_eq_ind_truncate_low_inplace() {
349 let mut rng = StdRng::seed_from_u64(0);
350
351 let reds = 4;
352 let n_vars = reds * (reds + 1) / 2;
353 let point = random_scalars(&mut rng, n_vars);
354
355 let mut eq_ind = eq_ind_partial_eval::<P>(&point);
356 let mut log_n_values = n_vars;
357
358 for reduction in (0..=reds).rev() {
359 let truncated_log_n_values = log_n_values - reduction;
360 eq_ind_truncate_low_inplace(&mut eq_ind, truncated_log_n_values).unwrap();
361
362 let eq_ind_ref = eq_ind_partial_eval::<P>(&point[..truncated_log_n_values]);
363 assert_eq!(eq_ind_ref.len(), eq_ind.len());
364 for i in 0..eq_ind.len() {
365 assert_eq!(eq_ind.get_checked(i).unwrap(), eq_ind_ref.get_checked(i).unwrap());
366 }
367
368 log_n_values = truncated_log_n_values;
369 }
370
371 assert_eq!(log_n_values, 0);
372 }
373
374 #[test]
375 fn test_tensor_prod_eq_prepend_conforms_to_append() {
376 let mut rng = StdRng::seed_from_u64(0);
377
378 let n_vars = 10;
379 let base_vars = 4;
380
381 let point = random_scalars::<F>(&mut rng, n_vars);
382
383 let append = eq_ind_partial_eval(&point);
384
385 let mut prepend = FieldBuffer::<P>::zeros_truncated(0, n_vars).unwrap();
386 let (prefix, suffix) = point.split_at(n_vars - base_vars);
387 prepend.set_checked(0, F::ONE).unwrap();
388 tensor_prod_eq_ind(&mut prepend, suffix).unwrap();
389 tensor_prod_eq_ind_prepend(&mut prepend, prefix).unwrap();
390
391 assert_eq!(append, prepend);
392 }
393}