1use std::ops::Deref;
5
6use binius_field::Field;
7
8use super::error::Error;
9use crate::{
10 field_buffer::FieldBuffer,
11 inner_product::{inner_product, inner_product_buffers},
12};
13
14pub fn fold_cols<F, DataMat, DataVec>(
46 mat: &FieldBuffer<F, DataMat>,
47 vec: &FieldBuffer<F, DataVec>,
48) -> Result<FieldBuffer<F>, Error>
49where
50 F: Field,
51 DataMat: Deref<Target = [F]>,
52 DataVec: Deref<Target = [F]>,
53{
54 let log_m = vec.log_len();
55 let Some(log_n) = mat.log_len().checked_sub(vec.log_len()) else {
56 return Err(Error::ArgumentRangeError {
57 arg: "vec.log_len()".to_string(),
58 range: 0..mat.log_len(),
59 });
60 };
61
62 let ret_vals = mat
63 .chunks(log_m)?
64 .map(|row| inner_product_buffers(&row, vec))
65 .collect::<Box<[_]>>();
66 FieldBuffer::new(log_n, ret_vals)
67}
68
69pub fn fold_rows<F, DataMat, DataVec>(
101 mat: &FieldBuffer<F, DataMat>,
102 vec: &FieldBuffer<F, DataVec>,
103) -> Result<FieldBuffer<F>, Error>
104where
105 F: Field,
106 DataMat: Deref<Target = [F]>,
107 DataVec: Deref<Target = [F]>,
108{
109 let log_n = vec.log_len();
110 let Some(log_m) = mat.log_len().checked_sub(vec.log_len()) else {
111 return Err(Error::ArgumentRangeError {
112 arg: "vec.log_len()".to_string(),
113 range: 0..mat.log_len(),
114 });
115 };
116
117 let mat_vals = mat.as_ref();
118 let ret_vals = (0..1 << log_m)
119 .map(|col_i| {
120 let col = (0..1 << log_n).map(|row_i| mat_vals[(row_i << log_m) + col_i]);
121 inner_product(col, vec.as_ref().iter().copied())
122 })
123 .collect::<Box<[_]>>();
124 FieldBuffer::new(log_m, ret_vals)
125}
126
127#[cfg(test)]
128mod tests {
129 use std::iter;
130
131 use binius_field::Random;
132 use rand::{SeedableRng, rngs::StdRng};
133
134 use super::*;
135 use crate::test_utils::{B128, random_scalars};
136
137 #[test]
138 fn test_fold_cols_linear_in_matrix() {
139 let mut rng = StdRng::seed_from_u64(0);
140
141 let log_rows = 5;
143 let log_cols = 5;
144 let total_elements = 1 << (log_rows + log_cols);
145
146 let mat0_values = random_scalars::<B128>(&mut rng, total_elements);
148 let mat0 = FieldBuffer::<B128>::from_values(&mat0_values).unwrap();
149
150 let mat1_values = random_scalars::<B128>(&mut rng, total_elements);
151 let mat1 = FieldBuffer::<B128>::from_values(&mat1_values).unwrap();
152
153 let vec_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
155 let vec = FieldBuffer::<B128>::from_values(&vec_values).unwrap();
156
157 let scalar0 = B128::random(&mut rng);
159 let scalar1 = B128::random(&mut rng);
160
161 let scaled_mat_values: Vec<B128> = iter::zip(&mat0_values, &mat1_values)
163 .map(|(&m0, &m1)| scalar0 * m0 + scalar1 * m1)
164 .collect();
165 let scaled_mat = FieldBuffer::<B128>::from_values(&scaled_mat_values).unwrap();
166 let left_side = fold_cols(&scaled_mat, &vec).unwrap();
167
168 let mat0_vec = fold_cols(&mat0, &vec).unwrap();
170 let mat1_vec = fold_cols(&mat1, &vec).unwrap();
171 let right_side_values: Vec<B128> = mat0_vec
172 .as_ref()
173 .iter()
174 .zip(mat1_vec.as_ref().iter())
175 .map(|(&m0v, &m1v)| scalar0 * m0v + scalar1 * m1v)
176 .collect();
177 let right_side = FieldBuffer::<B128>::from_values(&right_side_values).unwrap();
178
179 assert_eq!(left_side.as_ref(), right_side.as_ref());
181 }
182
183 #[test]
184 fn test_fold_cols_linear_in_vector() {
185 let mut rng = StdRng::seed_from_u64(0);
186
187 let log_rows = 5;
189 let log_cols = 5;
190 let total_elements = 1 << (log_rows + log_cols);
191
192 let mat_values = random_scalars::<B128>(&mut rng, total_elements);
194 let mat = FieldBuffer::<B128>::from_values(&mat_values).unwrap();
195
196 let vec0_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
198 let vec0 = FieldBuffer::<B128>::from_values(&vec0_values).unwrap();
199
200 let vec1_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
201 let vec1 = FieldBuffer::<B128>::from_values(&vec1_values).unwrap();
202
203 let scalar0 = B128::random(&mut rng);
205 let scalar1 = B128::random(&mut rng);
206
207 let scaled_vec_values: Vec<B128> = vec0_values
209 .iter()
210 .zip(vec1_values.iter())
211 .map(|(&v0, &v1)| scalar0 * v0 + scalar1 * v1)
212 .collect();
213 let scaled_vec = FieldBuffer::<B128>::from_values(&scaled_vec_values).unwrap();
214 let left_side = fold_cols(&mat, &scaled_vec).unwrap();
215
216 let mat_vec0 = fold_cols(&mat, &vec0).unwrap();
218 let mat_vec1 = fold_cols(&mat, &vec1).unwrap();
219 let right_side_values: Vec<B128> = mat_vec0
220 .as_ref()
221 .iter()
222 .zip(mat_vec1.as_ref().iter())
223 .map(|(&mv0, &mv1)| scalar0 * mv0 + scalar1 * mv1)
224 .collect();
225 let right_side = FieldBuffer::<B128>::from_values(&right_side_values).unwrap();
226
227 assert_eq!(left_side.as_ref(), right_side.as_ref());
229 }
230
231 #[test]
232 fn test_fold_cols_equals_fold_rows_transpose() {
233 let mut rng = StdRng::seed_from_u64(0);
234
235 let log_rows = 5;
237 let log_cols = 5;
238 let n_rows = 1 << log_rows;
239 let n_cols = 1 << log_cols;
240 let total_elements = n_rows * n_cols;
241
242 let mat_values = random_scalars::<B128>(&mut rng, total_elements);
244 let mat = FieldBuffer::<B128>::from_values(&mat_values).unwrap();
245
246 let vec_values = random_scalars::<B128>(&mut rng, n_cols);
248 let vec = FieldBuffer::<B128>::from_values(&vec_values).unwrap();
249
250 let fold_cols_result = fold_cols(&mat, &vec).unwrap();
252
253 let mut mat_t_values = vec![B128::ZERO; total_elements];
257 for i in 0..n_rows {
258 for j in 0..n_cols {
259 let orig_idx = i * n_cols + j;
260 let trans_idx = j * n_rows + i;
261 mat_t_values[trans_idx] = mat_values[orig_idx];
262 }
263 }
264 let mat_t = FieldBuffer::<B128>::from_values(&mat_t_values).unwrap();
265
266 let fold_rows_result = fold_rows(&mat_t, &vec).unwrap();
269
270 let fold_cols_values: Vec<B128> = (0..fold_cols_result.len())
272 .map(|i| fold_cols_result.to_ref().get_checked(i).unwrap())
273 .collect();
274 let fold_rows_values: Vec<B128> = (0..fold_rows_result.len())
275 .map(|i| fold_rows_result.to_ref().get_checked(i).unwrap())
276 .collect();
277
278 assert_eq!(fold_cols_values, fold_rows_values);
280 }
281
282 #[test]
283 fn test_fold_cols_tensor_product() {
284 let mut rng = StdRng::seed_from_u64(0);
285
286 let n = 10;
288 let m1 = 3;
289 let m2 = 4;
290
291 let vec1_size = 1 << m1; let vec1_values = random_scalars::<B128>(&mut rng, vec1_size);
294 let vec1 = FieldBuffer::<B128>::from_values(&vec1_values).unwrap();
295
296 let vec2_size = 1 << m2; let vec2_values = random_scalars::<B128>(&mut rng, vec2_size);
298 let vec2 = FieldBuffer::<B128>::from_values(&vec2_values).unwrap();
299
300 let tensor_product_size = vec2_size * vec1_size; let mut tensor_product_values = Vec::with_capacity(tensor_product_size);
304 for &v2 in vec2_values.iter() {
305 for &v1 in vec1_values.iter() {
306 tensor_product_values.push(v2 * v1);
307 }
308 }
309 let tensor_product = FieldBuffer::<B128>::from_values(&tensor_product_values).unwrap();
310
311 let matrix_values = random_scalars::<B128>(&mut rng, 1 << n);
313 let matrix = FieldBuffer::<B128>::from_values(&matrix_values).unwrap();
314
315 let intermediate = fold_cols(&matrix, &vec1).unwrap();
318 let sequential_result = fold_cols(&intermediate, &vec2).unwrap();
320
321 let direct_result = fold_cols(&matrix, &tensor_product).unwrap();
324
325 assert_eq!(sequential_result.as_ref(), direct_result.as_ref());
327 }
328}