binius_math/
tensor_prod_eq_ind.rs1use std::cmp::max;
4
5use binius_field::{Field, PackedField};
6use binius_maybe_rayon::prelude::*;
7use binius_utils::bail;
8use bytemuck::zeroed_vec;
9
10use crate::Error;
11
12pub fn tensor_prod_eq_ind<P: PackedField>(
35 log_n_values: usize,
36 packed_values: &mut [P],
37 extra_query_coordinates: &[P::Scalar],
38) -> Result<(), Error> {
39 let new_n_vars = log_n_values + extra_query_coordinates.len();
40 if packed_values.len() != max(1, (1 << new_n_vars) / P::WIDTH) {
41 bail!(Error::InvalidPackedValuesLength);
42 }
43
44 for (i, r_i) in extra_query_coordinates.iter().enumerate() {
45 let prev_length = 1 << (log_n_values + i);
46 if prev_length < P::WIDTH {
47 let q = &mut packed_values[0];
48 for h in 0..prev_length {
49 let x = q.get(h);
50 let prod = x * r_i;
51 q.set(h, x - prod);
52 q.set(prev_length | h, prod);
53 }
54 } else {
55 let prev_packed_length = prev_length / P::WIDTH;
56 let packed_r_i = P::broadcast(*r_i);
57 let (xs, ys) = packed_values.split_at_mut(prev_packed_length);
58 assert!(xs.len() <= ys.len());
59
60 xs.par_iter_mut()
63 .zip(ys.par_iter_mut())
64 .with_min_len(64)
65 .for_each(|(x, y)| {
66 let prod = (*x) * packed_r_i;
70 *x -= prod;
71 *y = prod;
72 });
73 }
74 }
75 Ok(())
76}
77
78pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> Vec<P> {
94 let n = point.len();
95 let len = 1 << n.saturating_sub(P::LOG_WIDTH);
96 let mut buffer = zeroed_vec::<P>(len);
97 buffer[0].set(0, P::Scalar::ONE);
98 tensor_prod_eq_ind(0, &mut buffer, point).expect("buffer is allocated with the correct length");
99 buffer
100}
101
102#[cfg(test)]
103mod tests {
104 use binius_field::{packed::set_packed_slice, Field, PackedBinaryField4x32b};
105 use itertools::Itertools;
106
107 use super::*;
108
109 type P = PackedBinaryField4x32b;
110 type F = <P as PackedField>::Scalar;
111
112 #[test]
113 fn test_tensor_prod_eq_ind() {
114 let v0 = F::new(1);
115 let v1 = F::new(2);
116 let query = vec![v0, v1];
117 let mut result = vec![P::default(); 1 << (query.len() - P::LOG_WIDTH)];
118 set_packed_slice(&mut result, 0, F::ONE);
119 tensor_prod_eq_ind(0, &mut result, &query).unwrap();
120 let result = PackedField::iter_slice(&result).collect_vec();
121 assert_eq!(
122 result,
123 vec![
124 (F::ONE - v0) * (F::ONE - v1),
125 v0 * (F::ONE - v1),
126 (F::ONE - v0) * v1,
127 v0 * v1
128 ]
129 );
130 }
131
132 #[test]
133 fn test_eq_ind_partial_eval_empty() {
134 let result = eq_ind_partial_eval::<P>(&[]);
135 let expected = vec![P::set_single(F::ONE)];
136 assert_eq!(result, expected);
137 }
138
139 #[test]
140 fn test_eq_ind_partial_eval_single_var() {
141 let r0 = F::new(2);
143 let result = eq_ind_partial_eval::<P>(&[r0]);
144 let expected = vec![(F::ONE - r0), r0, F::ZERO, F::ZERO];
145 let result = PackedField::iter_slice(&result).collect_vec();
146 assert_eq!(result, expected);
147 }
148
149 #[test]
150 fn test_eq_ind_partial_eval_two_vars() {
151 let r0 = F::new(2);
153 let r1 = F::new(3);
154 let result = eq_ind_partial_eval::<P>(&[r0, r1]);
155 let result = PackedField::iter_slice(&result).collect_vec();
156 let expected = vec![
157 (F::ONE - r0) * (F::ONE - r1),
158 r0 * (F::ONE - r1),
159 (F::ONE - r0) * r1,
160 r0 * r1,
161 ];
162 assert_eq!(result, expected);
163 }
164
165 #[test]
166 fn test_eq_ind_partial_eval_three_vars() {
167 let r0 = F::new(2);
169 let r1 = F::new(3);
170 let r2 = F::new(5);
171 let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
172 let result = PackedField::iter_slice(&result).collect_vec();
173
174 let expected = vec![
175 (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
176 r0 * (F::ONE - r1) * (F::ONE - r2),
177 (F::ONE - r0) * r1 * (F::ONE - r2),
178 r0 * r1 * (F::ONE - r2),
179 (F::ONE - r0) * (F::ONE - r1) * r2,
180 r0 * (F::ONE - r1) * r2,
181 (F::ONE - r0) * r1 * r2,
182 r0 * r1 * r2,
183 ];
184 assert_eq!(result, expected);
185 }
186}