1use std::ops::Deref;
5
6use binius_field::Field;
7
8use crate::{
9 field_buffer::FieldBuffer,
10 inner_product::{inner_product, inner_product_buffers},
11};
12
13pub fn fold_cols<F, DataMat, DataVec>(
44 mat: &FieldBuffer<F, DataMat>,
45 vec: &FieldBuffer<F, DataVec>,
46) -> FieldBuffer<F>
47where
48 F: Field,
49 DataMat: Deref<Target = [F]>,
50 DataVec: Deref<Target = [F]>,
51{
52 let log_m = vec.log_len();
53 let log_n = mat
54 .log_len()
55 .checked_sub(vec.log_len())
56 .expect("precondition: mat.log_len() must be at least vec.log_len()");
57
58 let ret_vals = mat
59 .chunks(log_m)
60 .map(|row| inner_product_buffers(&row, vec))
61 .collect::<Box<[_]>>();
62 FieldBuffer::new(log_n, ret_vals)
63}
64
65pub fn fold_rows<F, DataMat, DataVec>(
96 mat: &FieldBuffer<F, DataMat>,
97 vec: &FieldBuffer<F, DataVec>,
98) -> FieldBuffer<F>
99where
100 F: Field,
101 DataMat: Deref<Target = [F]>,
102 DataVec: Deref<Target = [F]>,
103{
104 let log_n = vec.log_len();
105 let log_m = mat
106 .log_len()
107 .checked_sub(vec.log_len())
108 .expect("precondition: mat.log_len() must be at least vec.log_len()");
109
110 let mat_vals = mat.as_ref();
111 let ret_vals = (0..1 << log_m)
112 .map(|col_i| {
113 let col = (0..1 << log_n).map(|row_i| mat_vals[(row_i << log_m) + col_i]);
114 inner_product(col, vec.as_ref().iter().copied())
115 })
116 .collect::<Box<[_]>>();
117 FieldBuffer::new(log_m, ret_vals)
118}
119
120#[cfg(test)]
121mod tests {
122 use std::iter;
123
124 use binius_field::Random;
125 use rand::{SeedableRng, rngs::StdRng};
126
127 use super::*;
128 use crate::test_utils::{B128, random_scalars};
129
130 #[test]
131 fn test_fold_cols_linear_in_matrix() {
132 let mut rng = StdRng::seed_from_u64(0);
133
134 let log_rows = 5;
136 let log_cols = 5;
137 let total_elements = 1 << (log_rows + log_cols);
138
139 let mat0_values = random_scalars::<B128>(&mut rng, total_elements);
141 let mat0 = FieldBuffer::<B128>::from_values(&mat0_values);
142
143 let mat1_values = random_scalars::<B128>(&mut rng, total_elements);
144 let mat1 = FieldBuffer::<B128>::from_values(&mat1_values);
145
146 let vec_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
148 let vec = FieldBuffer::<B128>::from_values(&vec_values);
149
150 let scalar0 = B128::random(&mut rng);
152 let scalar1 = B128::random(&mut rng);
153
154 let scaled_mat_values: Vec<B128> = iter::zip(&mat0_values, &mat1_values)
156 .map(|(&m0, &m1)| scalar0 * m0 + scalar1 * m1)
157 .collect();
158 let scaled_mat = FieldBuffer::<B128>::from_values(&scaled_mat_values);
159 let left_side = fold_cols(&scaled_mat, &vec);
160
161 let mat0_vec = fold_cols(&mat0, &vec);
163 let mat1_vec = fold_cols(&mat1, &vec);
164 let right_side_values: Vec<B128> = mat0_vec
165 .as_ref()
166 .iter()
167 .zip(mat1_vec.as_ref().iter())
168 .map(|(&m0v, &m1v)| scalar0 * m0v + scalar1 * m1v)
169 .collect();
170 let right_side = FieldBuffer::<B128>::from_values(&right_side_values);
171
172 assert_eq!(left_side.as_ref(), right_side.as_ref());
174 }
175
176 #[test]
177 fn test_fold_cols_linear_in_vector() {
178 let mut rng = StdRng::seed_from_u64(0);
179
180 let log_rows = 5;
182 let log_cols = 5;
183 let total_elements = 1 << (log_rows + log_cols);
184
185 let mat_values = random_scalars::<B128>(&mut rng, total_elements);
187 let mat = FieldBuffer::<B128>::from_values(&mat_values);
188
189 let vec0_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
191 let vec0 = FieldBuffer::<B128>::from_values(&vec0_values);
192
193 let vec1_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
194 let vec1 = FieldBuffer::<B128>::from_values(&vec1_values);
195
196 let scalar0 = B128::random(&mut rng);
198 let scalar1 = B128::random(&mut rng);
199
200 let scaled_vec_values: Vec<B128> = vec0_values
202 .iter()
203 .zip(vec1_values.iter())
204 .map(|(&v0, &v1)| scalar0 * v0 + scalar1 * v1)
205 .collect();
206 let scaled_vec = FieldBuffer::<B128>::from_values(&scaled_vec_values);
207 let left_side = fold_cols(&mat, &scaled_vec);
208
209 let mat_vec0 = fold_cols(&mat, &vec0);
211 let mat_vec1 = fold_cols(&mat, &vec1);
212 let right_side_values: Vec<B128> = mat_vec0
213 .as_ref()
214 .iter()
215 .zip(mat_vec1.as_ref().iter())
216 .map(|(&mv0, &mv1)| scalar0 * mv0 + scalar1 * mv1)
217 .collect();
218 let right_side = FieldBuffer::<B128>::from_values(&right_side_values);
219
220 assert_eq!(left_side.as_ref(), right_side.as_ref());
222 }
223
224 #[test]
225 fn test_fold_cols_equals_fold_rows_transpose() {
226 let mut rng = StdRng::seed_from_u64(0);
227
228 let log_rows = 5;
230 let log_cols = 5;
231 let n_rows = 1 << log_rows;
232 let n_cols = 1 << log_cols;
233 let total_elements = n_rows * n_cols;
234
235 let mat_values = random_scalars::<B128>(&mut rng, total_elements);
237 let mat = FieldBuffer::<B128>::from_values(&mat_values);
238
239 let vec_values = random_scalars::<B128>(&mut rng, n_cols);
241 let vec = FieldBuffer::<B128>::from_values(&vec_values);
242
243 let fold_cols_result = fold_cols(&mat, &vec);
245
246 let mut mat_t_values = vec![B128::ZERO; total_elements];
250 for i in 0..n_rows {
251 for j in 0..n_cols {
252 let orig_idx = i * n_cols + j;
253 let trans_idx = j * n_rows + i;
254 mat_t_values[trans_idx] = mat_values[orig_idx];
255 }
256 }
257 let mat_t = FieldBuffer::<B128>::from_values(&mat_t_values);
258
259 let fold_rows_result = fold_rows(&mat_t, &vec);
262
263 let fold_cols_values: Vec<B128> = (0..fold_cols_result.len())
265 .map(|i| fold_cols_result.to_ref().get(i))
266 .collect();
267 let fold_rows_values: Vec<B128> = (0..fold_rows_result.len())
268 .map(|i| fold_rows_result.to_ref().get(i))
269 .collect();
270
271 assert_eq!(fold_cols_values, fold_rows_values);
273 }
274
275 #[test]
276 fn test_fold_cols_tensor_product() {
277 let mut rng = StdRng::seed_from_u64(0);
278
279 let n = 10;
281 let m1 = 3;
282 let m2 = 4;
283
284 let vec1_size = 1 << m1; let vec1_values = random_scalars::<B128>(&mut rng, vec1_size);
287 let vec1 = FieldBuffer::<B128>::from_values(&vec1_values);
288
289 let vec2_size = 1 << m2; let vec2_values = random_scalars::<B128>(&mut rng, vec2_size);
291 let vec2 = FieldBuffer::<B128>::from_values(&vec2_values);
292
293 let tensor_product_size = vec2_size * vec1_size; let mut tensor_product_values = Vec::with_capacity(tensor_product_size);
297 for &v2 in vec2_values.iter() {
298 for &v1 in vec1_values.iter() {
299 tensor_product_values.push(v2 * v1);
300 }
301 }
302 let tensor_product = FieldBuffer::<B128>::from_values(&tensor_product_values);
303
304 let matrix_values = random_scalars::<B128>(&mut rng, 1 << n);
306 let matrix = FieldBuffer::<B128>::from_values(&matrix_values);
307
308 let intermediate = fold_cols(&matrix, &vec1);
311 let sequential_result = fold_cols(&intermediate, &vec2);
313
314 let direct_result = fold_cols(&matrix, &tensor_product);
317
318 assert_eq!(sequential_result.as_ref(), direct_result.as_ref());
320 }
321}