binius_math/
matrix.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter::repeat_with,
5	ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign},
6};
7
8use binius_field::{ExtensionField, Field};
9use bytemuck::zeroed_slice_box;
10use getset::CopyGetters;
11use rand::RngCore;
12
13use super::error::Error;
14
15/// A matrix over a field.
16#[derive(Debug, Clone, PartialEq, Eq, CopyGetters)]
17pub struct Matrix<F: Field> {
18	/// The number of rows.
19	#[getset(get_copy = "pub")]
20	m: usize,
21	/// The number of columns.
22	#[getset(get_copy = "pub")]
23	n: usize,
24	elements: Box<[F]>,
25}
26
27impl<F: Field> Matrix<F> {
28	pub fn new(m: usize, n: usize, elements: &[F]) -> Result<Self, Error> {
29		if elements.len() != m * n {
30			return Err(Error::IncorrectArgumentLength {
31				arg: "elements".into(),
32				expected: m * n,
33			});
34		}
35		Ok(Self {
36			m,
37			n,
38			elements: elements.into(),
39		})
40	}
41
42	pub fn zeros(m: usize, n: usize) -> Self {
43		Self {
44			m,
45			n,
46			elements: zeroed_slice_box(m * n),
47		}
48	}
49
50	pub fn identity(n: usize) -> Self {
51		let mut out = Self::zeros(n, n);
52		for i in 0..n {
53			out[(i, i)] = F::ONE;
54		}
55		out
56	}
57
58	fn fill_identity(&mut self) {
59		assert_eq!(self.m, self.n);
60		self.elements.fill(F::ZERO);
61		for i in 0..self.n {
62			self[(i, i)] = F::ONE;
63		}
64	}
65
66	pub const fn elements(&self) -> &[F] {
67		&self.elements
68	}
69
70	pub fn random(m: usize, n: usize, mut rng: impl RngCore) -> Self {
71		Self {
72			m,
73			n,
74			elements: repeat_with(|| F::random(&mut rng)).take(m * n).collect(),
75		}
76	}
77
78	pub const fn dim(&self) -> (usize, usize) {
79		(self.m, self.n)
80	}
81
82	pub fn copy_from(&mut self, other: &Self) {
83		assert_eq!(self.dim(), other.dim());
84		self.elements.copy_from_slice(&other.elements);
85	}
86
87	pub fn mul_into(a: &Self, b: &Self, c: &mut Self) {
88		assert_eq!(a.n(), b.m());
89		assert_eq!(a.m(), c.m());
90		assert_eq!(b.n(), c.n());
91
92		for i in 0..c.m() {
93			for j in 0..c.n() {
94				c[(i, j)] = (0..a.n()).map(|k| a[(i, k)] * b[(k, j)]).sum();
95			}
96		}
97	}
98
99	pub fn mul_vec_into<FE: ExtensionField<F>>(&self, x: &[FE], y: &mut [FE]) {
100		assert_eq!(self.n(), x.len());
101		assert_eq!(self.m(), y.len());
102
103		for i in 0..y.len() {
104			y[i] = (0..self.n()).map(|j| x[j] * self[(i, j)]).sum();
105		}
106	}
107
108	/// Invert a square matrix
109	///
110	/// ## Throws
111	///
112	/// * [`Error::MatrixNotSquare`]
113	/// * [`Error::MatrixIsSingular`]
114	///
115	/// ## Preconditions
116	///
117	/// * `out` - must have the same dimensions as `self`
118	pub fn inverse_into(&self, out: &mut Self) -> Result<(), Error> {
119		assert_eq!(self.dim(), out.dim());
120
121		if self.m != self.n {
122			return Err(Error::MatrixNotSquare);
123		}
124
125		let n = self.n;
126
127		let mut tmp = self.clone();
128		out.fill_identity();
129
130		let mut row_buffer = vec![F::ZERO; n];
131
132		for i in 0..n {
133			// Find the pivot row
134			let pivot = (i..n)
135				.find(|&pivot| tmp[(pivot, i)] != F::ZERO)
136				.ok_or(Error::MatrixIsSingular)?;
137			if pivot != i {
138				tmp.swap_rows(i, pivot, &mut row_buffer);
139				out.swap_rows(i, pivot, &mut row_buffer);
140			}
141
142			// Normalize the pivot
143			let scalar = tmp[(i, i)]
144				.invert()
145				.expect("pivot is checked to be non-zero above");
146			tmp.scale_row(i, scalar);
147			out.scale_row(i, scalar);
148
149			// Clear the pivot column
150			for j in (0..i).chain(i + 1..n) {
151				let scalar = tmp[(j, i)];
152				tmp.sub_pivot_row(j, i, scalar);
153				out.sub_pivot_row(j, i, scalar);
154			}
155		}
156
157		debug_assert_eq!(tmp, Self::identity(n));
158
159		Ok(())
160	}
161
162	fn row_ref(&self, i: usize) -> &[F] {
163		assert!(i < self.m);
164		&self.elements[i * self.n..(i + 1) * self.n]
165	}
166
167	fn row_mut(&mut self, i: usize) -> &mut [F] {
168		assert!(i < self.m);
169		&mut self.elements[i * self.n..(i + 1) * self.n]
170	}
171
172	fn swap_rows(&mut self, i0: usize, i1: usize, buffer: &mut [F]) {
173		assert!(i0 < self.m);
174		assert!(i1 < self.m);
175		assert_eq!(buffer.len(), self.n);
176
177		if i0 == i1 {
178			return;
179		}
180
181		buffer.copy_from_slice(self.row_ref(i1));
182		self.elements
183			.copy_within(i0 * self.n..(i0 + 1) * self.n, i1 * self.n);
184		self.row_mut(i0).copy_from_slice(buffer);
185	}
186
187	fn scale_row(&mut self, i: usize, scalar: F) {
188		for x in self.row_mut(i) {
189			*x *= scalar;
190		}
191	}
192
193	fn sub_pivot_row(&mut self, i0: usize, i1: usize, scalar: F) {
194		assert!(i0 < self.m);
195		assert!(i1 < self.m);
196
197		for j in 0..self.n {
198			let x = self[(i1, j)];
199			self[(i0, j)] -= x * scalar;
200		}
201	}
202}
203
204impl<F: Field> Index<(usize, usize)> for Matrix<F> {
205	type Output = F;
206
207	fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
208		assert!(i < self.m);
209		assert!(j < self.n);
210		&self.elements[i * self.n + j]
211	}
212}
213
214impl<F: Field> IndexMut<(usize, usize)> for Matrix<F> {
215	fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
216		assert!(i < self.m);
217		assert!(j < self.n);
218		&mut self.elements[i * self.n + j]
219	}
220}
221
222impl<F: Field> Add<Self> for &Matrix<F> {
223	type Output = Matrix<F>;
224
225	fn add(self, rhs: Self) -> Matrix<F> {
226		let mut out = self.clone();
227		out += rhs;
228		out
229	}
230}
231
232impl<F: Field> Sub<Self> for &Matrix<F> {
233	type Output = Matrix<F>;
234
235	fn sub(self, rhs: Self) -> Matrix<F> {
236		let mut out = self.clone();
237		out -= rhs;
238		out
239	}
240}
241
242impl<F: Field> AddAssign<&Self> for Matrix<F> {
243	fn add_assign(&mut self, rhs: &Self) {
244		assert_eq!(self.dim(), rhs.dim());
245		for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
246			*a_ij += b_ij;
247		}
248	}
249}
250
251impl<F: Field> SubAssign<&Self> for Matrix<F> {
252	fn sub_assign(&mut self, rhs: &Self) {
253		assert_eq!(self.dim(), rhs.dim());
254		for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
255			*a_ij -= b_ij;
256		}
257	}
258}
259
260#[cfg(test)]
261mod tests {
262	use proptest::prelude::*;
263	use rand::{SeedableRng, rngs::StdRng};
264
265	use super::*;
266	use crate::test_utils::B128;
267
268	type F = B128;
269
270	proptest! {
271		#[test]
272		fn test_left_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
273			let mut rng = StdRng::seed_from_u64(0);
274			let a0 = Matrix::<F>::random(c_m, a_n, &mut rng);
275			let a1 = Matrix::<F>::random(c_m, a_n, &mut rng);
276			let b = Matrix::<F>::random(a_n, c_n, &mut rng);
277			let mut c0 = Matrix::<F>::zeros(c_m, c_n);
278			let mut c1 = Matrix::<F>::zeros(c_m, c_n);
279
280			let a0p1 = &a0 + &a1;
281			let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
282
283			Matrix::mul_into(&a0, &b, &mut c0);
284			Matrix::mul_into(&a1, &b, &mut c1);
285			Matrix::mul_into(&a0p1, &b, &mut c0p1);
286
287			assert_eq!(c0p1, &c0 + &c1);
288		}
289
290		#[test]
291		fn test_right_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
292			let mut rng = StdRng::seed_from_u64(0);
293			let a = Matrix::<F>::random(c_m, a_n, &mut rng);
294			let b0 = Matrix::<F>::random(a_n, c_n, &mut rng);
295			let b1 = Matrix::<F>::random(a_n, c_n, &mut rng);
296			let mut c0 = Matrix::<F>::zeros(c_m, c_n);
297			let mut c1 = Matrix::<F>::zeros(c_m, c_n);
298
299			let b0p1 = &b0 + &b1;
300			let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
301
302			Matrix::mul_into(&a, &b0, &mut c0);
303			Matrix::mul_into(&a, &b1, &mut c1);
304			Matrix::mul_into(&a, &b0p1, &mut c0p1);
305
306			assert_eq!(c0p1, &c0 + &c1);
307		}
308
309		#[test]
310		fn test_double_inverse(n in 0..8usize) {
311			let mut rng = StdRng::seed_from_u64(0);
312			let a = Matrix::<F>::random(n, n, &mut rng);
313			let mut a_inv = Matrix::<F>::zeros(n, n);
314			let mut a_inv_inv = Matrix::<F>::zeros(n, n);
315
316			a.inverse_into(&mut a_inv).unwrap();
317			a_inv.inverse_into(&mut a_inv_inv).unwrap();
318			assert_eq!(a_inv_inv, a);
319		}
320
321		#[test]
322		fn test_inverse(n in 0..8usize) {
323			let mut rng = StdRng::seed_from_u64(0);
324			let a = Matrix::<F>::random(n, n, &mut rng);
325			let mut a_inv = Matrix::<F>::zeros(n, n);
326			let mut prod = Matrix::<F>::zeros(n, n);
327
328			a.inverse_into(&mut a_inv).unwrap();
329
330			Matrix::mul_into(&a, &a_inv, &mut prod);
331			assert_eq!(prod, Matrix::<F>::identity(n));
332
333			Matrix::mul_into(&a_inv, &a, &mut prod);
334			assert_eq!(prod, Matrix::<F>::identity(n));
335		}
336	}
337}