binius_math/
matrix.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter::repeat_with,
5	ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign},
6};
7
8use binius_field::{ExtensionField, Field};
9use bytemuck::zeroed_slice_box;
10use getset::CopyGetters;
11use rand::RngCore;
12
13/// A matrix over a field.
14#[derive(Debug, Clone, PartialEq, Eq, CopyGetters)]
15pub struct Matrix<F: Field> {
16	/// The number of rows.
17	#[getset(get_copy = "pub")]
18	m: usize,
19	/// The number of columns.
20	#[getset(get_copy = "pub")]
21	n: usize,
22	elements: Box<[F]>,
23}
24
25impl<F: Field> Matrix<F> {
26	pub fn new(m: usize, n: usize, elements: &[F]) -> Self {
27		assert_eq!(elements.len(), m * n, "precondition: elements length must equal m * n");
28		Self {
29			m,
30			n,
31			elements: elements.into(),
32		}
33	}
34
35	pub fn zeros(m: usize, n: usize) -> Self {
36		Self {
37			m,
38			n,
39			elements: zeroed_slice_box(m * n),
40		}
41	}
42
43	pub fn identity(n: usize) -> Self {
44		let mut out = Self::zeros(n, n);
45		for i in 0..n {
46			out[(i, i)] = F::ONE;
47		}
48		out
49	}
50
51	fn fill_identity(&mut self) {
52		assert_eq!(self.m, self.n);
53		self.elements.fill(F::ZERO);
54		for i in 0..self.n {
55			self[(i, i)] = F::ONE;
56		}
57	}
58
59	pub const fn elements(&self) -> &[F] {
60		&self.elements
61	}
62
63	pub fn random(m: usize, n: usize, mut rng: impl RngCore) -> Self {
64		Self {
65			m,
66			n,
67			elements: repeat_with(|| F::random(&mut rng)).take(m * n).collect(),
68		}
69	}
70
71	pub const fn dim(&self) -> (usize, usize) {
72		(self.m, self.n)
73	}
74
75	pub fn copy_from(&mut self, other: &Self) {
76		assert_eq!(self.dim(), other.dim());
77		self.elements.copy_from_slice(&other.elements);
78	}
79
80	pub fn mul_into(a: &Self, b: &Self, c: &mut Self) {
81		assert_eq!(a.n(), b.m());
82		assert_eq!(a.m(), c.m());
83		assert_eq!(b.n(), c.n());
84
85		for i in 0..c.m() {
86			for j in 0..c.n() {
87				c[(i, j)] = (0..a.n()).map(|k| a[(i, k)] * b[(k, j)]).sum();
88			}
89		}
90	}
91
92	pub fn mul_vec_into<FE: ExtensionField<F>>(&self, x: &[FE], y: &mut [FE]) {
93		assert_eq!(self.n(), x.len());
94		assert_eq!(self.m(), y.len());
95
96		for i in 0..y.len() {
97			y[i] = (0..self.n()).map(|j| x[j] * self[(i, j)]).sum();
98		}
99	}
100
101	/// Invert a square matrix
102	///
103	/// ## Preconditions
104	///
105	/// * Matrix must be square (m == n)
106	/// * Matrix must be non-singular (invertible)
107	/// * `out` must have the same dimensions as `self`
108	pub fn inverse_into(&self, out: &mut Self) {
109		assert_eq!(self.dim(), out.dim());
110		assert_eq!(self.m, self.n, "precondition: matrix must be square");
111
112		let n = self.n;
113
114		let mut tmp = self.clone();
115		out.fill_identity();
116
117		let mut row_buffer = vec![F::ZERO; n];
118
119		for i in 0..n {
120			// Find the pivot row
121			let pivot = (i..n)
122				.find(|&pivot| tmp[(pivot, i)] != F::ZERO)
123				.expect("precondition: matrix must be non-singular");
124			if pivot != i {
125				tmp.swap_rows(i, pivot, &mut row_buffer);
126				out.swap_rows(i, pivot, &mut row_buffer);
127			}
128
129			// Normalize the pivot
130			let scalar = tmp[(i, i)]
131				.invert()
132				.expect("pivot is checked to be non-zero above");
133			tmp.scale_row(i, scalar);
134			out.scale_row(i, scalar);
135
136			// Clear the pivot column
137			for j in (0..i).chain(i + 1..n) {
138				let scalar = tmp[(j, i)];
139				tmp.sub_pivot_row(j, i, scalar);
140				out.sub_pivot_row(j, i, scalar);
141			}
142		}
143
144		debug_assert_eq!(tmp, Self::identity(n));
145	}
146
147	fn row_ref(&self, i: usize) -> &[F] {
148		assert!(i < self.m);
149		&self.elements[i * self.n..(i + 1) * self.n]
150	}
151
152	fn row_mut(&mut self, i: usize) -> &mut [F] {
153		assert!(i < self.m);
154		&mut self.elements[i * self.n..(i + 1) * self.n]
155	}
156
157	fn swap_rows(&mut self, i0: usize, i1: usize, buffer: &mut [F]) {
158		assert!(i0 < self.m);
159		assert!(i1 < self.m);
160		assert_eq!(buffer.len(), self.n);
161
162		if i0 == i1 {
163			return;
164		}
165
166		buffer.copy_from_slice(self.row_ref(i1));
167		self.elements
168			.copy_within(i0 * self.n..(i0 + 1) * self.n, i1 * self.n);
169		self.row_mut(i0).copy_from_slice(buffer);
170	}
171
172	fn scale_row(&mut self, i: usize, scalar: F) {
173		for x in self.row_mut(i) {
174			*x *= scalar;
175		}
176	}
177
178	fn sub_pivot_row(&mut self, i0: usize, i1: usize, scalar: F) {
179		assert!(i0 < self.m);
180		assert!(i1 < self.m);
181
182		for j in 0..self.n {
183			let x = self[(i1, j)];
184			self[(i0, j)] -= x * scalar;
185		}
186	}
187}
188
189impl<F: Field> Index<(usize, usize)> for Matrix<F> {
190	type Output = F;
191
192	fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
193		assert!(i < self.m);
194		assert!(j < self.n);
195		&self.elements[i * self.n + j]
196	}
197}
198
199impl<F: Field> IndexMut<(usize, usize)> for Matrix<F> {
200	fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
201		assert!(i < self.m);
202		assert!(j < self.n);
203		&mut self.elements[i * self.n + j]
204	}
205}
206
207impl<F: Field> Add<Self> for &Matrix<F> {
208	type Output = Matrix<F>;
209
210	fn add(self, rhs: Self) -> Matrix<F> {
211		let mut out = self.clone();
212		out += rhs;
213		out
214	}
215}
216
217impl<F: Field> Sub<Self> for &Matrix<F> {
218	type Output = Matrix<F>;
219
220	fn sub(self, rhs: Self) -> Matrix<F> {
221		let mut out = self.clone();
222		out -= rhs;
223		out
224	}
225}
226
227impl<F: Field> AddAssign<&Self> for Matrix<F> {
228	fn add_assign(&mut self, rhs: &Self) {
229		assert_eq!(self.dim(), rhs.dim());
230		for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
231			*a_ij += b_ij;
232		}
233	}
234}
235
236impl<F: Field> SubAssign<&Self> for Matrix<F> {
237	fn sub_assign(&mut self, rhs: &Self) {
238		assert_eq!(self.dim(), rhs.dim());
239		for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
240			*a_ij -= b_ij;
241		}
242	}
243}
244
245#[cfg(test)]
246mod tests {
247	use proptest::prelude::*;
248	use rand::{SeedableRng, rngs::StdRng};
249
250	use super::*;
251	use crate::test_utils::B128;
252
253	type F = B128;
254
255	proptest! {
256		#[test]
257		fn test_left_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
258			let mut rng = StdRng::seed_from_u64(0);
259			let a0 = Matrix::<F>::random(c_m, a_n, &mut rng);
260			let a1 = Matrix::<F>::random(c_m, a_n, &mut rng);
261			let b = Matrix::<F>::random(a_n, c_n, &mut rng);
262			let mut c0 = Matrix::<F>::zeros(c_m, c_n);
263			let mut c1 = Matrix::<F>::zeros(c_m, c_n);
264
265			let a0p1 = &a0 + &a1;
266			let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
267
268			Matrix::mul_into(&a0, &b, &mut c0);
269			Matrix::mul_into(&a1, &b, &mut c1);
270			Matrix::mul_into(&a0p1, &b, &mut c0p1);
271
272			assert_eq!(c0p1, &c0 + &c1);
273		}
274
275		#[test]
276		fn test_right_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
277			let mut rng = StdRng::seed_from_u64(0);
278			let a = Matrix::<F>::random(c_m, a_n, &mut rng);
279			let b0 = Matrix::<F>::random(a_n, c_n, &mut rng);
280			let b1 = Matrix::<F>::random(a_n, c_n, &mut rng);
281			let mut c0 = Matrix::<F>::zeros(c_m, c_n);
282			let mut c1 = Matrix::<F>::zeros(c_m, c_n);
283
284			let b0p1 = &b0 + &b1;
285			let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
286
287			Matrix::mul_into(&a, &b0, &mut c0);
288			Matrix::mul_into(&a, &b1, &mut c1);
289			Matrix::mul_into(&a, &b0p1, &mut c0p1);
290
291			assert_eq!(c0p1, &c0 + &c1);
292		}
293
294		#[test]
295		fn test_double_inverse(n in 0..8usize) {
296			let mut rng = StdRng::seed_from_u64(0);
297			let a = Matrix::<F>::random(n, n, &mut rng);
298			let mut a_inv = Matrix::<F>::zeros(n, n);
299			let mut a_inv_inv = Matrix::<F>::zeros(n, n);
300
301			a.inverse_into(&mut a_inv);
302			a_inv.inverse_into(&mut a_inv_inv);
303			assert_eq!(a_inv_inv, a);
304		}
305
306		#[test]
307		fn test_inverse(n in 0..8usize) {
308			let mut rng = StdRng::seed_from_u64(0);
309			let a = Matrix::<F>::random(n, n, &mut rng);
310			let mut a_inv = Matrix::<F>::zeros(n, n);
311			let mut prod = Matrix::<F>::zeros(n, n);
312
313			a.inverse_into(&mut a_inv);
314
315			Matrix::mul_into(&a, &a_inv, &mut prod);
316			assert_eq!(prod, Matrix::<F>::identity(n));
317
318			Matrix::mul_into(&a_inv, &a, &mut prod);
319			assert_eq!(prod, Matrix::<F>::identity(n));
320		}
321	}
322}