binius_math/
matrix.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	iter::repeat_with,
5	ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign},
6};
7
8use binius_field::{ExtensionField, Field};
9use binius_utils::bail;
10use bytemuck::zeroed_slice_box;
11use getset::CopyGetters;
12use rand::RngCore;
13
14use super::error::Error;
15
16/// A matrix over a field.
17#[derive(Debug, Clone, PartialEq, Eq, CopyGetters)]
18pub struct Matrix<F: Field> {
19	/// The number of rows.
20	#[getset(get_copy = "pub")]
21	m: usize,
22	/// The number of columns.
23	#[getset(get_copy = "pub")]
24	n: usize,
25	elements: Box<[F]>,
26}
27
28impl<F: Field> Matrix<F> {
29	pub fn new(m: usize, n: usize, elements: &[F]) -> Result<Self, Error> {
30		if elements.len() != m * n {
31			bail!(Error::IncorrectArgumentLength {
32				arg: "elements".into(),
33				expected: m * n,
34			});
35		}
36		Ok(Self {
37			m,
38			n,
39			elements: elements.into(),
40		})
41	}
42
43	pub fn zeros(m: usize, n: usize) -> Self {
44		Self {
45			m,
46			n,
47			elements: zeroed_slice_box(m * n),
48		}
49	}
50
51	pub fn identity(n: usize) -> Self {
52		let mut out = Self::zeros(n, n);
53		for i in 0..n {
54			out[(i, i)] = F::ONE;
55		}
56		out
57	}
58
59	fn fill_identity(&mut self) {
60		assert_eq!(self.m, self.n);
61		self.elements.fill(F::ZERO);
62		for i in 0..self.n {
63			self[(i, i)] = F::ONE;
64		}
65	}
66
67	pub const fn elements(&self) -> &[F] {
68		&self.elements
69	}
70
71	pub fn random(m: usize, n: usize, mut rng: impl RngCore) -> Self {
72		Self {
73			m,
74			n,
75			elements: repeat_with(|| F::random(&mut rng)).take(m * n).collect(),
76		}
77	}
78
79	pub const fn dim(&self) -> (usize, usize) {
80		(self.m, self.n)
81	}
82
83	pub fn copy_from(&mut self, other: &Self) {
84		assert_eq!(self.dim(), other.dim());
85		self.elements.copy_from_slice(&other.elements);
86	}
87
88	pub fn mul_into(a: &Self, b: &Self, c: &mut Self) {
89		assert_eq!(a.n(), b.m());
90		assert_eq!(a.m(), c.m());
91		assert_eq!(b.n(), c.n());
92
93		for i in 0..c.m() {
94			for j in 0..c.n() {
95				c[(i, j)] = (0..a.n()).map(|k| a[(i, k)] * b[(k, j)]).sum();
96			}
97		}
98	}
99
100	pub fn mul_vec_into<FE: ExtensionField<F>>(&self, x: &[FE], y: &mut [FE]) {
101		assert_eq!(self.n(), x.len());
102		assert_eq!(self.m(), y.len());
103
104		for i in 0..y.len() {
105			y[i] = (0..self.n()).map(|j| x[j] * self[(i, j)]).sum();
106		}
107	}
108
109	/// Invert a square matrix
110	///
111	/// ## Throws
112	///
113	/// * [`Error::MatrixNotSquare`]
114	/// * [`Error::MatrixIsSingular`]
115	///
116	/// ## Preconditions
117	///
118	/// * `out` - must have the same dimensions as `self`
119	pub fn inverse_into(&self, out: &mut Self) -> Result<(), Error> {
120		assert_eq!(self.dim(), out.dim());
121
122		if self.m != self.n {
123			bail!(Error::MatrixNotSquare);
124		}
125
126		let n = self.n;
127
128		let mut tmp = self.clone();
129		out.fill_identity();
130
131		let mut row_buffer = vec![F::ZERO; n];
132
133		for i in 0..n {
134			// Find the pivot row
135			let pivot = (i..n)
136				.find(|&pivot| tmp[(pivot, i)] != F::ZERO)
137				.ok_or(Error::MatrixIsSingular)?;
138			if pivot != i {
139				tmp.swap_rows(i, pivot, &mut row_buffer);
140				out.swap_rows(i, pivot, &mut row_buffer);
141			}
142
143			// Normalize the pivot
144			let scalar = tmp[(i, i)]
145				.invert()
146				.expect("pivot is checked to be non-zero above");
147			tmp.scale_row(i, scalar);
148			out.scale_row(i, scalar);
149
150			// Clear the pivot column
151			for j in (0..i).chain(i + 1..n) {
152				let scalar = tmp[(j, i)];
153				tmp.sub_pivot_row(j, i, scalar);
154				out.sub_pivot_row(j, i, scalar);
155			}
156		}
157
158		debug_assert_eq!(tmp, Self::identity(n));
159
160		Ok(())
161	}
162
163	fn row_ref(&self, i: usize) -> &[F] {
164		assert!(i < self.m);
165		&self.elements[i * self.n..(i + 1) * self.n]
166	}
167
168	fn row_mut(&mut self, i: usize) -> &mut [F] {
169		assert!(i < self.m);
170		&mut self.elements[i * self.n..(i + 1) * self.n]
171	}
172
173	fn swap_rows(&mut self, i0: usize, i1: usize, buffer: &mut [F]) {
174		assert!(i0 < self.m);
175		assert!(i1 < self.m);
176		assert_eq!(buffer.len(), self.n);
177
178		if i0 == i1 {
179			return;
180		}
181
182		buffer.copy_from_slice(self.row_ref(i1));
183		self.elements
184			.copy_within(i0 * self.n..(i0 + 1) * self.n, i1 * self.n);
185		self.row_mut(i0).copy_from_slice(buffer);
186	}
187
188	fn scale_row(&mut self, i: usize, scalar: F) {
189		for x in self.row_mut(i) {
190			*x *= scalar;
191		}
192	}
193
194	fn sub_pivot_row(&mut self, i0: usize, i1: usize, scalar: F) {
195		assert!(i0 < self.m);
196		assert!(i1 < self.m);
197
198		for j in 0..self.n {
199			let x = self[(i1, j)];
200			self[(i0, j)] -= x * scalar;
201		}
202	}
203}
204
205impl<F: Field> Index<(usize, usize)> for Matrix<F> {
206	type Output = F;
207
208	fn index(&self, index: (usize, usize)) -> &Self::Output {
209		let (i, j) = index;
210		assert!(i < self.m);
211		assert!(j < self.n);
212		&self.elements[i * self.n + j]
213	}
214}
215
216impl<F: Field> IndexMut<(usize, usize)> for Matrix<F> {
217	fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
218		let (i, j) = index;
219		assert!(i < self.m);
220		assert!(j < self.n);
221		&mut self.elements[i * self.n + j]
222	}
223}
224
225impl<F: Field> Add<Self> for &Matrix<F> {
226	type Output = Matrix<F>;
227
228	fn add(self, rhs: Self) -> Matrix<F> {
229		let mut out = self.clone();
230		out += rhs;
231		out
232	}
233}
234
235impl<F: Field> Sub<Self> for &Matrix<F> {
236	type Output = Matrix<F>;
237
238	fn sub(self, rhs: Self) -> Matrix<F> {
239		let mut out = self.clone();
240		out -= rhs;
241		out
242	}
243}
244
245impl<F: Field> AddAssign<&Self> for Matrix<F> {
246	fn add_assign(&mut self, rhs: &Self) {
247		assert_eq!(self.dim(), rhs.dim());
248		for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
249			*a_ij += b_ij;
250		}
251	}
252}
253
254impl<F: Field> SubAssign<&Self> for Matrix<F> {
255	fn sub_assign(&mut self, rhs: &Self) {
256		assert_eq!(self.dim(), rhs.dim());
257		for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
258			*a_ij -= b_ij;
259		}
260	}
261}
262
263#[cfg(test)]
264mod tests {
265	use binius_field::BinaryField32b;
266	use proptest::prelude::*;
267	use rand::{rngs::StdRng, SeedableRng};
268
269	use super::*;
270
271	proptest! {
272		#[test]
273		fn test_left_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
274			type F = BinaryField32b;
275
276			let mut rng = StdRng::seed_from_u64(0);
277			let a0 = Matrix::<F>::random(c_m, a_n, &mut rng);
278			let a1 = Matrix::<F>::random(c_m, a_n, &mut rng);
279			let b = Matrix::<F>::random(a_n, c_n, &mut rng);
280			let mut c0 = Matrix::<F>::zeros(c_m, c_n);
281			let mut c1 = Matrix::<F>::zeros(c_m, c_n);
282
283			let a0p1 = &a0 + &a1;
284			let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
285
286			Matrix::mul_into(&a0, &b, &mut c0);
287			Matrix::mul_into(&a1, &b, &mut c1);
288			Matrix::mul_into(&a0p1, &b, &mut c0p1);
289
290			assert_eq!(c0p1, &c0 + &c1);
291		}
292
293		#[test]
294		fn test_right_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
295			type F = BinaryField32b;
296
297			let mut rng = StdRng::seed_from_u64(0);
298			let a = Matrix::<F>::random(c_m, a_n, &mut rng);
299			let b0 = Matrix::<F>::random(a_n, c_n, &mut rng);
300			let b1 = Matrix::<F>::random(a_n, c_n, &mut rng);
301			let mut c0 = Matrix::<F>::zeros(c_m, c_n);
302			let mut c1 = Matrix::<F>::zeros(c_m, c_n);
303
304			let b0p1 = &b0 + &b1;
305			let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
306
307			Matrix::mul_into(&a, &b0, &mut c0);
308			Matrix::mul_into(&a, &b1, &mut c1);
309			Matrix::mul_into(&a, &b0p1, &mut c0p1);
310
311			assert_eq!(c0p1, &c0 + &c1);
312		}
313
314		#[test]
315		fn test_double_inverse(n in 0..8usize) {
316			type F = BinaryField32b;
317
318			let mut rng = StdRng::seed_from_u64(0);
319			let a = Matrix::<F>::random(n, n, &mut rng);
320			let mut a_inv = Matrix::<F>::zeros(n, n);
321			let mut a_inv_inv = Matrix::<F>::zeros(n, n);
322
323			a.inverse_into(&mut a_inv).unwrap();
324			a_inv.inverse_into(&mut a_inv_inv).unwrap();
325			assert_eq!(a_inv_inv, a);
326		}
327
328		#[test]
329		fn test_inverse(n in 0..8usize) {
330			type F = BinaryField32b;
331
332			let mut rng = StdRng::seed_from_u64(0);
333			let a = Matrix::<F>::random(n, n, &mut rng);
334			let mut a_inv = Matrix::<F>::zeros(n, n);
335			let mut prod = Matrix::<F>::zeros(n, n);
336
337			a.inverse_into(&mut a_inv).unwrap();
338
339			Matrix::mul_into(&a, &a_inv, &mut prod);
340			assert_eq!(prod, Matrix::<F>::identity(n));
341
342			Matrix::mul_into(&a_inv, &a, &mut prod);
343			assert_eq!(prod, Matrix::<F>::identity(n));
344		}
345	}
346}