binius_math/
tensor_prod_eq_ind.rs1use std::cmp::max;
4
5use binius_field::{Field, PackedField};
6use binius_maybe_rayon::prelude::*;
7use binius_utils::bail;
8use bytemuck::zeroed_vec;
9
10use crate::Error;
11
12pub fn tensor_prod_eq_ind<P: PackedField>(
36 log_n_values: usize,
37 packed_values: &mut [P],
38 extra_query_coordinates: &[P::Scalar],
39) -> Result<(), Error> {
40 let new_n_vars = log_n_values + extra_query_coordinates.len();
41 if packed_values.len() != max(1, (1 << new_n_vars) / P::WIDTH) {
42 bail!(Error::InvalidPackedValuesLength);
43 }
44
45 for (i, r_i) in extra_query_coordinates.iter().enumerate() {
46 let prev_length = 1 << (log_n_values + i);
47 if prev_length < P::WIDTH {
48 let q = &mut packed_values[0];
49 for h in 0..prev_length {
50 let x = q.get(h);
51 let prod = x * r_i;
52 q.set(h, x - prod);
53 q.set(prev_length | h, prod);
54 }
55 } else {
56 let prev_packed_length = prev_length / P::WIDTH;
57 let packed_r_i = P::broadcast(*r_i);
58 let (xs, ys) = packed_values.split_at_mut(prev_packed_length);
59 assert!(xs.len() <= ys.len());
60
61 xs.par_iter_mut()
64 .zip(ys.par_iter_mut())
65 .with_min_len(64)
66 .for_each(|(x, y)| {
67 let prod = (*x) * packed_r_i;
71 *x -= prod;
72 *y = prod;
73 });
74 }
75 }
76 Ok(())
77}
78
79pub fn eq_ind_partial_eval<P: PackedField>(point: &[P::Scalar]) -> Vec<P> {
95 let n = point.len();
96 let len = 1 << n.saturating_sub(P::LOG_WIDTH);
97 let mut buffer = zeroed_vec::<P>(len);
98 buffer[0].set(0, P::Scalar::ONE);
99 tensor_prod_eq_ind(0, &mut buffer, point).expect("buffer is allocated with the correct length");
100 buffer
101}
102
103#[cfg(test)]
104mod tests {
105 use binius_field::{Field, PackedBinaryField4x32b, packed::set_packed_slice};
106 use itertools::Itertools;
107
108 use super::*;
109
110 type P = PackedBinaryField4x32b;
111 type F = <P as PackedField>::Scalar;
112
113 #[test]
114 fn test_tensor_prod_eq_ind() {
115 let v0 = F::new(1);
116 let v1 = F::new(2);
117 let query = vec![v0, v1];
118 let mut result = vec![P::default(); 1 << (query.len() - P::LOG_WIDTH)];
119 set_packed_slice(&mut result, 0, F::ONE);
120 tensor_prod_eq_ind(0, &mut result, &query).unwrap();
121 let result = PackedField::iter_slice(&result).collect_vec();
122 assert_eq!(
123 result,
124 vec![
125 (F::ONE - v0) * (F::ONE - v1),
126 v0 * (F::ONE - v1),
127 (F::ONE - v0) * v1,
128 v0 * v1
129 ]
130 );
131 }
132
133 #[test]
134 fn test_eq_ind_partial_eval_empty() {
135 let result = eq_ind_partial_eval::<P>(&[]);
136 let expected = vec![P::set_single(F::ONE)];
137 assert_eq!(result, expected);
138 }
139
140 #[test]
141 fn test_eq_ind_partial_eval_single_var() {
142 let r0 = F::new(2);
144 let result = eq_ind_partial_eval::<P>(&[r0]);
145 let expected = vec![(F::ONE - r0), r0, F::ZERO, F::ZERO];
146 let result = PackedField::iter_slice(&result).collect_vec();
147 assert_eq!(result, expected);
148 }
149
150 #[test]
151 fn test_eq_ind_partial_eval_two_vars() {
152 let r0 = F::new(2);
154 let r1 = F::new(3);
155 let result = eq_ind_partial_eval::<P>(&[r0, r1]);
156 let result = PackedField::iter_slice(&result).collect_vec();
157 let expected = vec![
158 (F::ONE - r0) * (F::ONE - r1),
159 r0 * (F::ONE - r1),
160 (F::ONE - r0) * r1,
161 r0 * r1,
162 ];
163 assert_eq!(result, expected);
164 }
165
166 #[test]
167 fn test_eq_ind_partial_eval_three_vars() {
168 let r0 = F::new(2);
170 let r1 = F::new(3);
171 let r2 = F::new(5);
172 let result = eq_ind_partial_eval::<P>(&[r0, r1, r2]);
173 let result = PackedField::iter_slice(&result).collect_vec();
174
175 let expected = vec![
176 (F::ONE - r0) * (F::ONE - r1) * (F::ONE - r2),
177 r0 * (F::ONE - r1) * (F::ONE - r2),
178 (F::ONE - r0) * r1 * (F::ONE - r2),
179 r0 * r1 * (F::ONE - r2),
180 (F::ONE - r0) * (F::ONE - r1) * r2,
181 r0 * (F::ONE - r1) * r2,
182 (F::ONE - r0) * r1 * r2,
183 r0 * r1 * r2,
184 ];
185 assert_eq!(result, expected);
186 }
187}