1use std::{iter, ops::DerefMut};
4
5use binius_field::{
6 Field, PackedField,
7 packed::{get_packed_slice, set_packed_slice},
8};
9use binius_utils::rayon::prelude::*;
10
11use crate::{Error, FieldBuffer};
12
13fn tensor_prod_eq_ind<P: PackedField, Data: DerefMut<Target = [P]>>(
40 values: &mut FieldBuffer<P, Data>,
41 extra_query_coordinates: &[P::Scalar],
42) -> Result<(), Error> {
43 let new_log_len = values.log_len() + extra_query_coordinates.len();
44
45 if values.log_cap() < new_log_len {
46 return Err(Error::IncorrectArgumentLength {
47 arg: "values.log_cap()".to_string(),
48 expected: new_log_len,
49 });
50 }
51
52 for &r_i in extra_query_coordinates {
53 let packed_r_i = P::broadcast(r_i);
54
55 values.resize(values.log_len() + 1)?;
56 let mut split = values
57 .split_half_mut_no_closure()
58 .expect("doubled by zero_extend()");
59 let (mut lo, mut hi) = split.halves();
60
61 (lo.as_mut(), hi.as_mut())
62 .into_par_iter()
63 .for_each(|(lo_i, hi_i)| {
64 let prod = (*lo_i) * packed_r_i;
65 *lo_i -= prod;
66 *hi_i = prod;
67 });
68 }
69
70 Ok(())
71}
72
73pub fn tensor_prod_eq_ind_prepend<P: PackedField, Data: DerefMut<Target = [P]>>(
84 values: &mut FieldBuffer<P, Data>,
85 extra_query_coordinates: &[P::Scalar],
86) -> Result<(), Error> {
87 let new_log_len = values.log_len() + extra_query_coordinates.len();
88
89 if values.log_cap() < new_log_len {
90 return Err(Error::IncorrectArgumentLength {
91 arg: "values.log_cap()".to_string(),
92 expected: new_log_len,
93 });
94 }
95
96 for &r_i in extra_query_coordinates.iter().rev() {
97 values.zero_extend(values.log_len() + 1)?;
98 for i in (0..values.len() / 2).rev() {
99 let eval = get_packed_slice(values.as_ref(), i);
100 set_packed_slice(values.as_mut(), 2 * i, eval * (P::Scalar::ONE - r_i));
101 set_packed_slice(values.as_mut(), 2 * i + 1, eval * r_i);
102 }
103 }
104
105 Ok(())
106}
107
108pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> FieldBuffer<P> {
124 let log_size = point.len();
127 let mut buffer = FieldBuffer::zeros_truncated(0, log_size).expect("log_size >= 0");
128 buffer
129 .set(0, P::Scalar::ONE)
130 .expect("buffer has length exactly 1");
131 tensor_prod_eq_ind(&mut buffer, point).expect("buffer is allocated with the correct length");
132 buffer
133}
134
135pub fn eq_ind_truncate_low_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
142 values: &mut FieldBuffer<P, Data>,
143 truncated_log_len: usize,
144) -> Result<(), Error> {
145 if truncated_log_len > values.log_len() {
146 return Err(Error::ArgumentRangeError {
147 arg: "truncated_log_len".to_string(),
148 range: 0..values.log_len() + 1,
149 });
150 }
151
152 for log_len in (truncated_log_len..values.log_len()).rev() {
153 values.split_half_mut(|lo, hi| {
154 (lo.as_mut(), hi.as_ref())
155 .into_par_iter()
156 .for_each(|(zero, one)| {
157 *zero += *one;
158 });
159 })?;
160
161 values.truncate(log_len);
162 }
163
164 Ok(())
165}
166
167#[inline(always)]
181pub fn eq_one_var<F: Field>(x: F, y: F) -> F {
182 if F::CHARACTERISTIC == 2 {
183 x + y + F::ONE
185 } else {
186 x * y + (F::ONE - x) * (F::ONE - y)
187 }
188}
189
190pub fn eq_ind<F: Field>(x: &[F], y: &[F]) -> F {
204 assert_eq!(x.len(), y.len(), "pre-condition: x and y must be the same length");
205 iter::zip(x, y).map(|(&x, &y)| eq_one_var(x, y)).product()
206}
207
208#[cfg(test)]
209mod tests {
210 use rand::prelude::*;
211
212 use super::*;
213 use crate::test_utils::{B128, Packed128b, index_to_hypercube_point, random_scalars};
214
215 type P = Packed128b;
216 type F = B128;
217
218 #[test]
219 fn test_tensor_prod_eq_ind() {
220 let v0 = F::from(1);
221 let v1 = F::from(2);
222 let query = vec![v0, v1];
223 let mut result = FieldBuffer::zeros_truncated(0, query.len()).unwrap();
225 result.set(0, F::ONE).unwrap();
226 tensor_prod_eq_ind(&mut result, &query).unwrap();
227 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
228 assert_eq!(
229 result_vec,
230 vec![
231 (F::ONE - v0) * (F::ONE - v1),
232 v0 * (F::ONE - v1),
233 (F::ONE - v0) * v1,
234 v0 * v1
235 ]
236 );
237 }
238
239 #[test]
240 fn test_tensor_prod_eq_ind_inplace_expansion() {
241 let mut rng = StdRng::seed_from_u64(0);
242
243 let exps = 4;
244 let max_n_vars = exps * (exps + 1) / 2;
245 let mut coords = Vec::with_capacity(max_n_vars);
246 let mut eq_expansion = FieldBuffer::zeros_truncated(0, max_n_vars).unwrap();
247 eq_expansion.set(0, F::ONE).unwrap();
248
249 for extra_count in 1..=exps {
250 let extra = random_scalars(&mut rng, extra_count);
251
252 tensor_prod_eq_ind::<P, _>(&mut eq_expansion, &extra).unwrap();
253 coords.extend(&extra);
254
255 assert_eq!(eq_expansion.log_len(), coords.len());
256 for i in 0..eq_expansion.len() {
257 let v = eq_expansion.get(i).unwrap();
258 let hypercube_point = index_to_hypercube_point(coords.len(), i);
259 assert_eq!(v, eq_ind(&hypercube_point, &coords));
260 }
261 }
262 }
263
264 #[test]
265 fn test_eq_ind_partial_eval_empty() {
266 let result = eq_ind_partial_eval::<P>(&[]);
267 assert_eq!(result.log_len(), 0);
269 assert_eq!(result.len(), 1);
270 let result_mut = result;
271 assert_eq!(result_mut.get(0).unwrap(), F::ONE);
272 }
273
274 #[test]
275 fn test_eq_ind_partial_eval_single_var() {
276 let r0 = F::new(2);
278 let result = eq_ind_partial_eval::<P>(&[r0]);
279 assert_eq!(result.log_len(), 1);
280 assert_eq!(result.len(), 2);
281 let result_mut = result;
282 assert_eq!(result_mut.get(0).unwrap(), F::ONE - r0);
283 assert_eq!(result_mut.get(1).unwrap(), r0);
284 }
285
286 #[test]
287 fn test_eq_ind_partial_eval_two_vars() {
288 let r0 = F::new(2);
290 let r1 = F::new(3);
291 let result = eq_ind_partial_eval::<P>(&[r0, r1]);
292 assert_eq!(result.log_len(), 2);
293 assert_eq!(result.len(), 4);
294 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
295 let expected = vec![
296 (F::ONE - r0) * (F::ONE - r1),
297 r0 * (F::ONE - r1),
298 (F::ONE - r0) * r1,
299 r0 * r1,
300 ];
301 assert_eq!(result_vec, expected);
302 }
303
304 #[test]
305 fn test_eq_ind_partial_eval_three_vars() {
306 let r0 = F::new(2);
308 let r1 = F::new(3);
309 let r2 = F::new(5);
310 let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
311 assert_eq!(result.log_len(), 3);
312 assert_eq!(result.len(), 8);
313 let result_vec: Vec<F> = P::iter_slice(result.as_ref()).collect();
314
315 let expected = vec![
316 (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
317 r0 * (F::ONE - r1) * (F::ONE - r2),
318 (F::ONE - r0) * r1 * (F::ONE - r2),
319 r0 * r1 * (F::ONE - r2),
320 (F::ONE - r0) * (F::ONE - r1) * r2,
321 r0 * (F::ONE - r1) * r2,
322 (F::ONE - r0) * r1 * r2,
323 r0 * r1 * r2,
324 ];
325 assert_eq!(result_vec, expected);
326 }
327
328 #[test]
330 fn test_eq_ind_partial_eval_consistent_on_hypercube() {
331 let mut rng = StdRng::seed_from_u64(0);
332
333 let n_vars = 5;
334
335 let point = random_scalars(&mut rng, n_vars);
336 let result = eq_ind_partial_eval::<P>(&point);
337 let index = rng.random_range(..1 << n_vars);
338
339 let result_mut = result;
341 let partial_eval_value = result_mut.get(index).unwrap();
342
343 let index_bits = index_to_hypercube_point(n_vars, index);
344 let eq_ind_value = eq_ind(&point, &index_bits);
345
346 assert_eq!(partial_eval_value, eq_ind_value);
347 }
348
349 #[test]
350 fn test_eq_ind_truncate_low_inplace() {
351 let mut rng = StdRng::seed_from_u64(0);
352
353 let reds = 4;
354 let n_vars = reds * (reds + 1) / 2;
355 let point = random_scalars(&mut rng, n_vars);
356
357 let mut eq_ind = eq_ind_partial_eval::<P>(&point);
358 let mut log_n_values = n_vars;
359
360 for reduction in (0..=reds).rev() {
361 let truncated_log_n_values = log_n_values - reduction;
362 eq_ind_truncate_low_inplace(&mut eq_ind, truncated_log_n_values).unwrap();
363
364 let eq_ind_ref = eq_ind_partial_eval::<P>(&point[..truncated_log_n_values]);
365 assert_eq!(eq_ind_ref.len(), eq_ind.len());
366 for i in 0..eq_ind.len() {
367 assert_eq!(eq_ind.get(i).unwrap(), eq_ind_ref.get(i).unwrap());
368 }
369
370 log_n_values = truncated_log_n_values;
371 }
372
373 assert_eq!(log_n_values, 0);
374 }
375
376 #[test]
377 fn test_tensor_prod_eq_prepend_conforms_to_append() {
378 let mut rng = StdRng::seed_from_u64(0);
379
380 let n_vars = 10;
381 let base_vars = 4;
382
383 let point = random_scalars::<F>(&mut rng, n_vars);
384
385 let append = eq_ind_partial_eval(&point);
386
387 let mut prepend = FieldBuffer::<P>::zeros_truncated(0, n_vars).unwrap();
388 let (prefix, suffix) = point.split_at(n_vars - base_vars);
389 prepend.set(0, F::ONE).unwrap();
390 tensor_prod_eq_ind(&mut prepend, suffix).unwrap();
391 tensor_prod_eq_ind_prepend(&mut prepend, prefix).unwrap();
392
393 assert_eq!(append, prepend);
394 }
395}