1use std::ops::Deref;
5
6use binius_field::{Field, util::inner_product_unchecked};
7
8use super::error::Error;
9use crate::field_buffer::FieldBuffer;
10
11pub fn fold_cols<F, DataMat, DataVec>(
43 mat: &FieldBuffer<F, DataMat>,
44 vec: &FieldBuffer<F, DataVec>,
45) -> Result<FieldBuffer<F>, Error>
46where
47 F: Field,
48 DataMat: Deref<Target = [F]>,
49 DataVec: Deref<Target = [F]>,
50{
51 let log_m = vec.log_len();
52 let Some(log_n) = mat.log_len().checked_sub(vec.log_len()) else {
53 return Err(Error::ArgumentRangeError {
54 arg: "vec.log_len()".to_string(),
55 range: 0..mat.log_len(),
56 });
57 };
58
59 let ret_vals = mat
60 .chunks(log_m)?
61 .map(|row| {
62 inner_product_unchecked(row.as_ref().iter().copied(), vec.as_ref().iter().copied())
63 })
64 .collect::<Box<[_]>>();
65 FieldBuffer::new(log_n, ret_vals)
66}
67
68pub fn fold_rows<F, DataMat, DataVec>(
100 mat: &FieldBuffer<F, DataMat>,
101 vec: &FieldBuffer<F, DataVec>,
102) -> Result<FieldBuffer<F>, Error>
103where
104 F: Field,
105 DataMat: Deref<Target = [F]>,
106 DataVec: Deref<Target = [F]>,
107{
108 let log_n = vec.log_len();
109 let Some(log_m) = mat.log_len().checked_sub(vec.log_len()) else {
110 return Err(Error::ArgumentRangeError {
111 arg: "vec.log_len()".to_string(),
112 range: 0..mat.log_len(),
113 });
114 };
115
116 let mat_vals = mat.as_ref();
117 let ret_vals = (0..1 << log_m)
118 .map(|col_i| {
119 let col = (0..1 << log_n).map(|row_i| mat_vals[(row_i << log_m) + col_i]);
120 inner_product_unchecked(col, vec.as_ref().iter().copied())
121 })
122 .collect::<Box<[_]>>();
123 FieldBuffer::new(log_m, ret_vals)
124}
125
126#[cfg(test)]
127mod tests {
128 use std::iter;
129
130 use binius_field::Random;
131 use rand::{SeedableRng, rngs::StdRng};
132
133 use super::*;
134 use crate::test_utils::{B128, random_scalars};
135
136 #[test]
137 fn test_fold_cols_linear_in_matrix() {
138 let mut rng = StdRng::seed_from_u64(0);
139
140 let log_rows = 5;
142 let log_cols = 5;
143 let total_elements = 1 << (log_rows + log_cols);
144
145 let mat0_values = random_scalars::<B128>(&mut rng, total_elements);
147 let mat0 = FieldBuffer::<B128>::from_values(&mat0_values).unwrap();
148
149 let mat1_values = random_scalars::<B128>(&mut rng, total_elements);
150 let mat1 = FieldBuffer::<B128>::from_values(&mat1_values).unwrap();
151
152 let vec_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
154 let vec = FieldBuffer::<B128>::from_values(&vec_values).unwrap();
155
156 let scalar0 = B128::random(&mut rng);
158 let scalar1 = B128::random(&mut rng);
159
160 let scaled_mat_values: Vec<B128> = iter::zip(&mat0_values, &mat1_values)
162 .map(|(&m0, &m1)| scalar0 * m0 + scalar1 * m1)
163 .collect();
164 let scaled_mat = FieldBuffer::<B128>::from_values(&scaled_mat_values).unwrap();
165 let left_side = fold_cols(&scaled_mat, &vec).unwrap();
166
167 let mat0_vec = fold_cols(&mat0, &vec).unwrap();
169 let mat1_vec = fold_cols(&mat1, &vec).unwrap();
170 let right_side_values: Vec<B128> = mat0_vec
171 .as_ref()
172 .iter()
173 .zip(mat1_vec.as_ref().iter())
174 .map(|(&m0v, &m1v)| scalar0 * m0v + scalar1 * m1v)
175 .collect();
176 let right_side = FieldBuffer::<B128>::from_values(&right_side_values).unwrap();
177
178 assert_eq!(left_side.as_ref(), right_side.as_ref());
180 }
181
182 #[test]
183 fn test_fold_cols_linear_in_vector() {
184 let mut rng = StdRng::seed_from_u64(0);
185
186 let log_rows = 5;
188 let log_cols = 5;
189 let total_elements = 1 << (log_rows + log_cols);
190
191 let mat_values = random_scalars::<B128>(&mut rng, total_elements);
193 let mat = FieldBuffer::<B128>::from_values(&mat_values).unwrap();
194
195 let vec0_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
197 let vec0 = FieldBuffer::<B128>::from_values(&vec0_values).unwrap();
198
199 let vec1_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
200 let vec1 = FieldBuffer::<B128>::from_values(&vec1_values).unwrap();
201
202 let scalar0 = B128::random(&mut rng);
204 let scalar1 = B128::random(&mut rng);
205
206 let scaled_vec_values: Vec<B128> = vec0_values
208 .iter()
209 .zip(vec1_values.iter())
210 .map(|(&v0, &v1)| scalar0 * v0 + scalar1 * v1)
211 .collect();
212 let scaled_vec = FieldBuffer::<B128>::from_values(&scaled_vec_values).unwrap();
213 let left_side = fold_cols(&mat, &scaled_vec).unwrap();
214
215 let mat_vec0 = fold_cols(&mat, &vec0).unwrap();
217 let mat_vec1 = fold_cols(&mat, &vec1).unwrap();
218 let right_side_values: Vec<B128> = mat_vec0
219 .as_ref()
220 .iter()
221 .zip(mat_vec1.as_ref().iter())
222 .map(|(&mv0, &mv1)| scalar0 * mv0 + scalar1 * mv1)
223 .collect();
224 let right_side = FieldBuffer::<B128>::from_values(&right_side_values).unwrap();
225
226 assert_eq!(left_side.as_ref(), right_side.as_ref());
228 }
229
230 #[test]
231 fn test_fold_cols_equals_fold_rows_transpose() {
232 let mut rng = StdRng::seed_from_u64(0);
233
234 let log_rows = 5;
236 let log_cols = 5;
237 let n_rows = 1 << log_rows;
238 let n_cols = 1 << log_cols;
239 let total_elements = n_rows * n_cols;
240
241 let mat_values = random_scalars::<B128>(&mut rng, total_elements);
243 let mat = FieldBuffer::<B128>::from_values(&mat_values).unwrap();
244
245 let vec_values = random_scalars::<B128>(&mut rng, n_cols);
247 let vec = FieldBuffer::<B128>::from_values(&vec_values).unwrap();
248
249 let fold_cols_result = fold_cols(&mat, &vec).unwrap();
251
252 let mut mat_t_values = vec![B128::ZERO; total_elements];
256 for i in 0..n_rows {
257 for j in 0..n_cols {
258 let orig_idx = i * n_cols + j;
259 let trans_idx = j * n_rows + i;
260 mat_t_values[trans_idx] = mat_values[orig_idx];
261 }
262 }
263 let mat_t = FieldBuffer::<B128>::from_values(&mat_t_values).unwrap();
264
265 let fold_rows_result = fold_rows(&mat_t, &vec).unwrap();
268
269 let fold_cols_values: Vec<B128> = (0..fold_cols_result.len())
271 .map(|i| fold_cols_result.to_ref().get(i).unwrap())
272 .collect();
273 let fold_rows_values: Vec<B128> = (0..fold_rows_result.len())
274 .map(|i| fold_rows_result.to_ref().get(i).unwrap())
275 .collect();
276
277 assert_eq!(fold_cols_values, fold_rows_values);
279 }
280
281 #[test]
282 fn test_fold_cols_tensor_product() {
283 let mut rng = StdRng::seed_from_u64(0);
284
285 let n = 10;
287 let m1 = 3;
288 let m2 = 4;
289
290 let vec1_size = 1 << m1; let vec1_values = random_scalars::<B128>(&mut rng, vec1_size);
293 let vec1 = FieldBuffer::<B128>::from_values(&vec1_values).unwrap();
294
295 let vec2_size = 1 << m2; let vec2_values = random_scalars::<B128>(&mut rng, vec2_size);
297 let vec2 = FieldBuffer::<B128>::from_values(&vec2_values).unwrap();
298
299 let tensor_product_size = vec2_size * vec1_size; let mut tensor_product_values = Vec::with_capacity(tensor_product_size);
303 for &v2 in vec2_values.iter() {
304 for &v1 in vec1_values.iter() {
305 tensor_product_values.push(v2 * v1);
306 }
307 }
308 let tensor_product = FieldBuffer::<B128>::from_values(&tensor_product_values).unwrap();
309
310 let matrix_values = random_scalars::<B128>(&mut rng, 1 << n);
312 let matrix = FieldBuffer::<B128>::from_values(&matrix_values).unwrap();
313
314 let intermediate = fold_cols(&matrix, &vec1).unwrap();
317 let sequential_result = fold_cols(&intermediate, &vec2).unwrap();
319
320 let direct_result = fold_cols(&matrix, &tensor_product).unwrap();
323
324 assert_eq!(sequential_result.as_ref(), direct_result.as_ref());
326 }
327}