1use std::{
4 iter::repeat_with,
5 ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign},
6};
7
8use binius_field::{ExtensionField, Field};
9use binius_utils::bail;
10use bytemuck::zeroed_slice_box;
11use getset::CopyGetters;
12use rand::RngCore;
13
14use super::error::Error;
15
16#[derive(Debug, Clone, PartialEq, Eq, CopyGetters)]
18pub struct Matrix<F: Field> {
19 #[getset(get_copy = "pub")]
21 m: usize,
22 #[getset(get_copy = "pub")]
24 n: usize,
25 elements: Box<[F]>,
26}
27
28impl<F: Field> Matrix<F> {
29 pub fn new(m: usize, n: usize, elements: &[F]) -> Result<Self, Error> {
30 if elements.len() != m * n {
31 bail!(Error::IncorrectArgumentLength {
32 arg: "elements".into(),
33 expected: m * n,
34 });
35 }
36 Ok(Self {
37 m,
38 n,
39 elements: elements.into(),
40 })
41 }
42
43 pub fn zeros(m: usize, n: usize) -> Self {
44 Self {
45 m,
46 n,
47 elements: zeroed_slice_box(m * n),
48 }
49 }
50
51 pub fn identity(n: usize) -> Self {
52 let mut out = Self::zeros(n, n);
53 for i in 0..n {
54 out[(i, i)] = F::ONE;
55 }
56 out
57 }
58
59 fn fill_identity(&mut self) {
60 assert_eq!(self.m, self.n);
61 self.elements.fill(F::ZERO);
62 for i in 0..self.n {
63 self[(i, i)] = F::ONE;
64 }
65 }
66
67 pub const fn elements(&self) -> &[F] {
68 &self.elements
69 }
70
71 pub fn random(m: usize, n: usize, mut rng: impl RngCore) -> Self {
72 Self {
73 m,
74 n,
75 elements: repeat_with(|| F::random(&mut rng)).take(m * n).collect(),
76 }
77 }
78
79 pub const fn dim(&self) -> (usize, usize) {
80 (self.m, self.n)
81 }
82
83 pub fn copy_from(&mut self, other: &Self) {
84 assert_eq!(self.dim(), other.dim());
85 self.elements.copy_from_slice(&other.elements);
86 }
87
88 pub fn mul_into(a: &Self, b: &Self, c: &mut Self) {
89 assert_eq!(a.n(), b.m());
90 assert_eq!(a.m(), c.m());
91 assert_eq!(b.n(), c.n());
92
93 for i in 0..c.m() {
94 for j in 0..c.n() {
95 c[(i, j)] = (0..a.n()).map(|k| a[(i, k)] * b[(k, j)]).sum();
96 }
97 }
98 }
99
100 pub fn mul_vec_into<FE: ExtensionField<F>>(&self, x: &[FE], y: &mut [FE]) {
101 assert_eq!(self.n(), x.len());
102 assert_eq!(self.m(), y.len());
103
104 for i in 0..y.len() {
105 y[i] = (0..self.n()).map(|j| x[j] * self[(i, j)]).sum();
106 }
107 }
108
109 pub fn inverse_into(&self, out: &mut Self) -> Result<(), Error> {
120 assert_eq!(self.dim(), out.dim());
121
122 if self.m != self.n {
123 bail!(Error::MatrixNotSquare);
124 }
125
126 let n = self.n;
127
128 let mut tmp = self.clone();
129 out.fill_identity();
130
131 let mut row_buffer = vec![F::ZERO; n];
132
133 for i in 0..n {
134 let pivot = (i..n)
136 .find(|&pivot| tmp[(pivot, i)] != F::ZERO)
137 .ok_or(Error::MatrixIsSingular)?;
138 if pivot != i {
139 tmp.swap_rows(i, pivot, &mut row_buffer);
140 out.swap_rows(i, pivot, &mut row_buffer);
141 }
142
143 let scalar = tmp[(i, i)]
145 .invert()
146 .expect("pivot is checked to be non-zero above");
147 tmp.scale_row(i, scalar);
148 out.scale_row(i, scalar);
149
150 for j in (0..i).chain(i + 1..n) {
152 let scalar = tmp[(j, i)];
153 tmp.sub_pivot_row(j, i, scalar);
154 out.sub_pivot_row(j, i, scalar);
155 }
156 }
157
158 debug_assert_eq!(tmp, Self::identity(n));
159
160 Ok(())
161 }
162
163 fn row_ref(&self, i: usize) -> &[F] {
164 assert!(i < self.m);
165 &self.elements[i * self.n..(i + 1) * self.n]
166 }
167
168 fn row_mut(&mut self, i: usize) -> &mut [F] {
169 assert!(i < self.m);
170 &mut self.elements[i * self.n..(i + 1) * self.n]
171 }
172
173 fn swap_rows(&mut self, i0: usize, i1: usize, buffer: &mut [F]) {
174 assert!(i0 < self.m);
175 assert!(i1 < self.m);
176 assert_eq!(buffer.len(), self.n);
177
178 if i0 == i1 {
179 return;
180 }
181
182 buffer.copy_from_slice(self.row_ref(i1));
183 self.elements
184 .copy_within(i0 * self.n..(i0 + 1) * self.n, i1 * self.n);
185 self.row_mut(i0).copy_from_slice(buffer);
186 }
187
188 fn scale_row(&mut self, i: usize, scalar: F) {
189 for x in self.row_mut(i) {
190 *x *= scalar;
191 }
192 }
193
194 fn sub_pivot_row(&mut self, i0: usize, i1: usize, scalar: F) {
195 assert!(i0 < self.m);
196 assert!(i1 < self.m);
197
198 for j in 0..self.n {
199 let x = self[(i1, j)];
200 self[(i0, j)] -= x * scalar;
201 }
202 }
203}
204
205impl<F: Field> Index<(usize, usize)> for Matrix<F> {
206 type Output = F;
207
208 fn index(&self, index: (usize, usize)) -> &Self::Output {
209 let (i, j) = index;
210 assert!(i < self.m);
211 assert!(j < self.n);
212 &self.elements[i * self.n + j]
213 }
214}
215
216impl<F: Field> IndexMut<(usize, usize)> for Matrix<F> {
217 fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
218 let (i, j) = index;
219 assert!(i < self.m);
220 assert!(j < self.n);
221 &mut self.elements[i * self.n + j]
222 }
223}
224
225impl<F: Field> Add<Self> for &Matrix<F> {
226 type Output = Matrix<F>;
227
228 fn add(self, rhs: Self) -> Matrix<F> {
229 let mut out = self.clone();
230 out += rhs;
231 out
232 }
233}
234
235impl<F: Field> Sub<Self> for &Matrix<F> {
236 type Output = Matrix<F>;
237
238 fn sub(self, rhs: Self) -> Matrix<F> {
239 let mut out = self.clone();
240 out -= rhs;
241 out
242 }
243}
244
245impl<F: Field> AddAssign<&Self> for Matrix<F> {
246 fn add_assign(&mut self, rhs: &Self) {
247 assert_eq!(self.dim(), rhs.dim());
248 for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
249 *a_ij += b_ij;
250 }
251 }
252}
253
254impl<F: Field> SubAssign<&Self> for Matrix<F> {
255 fn sub_assign(&mut self, rhs: &Self) {
256 assert_eq!(self.dim(), rhs.dim());
257 for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
258 *a_ij -= b_ij;
259 }
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use binius_field::BinaryField32b;
266 use proptest::prelude::*;
267 use rand::{rngs::StdRng, SeedableRng};
268
269 use super::*;
270
271 proptest! {
272 #[test]
273 fn test_left_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
274 type F = BinaryField32b;
275
276 let mut rng = StdRng::seed_from_u64(0);
277 let a0 = Matrix::<F>::random(c_m, a_n, &mut rng);
278 let a1 = Matrix::<F>::random(c_m, a_n, &mut rng);
279 let b = Matrix::<F>::random(a_n, c_n, &mut rng);
280 let mut c0 = Matrix::<F>::zeros(c_m, c_n);
281 let mut c1 = Matrix::<F>::zeros(c_m, c_n);
282
283 let a0p1 = &a0 + &a1;
284 let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
285
286 Matrix::mul_into(&a0, &b, &mut c0);
287 Matrix::mul_into(&a1, &b, &mut c1);
288 Matrix::mul_into(&a0p1, &b, &mut c0p1);
289
290 assert_eq!(c0p1, &c0 + &c1);
291 }
292
293 #[test]
294 fn test_right_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
295 type F = BinaryField32b;
296
297 let mut rng = StdRng::seed_from_u64(0);
298 let a = Matrix::<F>::random(c_m, a_n, &mut rng);
299 let b0 = Matrix::<F>::random(a_n, c_n, &mut rng);
300 let b1 = Matrix::<F>::random(a_n, c_n, &mut rng);
301 let mut c0 = Matrix::<F>::zeros(c_m, c_n);
302 let mut c1 = Matrix::<F>::zeros(c_m, c_n);
303
304 let b0p1 = &b0 + &b1;
305 let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
306
307 Matrix::mul_into(&a, &b0, &mut c0);
308 Matrix::mul_into(&a, &b1, &mut c1);
309 Matrix::mul_into(&a, &b0p1, &mut c0p1);
310
311 assert_eq!(c0p1, &c0 + &c1);
312 }
313
314 #[test]
315 fn test_double_inverse(n in 0..8usize) {
316 type F = BinaryField32b;
317
318 let mut rng = StdRng::seed_from_u64(0);
319 let a = Matrix::<F>::random(n, n, &mut rng);
320 let mut a_inv = Matrix::<F>::zeros(n, n);
321 let mut a_inv_inv = Matrix::<F>::zeros(n, n);
322
323 a.inverse_into(&mut a_inv).unwrap();
324 a_inv.inverse_into(&mut a_inv_inv).unwrap();
325 assert_eq!(a_inv_inv, a);
326 }
327
328 #[test]
329 fn test_inverse(n in 0..8usize) {
330 type F = BinaryField32b;
331
332 let mut rng = StdRng::seed_from_u64(0);
333 let a = Matrix::<F>::random(n, n, &mut rng);
334 let mut a_inv = Matrix::<F>::zeros(n, n);
335 let mut prod = Matrix::<F>::zeros(n, n);
336
337 a.inverse_into(&mut a_inv).unwrap();
338
339 Matrix::mul_into(&a, &a_inv, &mut prod);
340 assert_eq!(prod, Matrix::<F>::identity(n));
341
342 Matrix::mul_into(&a_inv, &a, &mut prod);
343 assert_eq!(prod, Matrix::<F>::identity(n));
344 }
345 }
346}