binius_math/
fold.rs

1// Copyright 2025 Irreducible Inc.
2//! Algorithms for matrix multiplications and generalized tensor contractions.
3
4use std::ops::Deref;
5
6use binius_field::Field;
7
8use crate::{
9	field_buffer::FieldBuffer,
10	inner_product::{inner_product, inner_product_buffers},
11};
12
13/// Computes a linear combination of the columns of a matrix.
14///
15/// A column-combination of a matrix is a matrix-vector multiplication.
16///
17/// This implementation is a naive, single-threaded implementation operating on buffers of scalar
18/// elements.
19///
20/// ## Mathematical Definition
21///
22/// This operation accepts
23///
24/// * $n \in \mathbb{N}$ (`out.len()`),
25/// * $m \in \mathbb{N}$ (`vec.len()`),
26/// * $M \in K^{n \times m}$ (`mat`),
27/// * $v \in K^m$ (`vec`),
28///
29/// and computes the vector $Mv$.
30///
31/// ## Args
32///
33/// * `mat` - a buffer of `n * m` `F` elements, interpreted as a row-major matrix.
34/// * `vec` - a buffer of `m` `F` elements containing the column scalars.
35///
36/// ## Returns
37///
38/// The matrix-vector product, as a buffer of `F` elements.
39///
40/// ## Preconditions
41///
42/// * `mat.log_len()` must be at least `vec.log_len()`
43pub fn fold_cols<F, DataMat, DataVec>(
44	mat: &FieldBuffer<F, DataMat>,
45	vec: &FieldBuffer<F, DataVec>,
46) -> FieldBuffer<F>
47where
48	F: Field,
49	DataMat: Deref<Target = [F]>,
50	DataVec: Deref<Target = [F]>,
51{
52	let log_m = vec.log_len();
53	let log_n = mat
54		.log_len()
55		.checked_sub(vec.log_len())
56		.expect("precondition: mat.log_len() must be at least vec.log_len()");
57
58	let ret_vals = mat
59		.chunks(log_m)
60		.map(|row| inner_product_buffers(&row, vec))
61		.collect::<Box<[_]>>();
62	FieldBuffer::new(log_n, ret_vals)
63}
64
65/// Computes a linear combination of the rows of a matrix.
66///
67/// A row-combination of a matrix is a vector-matrix multiplication.
68///
69/// This implementation is a naive, single-threaded implementation operating on buffers of scalar
70/// elements.
71///
72/// ## Mathematical Definition
73///
74/// This operation accepts
75///
76/// * $n \in \mathbb{N}$ (`vec.len()`),
77/// * $m \in \mathbb{N}$ (`out.len()`),
78/// * $M \in K^{n \times m}$ (`mat`),
79/// * $v \in K^m$ (`vec`),
80///
81/// and computes the vector $v^\top M^\top$.
82///
83/// ## Args
84///
85/// * `mat` - a buffer of `n * m` `F` elements, interpreted as a row-major matrix.
86/// * `vec` - a buffer of `n` `F` elements containing the row scalars.
87///
88/// ## Returns
89///
90/// The vector-matrix product, as a buffer of `F` elements.
91///
92/// ## Preconditions
93///
94/// * `mat.log_len()` must be at least `vec.log_len()`
95pub fn fold_rows<F, DataMat, DataVec>(
96	mat: &FieldBuffer<F, DataMat>,
97	vec: &FieldBuffer<F, DataVec>,
98) -> FieldBuffer<F>
99where
100	F: Field,
101	DataMat: Deref<Target = [F]>,
102	DataVec: Deref<Target = [F]>,
103{
104	let log_n = vec.log_len();
105	let log_m = mat
106		.log_len()
107		.checked_sub(vec.log_len())
108		.expect("precondition: mat.log_len() must be at least vec.log_len()");
109
110	let mat_vals = mat.as_ref();
111	let ret_vals = (0..1 << log_m)
112		.map(|col_i| {
113			let col = (0..1 << log_n).map(|row_i| mat_vals[(row_i << log_m) + col_i]);
114			inner_product(col, vec.as_ref().iter().copied())
115		})
116		.collect::<Box<[_]>>();
117	FieldBuffer::new(log_m, ret_vals)
118}
119
120#[cfg(test)]
121mod tests {
122	use std::iter;
123
124	use binius_field::Random;
125	use rand::{SeedableRng, rngs::StdRng};
126
127	use super::*;
128	use crate::test_utils::{B128, random_scalars};
129
130	#[test]
131	fn test_fold_cols_linear_in_matrix() {
132		let mut rng = StdRng::seed_from_u64(0);
133
134		// Matrix dimensions: 2^5 x 2^5 = 32 x 32 = 2^10 elements total
135		let log_rows = 5;
136		let log_cols = 5;
137		let total_elements = 1 << (log_rows + log_cols);
138
139		// Generate two random matrices
140		let mat0_values = random_scalars::<B128>(&mut rng, total_elements);
141		let mat0 = FieldBuffer::<B128>::from_values(&mat0_values);
142
143		let mat1_values = random_scalars::<B128>(&mut rng, total_elements);
144		let mat1 = FieldBuffer::<B128>::from_values(&mat1_values);
145
146		// Generate random vector
147		let vec_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
148		let vec = FieldBuffer::<B128>::from_values(&vec_values);
149
150		// Generate random scalars
151		let scalar0 = B128::random(&mut rng);
152		let scalar1 = B128::random(&mut rng);
153
154		// Compute left side: (scalar0 * mat0 + scalar1 * mat1) * vec
155		let scaled_mat_values: Vec<B128> = iter::zip(&mat0_values, &mat1_values)
156			.map(|(&m0, &m1)| scalar0 * m0 + scalar1 * m1)
157			.collect();
158		let scaled_mat = FieldBuffer::<B128>::from_values(&scaled_mat_values);
159		let left_side = fold_cols(&scaled_mat, &vec);
160
161		// Compute right side: scalar0 * (mat0 * vec) + scalar1 * (mat1 * vec)
162		let mat0_vec = fold_cols(&mat0, &vec);
163		let mat1_vec = fold_cols(&mat1, &vec);
164		let right_side_values: Vec<B128> = mat0_vec
165			.as_ref()
166			.iter()
167			.zip(mat1_vec.as_ref().iter())
168			.map(|(&m0v, &m1v)| scalar0 * m0v + scalar1 * m1v)
169			.collect();
170		let right_side = FieldBuffer::<B128>::from_values(&right_side_values);
171
172		// Compare results
173		assert_eq!(left_side.as_ref(), right_side.as_ref());
174	}
175
176	#[test]
177	fn test_fold_cols_linear_in_vector() {
178		let mut rng = StdRng::seed_from_u64(0);
179
180		// Matrix dimensions: 2^5 x 2^5 = 32 x 32 = 2^10 elements total
181		let log_rows = 5;
182		let log_cols = 5;
183		let total_elements = 1 << (log_rows + log_cols);
184
185		// Generate random matrix
186		let mat_values = random_scalars::<B128>(&mut rng, total_elements);
187		let mat = FieldBuffer::<B128>::from_values(&mat_values);
188
189		// Generate two random vectors
190		let vec0_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
191		let vec0 = FieldBuffer::<B128>::from_values(&vec0_values);
192
193		let vec1_values = random_scalars::<B128>(&mut rng, 1 << log_cols);
194		let vec1 = FieldBuffer::<B128>::from_values(&vec1_values);
195
196		// Generate random scalars
197		let scalar0 = B128::random(&mut rng);
198		let scalar1 = B128::random(&mut rng);
199
200		// Compute left side: mat * (scalar0 * vec0 + scalar1 * vec1)
201		let scaled_vec_values: Vec<B128> = vec0_values
202			.iter()
203			.zip(vec1_values.iter())
204			.map(|(&v0, &v1)| scalar0 * v0 + scalar1 * v1)
205			.collect();
206		let scaled_vec = FieldBuffer::<B128>::from_values(&scaled_vec_values);
207		let left_side = fold_cols(&mat, &scaled_vec);
208
209		// Compute right side: scalar0 * (mat * vec0) + scalar1 * (mat * vec1)
210		let mat_vec0 = fold_cols(&mat, &vec0);
211		let mat_vec1 = fold_cols(&mat, &vec1);
212		let right_side_values: Vec<B128> = mat_vec0
213			.as_ref()
214			.iter()
215			.zip(mat_vec1.as_ref().iter())
216			.map(|(&mv0, &mv1)| scalar0 * mv0 + scalar1 * mv1)
217			.collect();
218		let right_side = FieldBuffer::<B128>::from_values(&right_side_values);
219
220		// Compare results
221		assert_eq!(left_side.as_ref(), right_side.as_ref());
222	}
223
224	#[test]
225	fn test_fold_cols_equals_fold_rows_transpose() {
226		let mut rng = StdRng::seed_from_u64(0);
227
228		// Matrix dimensions: 2^5 x 2^5 = 32 x 32 = 2^10 elements total
229		let log_rows = 5;
230		let log_cols = 5;
231		let n_rows = 1 << log_rows;
232		let n_cols = 1 << log_cols;
233		let total_elements = n_rows * n_cols;
234
235		// Generate random matrix
236		let mat_values = random_scalars::<B128>(&mut rng, total_elements);
237		let mat = FieldBuffer::<B128>::from_values(&mat_values);
238
239		// Generate random vector
240		let vec_values = random_scalars::<B128>(&mut rng, n_cols);
241		let vec = FieldBuffer::<B128>::from_values(&vec_values);
242
243		// Compute fold_cols(mat, vec) which gives mat * vec
244		let fold_cols_result = fold_cols(&mat, &vec);
245
246		// Transpose the matrix: mat_T[j,i] = mat[i,j]
247		// Original matrix is row-major: mat[i,j] = mat_values[i * n_cols + j]
248		// Transposed matrix is row-major: mat_T[j,i] = mat_T_values[j * n_rows + i]
249		let mut mat_t_values = vec![B128::ZERO; total_elements];
250		for i in 0..n_rows {
251			for j in 0..n_cols {
252				let orig_idx = i * n_cols + j;
253				let trans_idx = j * n_rows + i;
254				mat_t_values[trans_idx] = mat_values[orig_idx];
255			}
256		}
257		let mat_t = FieldBuffer::<B128>::from_values(&mat_t_values);
258
259		// Compute fold_rows(mat_T, vec) which gives vec^T * mat^T
260		// This should equal mat * vec (same as fold_cols_result)
261		let fold_rows_result = fold_rows(&mat_t, &vec);
262
263		// Extract values from both results for comparison
264		let fold_cols_values: Vec<B128> = (0..fold_cols_result.len())
265			.map(|i| fold_cols_result.to_ref().get(i))
266			.collect();
267		let fold_rows_values: Vec<B128> = (0..fold_rows_result.len())
268			.map(|i| fold_rows_result.to_ref().get(i))
269			.collect();
270
271		// Compare results
272		assert_eq!(fold_cols_values, fold_rows_values);
273	}
274
275	#[test]
276	fn test_fold_cols_tensor_product() {
277		let mut rng = StdRng::seed_from_u64(0);
278
279		// Parameters: n = 10, m1 = 3, m2 = 4
280		let n = 10;
281		let m1 = 3;
282		let m2 = 4;
283
284		// Generate two vectors
285		let vec1_size = 1 << m1; // 2^3 = 8
286		let vec1_values = random_scalars::<B128>(&mut rng, vec1_size);
287		let vec1 = FieldBuffer::<B128>::from_values(&vec1_values);
288
289		let vec2_size = 1 << m2; // 2^4 = 16
290		let vec2_values = random_scalars::<B128>(&mut rng, vec2_size);
291		let vec2 = FieldBuffer::<B128>::from_values(&vec2_values);
292
293		// Compute tensor product of vec2 and vec1 (note the order!)
294		// The tensor product v2 ⊗ v1 has components (v2 ⊗ v1)[j*|v1| + i] = v2[j] * v1[i]
295		let tensor_product_size = vec2_size * vec1_size; // 16 * 8 = 128
296		let mut tensor_product_values = Vec::with_capacity(tensor_product_size);
297		for &v2 in vec2_values.iter() {
298			for &v1 in vec1_values.iter() {
299				tensor_product_values.push(v2 * v1);
300			}
301		}
302		let tensor_product = FieldBuffer::<B128>::from_values(&tensor_product_values);
303
304		// Generate a random matrix of size 2^n
305		let matrix_values = random_scalars::<B128>(&mut rng, 1 << n);
306		let matrix = FieldBuffer::<B128>::from_values(&matrix_values);
307
308		// Method 1: Sequential folding
309		// First fold: matrix is viewed as 2^(n-m1) x 2^m1
310		let intermediate = fold_cols(&matrix, &vec1);
311		// Second fold: intermediate is viewed as 2^(n-m1-m2) x 2^m2
312		let sequential_result = fold_cols(&intermediate, &vec2);
313
314		// Method 2: Direct tensor product folding
315		// Matrix is viewed as 2^(n-m1-m2) x 2^(m1+m2)
316		let direct_result = fold_cols(&matrix, &tensor_product);
317
318		// Compare results
319		assert_eq!(sequential_result.as_ref(), direct_result.as_ref());
320	}
321}