1use std::{
4 iter::repeat_with,
5 ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign},
6};
7
8use binius_field::{ExtensionField, Field};
9use bytemuck::zeroed_slice_box;
10use getset::CopyGetters;
11use rand::RngCore;
12
13use super::error::Error;
14
15#[derive(Debug, Clone, PartialEq, Eq, CopyGetters)]
17pub struct Matrix<F: Field> {
18 #[getset(get_copy = "pub")]
20 m: usize,
21 #[getset(get_copy = "pub")]
23 n: usize,
24 elements: Box<[F]>,
25}
26
27impl<F: Field> Matrix<F> {
28 pub fn new(m: usize, n: usize, elements: &[F]) -> Result<Self, Error> {
29 if elements.len() != m * n {
30 return Err(Error::IncorrectArgumentLength {
31 arg: "elements".into(),
32 expected: m * n,
33 });
34 }
35 Ok(Self {
36 m,
37 n,
38 elements: elements.into(),
39 })
40 }
41
42 pub fn zeros(m: usize, n: usize) -> Self {
43 Self {
44 m,
45 n,
46 elements: zeroed_slice_box(m * n),
47 }
48 }
49
50 pub fn identity(n: usize) -> Self {
51 let mut out = Self::zeros(n, n);
52 for i in 0..n {
53 out[(i, i)] = F::ONE;
54 }
55 out
56 }
57
58 fn fill_identity(&mut self) {
59 assert_eq!(self.m, self.n);
60 self.elements.fill(F::ZERO);
61 for i in 0..self.n {
62 self[(i, i)] = F::ONE;
63 }
64 }
65
66 pub const fn elements(&self) -> &[F] {
67 &self.elements
68 }
69
70 pub fn random(m: usize, n: usize, mut rng: impl RngCore) -> Self {
71 Self {
72 m,
73 n,
74 elements: repeat_with(|| F::random(&mut rng)).take(m * n).collect(),
75 }
76 }
77
78 pub const fn dim(&self) -> (usize, usize) {
79 (self.m, self.n)
80 }
81
82 pub fn copy_from(&mut self, other: &Self) {
83 assert_eq!(self.dim(), other.dim());
84 self.elements.copy_from_slice(&other.elements);
85 }
86
87 pub fn mul_into(a: &Self, b: &Self, c: &mut Self) {
88 assert_eq!(a.n(), b.m());
89 assert_eq!(a.m(), c.m());
90 assert_eq!(b.n(), c.n());
91
92 for i in 0..c.m() {
93 for j in 0..c.n() {
94 c[(i, j)] = (0..a.n()).map(|k| a[(i, k)] * b[(k, j)]).sum();
95 }
96 }
97 }
98
99 pub fn mul_vec_into<FE: ExtensionField<F>>(&self, x: &[FE], y: &mut [FE]) {
100 assert_eq!(self.n(), x.len());
101 assert_eq!(self.m(), y.len());
102
103 for i in 0..y.len() {
104 y[i] = (0..self.n()).map(|j| x[j] * self[(i, j)]).sum();
105 }
106 }
107
108 pub fn inverse_into(&self, out: &mut Self) -> Result<(), Error> {
119 assert_eq!(self.dim(), out.dim());
120
121 if self.m != self.n {
122 return Err(Error::MatrixNotSquare);
123 }
124
125 let n = self.n;
126
127 let mut tmp = self.clone();
128 out.fill_identity();
129
130 let mut row_buffer = vec![F::ZERO; n];
131
132 for i in 0..n {
133 let pivot = (i..n)
135 .find(|&pivot| tmp[(pivot, i)] != F::ZERO)
136 .ok_or(Error::MatrixIsSingular)?;
137 if pivot != i {
138 tmp.swap_rows(i, pivot, &mut row_buffer);
139 out.swap_rows(i, pivot, &mut row_buffer);
140 }
141
142 let scalar = tmp[(i, i)]
144 .invert()
145 .expect("pivot is checked to be non-zero above");
146 tmp.scale_row(i, scalar);
147 out.scale_row(i, scalar);
148
149 for j in (0..i).chain(i + 1..n) {
151 let scalar = tmp[(j, i)];
152 tmp.sub_pivot_row(j, i, scalar);
153 out.sub_pivot_row(j, i, scalar);
154 }
155 }
156
157 debug_assert_eq!(tmp, Self::identity(n));
158
159 Ok(())
160 }
161
162 fn row_ref(&self, i: usize) -> &[F] {
163 assert!(i < self.m);
164 &self.elements[i * self.n..(i + 1) * self.n]
165 }
166
167 fn row_mut(&mut self, i: usize) -> &mut [F] {
168 assert!(i < self.m);
169 &mut self.elements[i * self.n..(i + 1) * self.n]
170 }
171
172 fn swap_rows(&mut self, i0: usize, i1: usize, buffer: &mut [F]) {
173 assert!(i0 < self.m);
174 assert!(i1 < self.m);
175 assert_eq!(buffer.len(), self.n);
176
177 if i0 == i1 {
178 return;
179 }
180
181 buffer.copy_from_slice(self.row_ref(i1));
182 self.elements
183 .copy_within(i0 * self.n..(i0 + 1) * self.n, i1 * self.n);
184 self.row_mut(i0).copy_from_slice(buffer);
185 }
186
187 fn scale_row(&mut self, i: usize, scalar: F) {
188 for x in self.row_mut(i) {
189 *x *= scalar;
190 }
191 }
192
193 fn sub_pivot_row(&mut self, i0: usize, i1: usize, scalar: F) {
194 assert!(i0 < self.m);
195 assert!(i1 < self.m);
196
197 for j in 0..self.n {
198 let x = self[(i1, j)];
199 self[(i0, j)] -= x * scalar;
200 }
201 }
202}
203
204impl<F: Field> Index<(usize, usize)> for Matrix<F> {
205 type Output = F;
206
207 fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
208 assert!(i < self.m);
209 assert!(j < self.n);
210 &self.elements[i * self.n + j]
211 }
212}
213
214impl<F: Field> IndexMut<(usize, usize)> for Matrix<F> {
215 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
216 assert!(i < self.m);
217 assert!(j < self.n);
218 &mut self.elements[i * self.n + j]
219 }
220}
221
222impl<F: Field> Add<Self> for &Matrix<F> {
223 type Output = Matrix<F>;
224
225 fn add(self, rhs: Self) -> Matrix<F> {
226 let mut out = self.clone();
227 out += rhs;
228 out
229 }
230}
231
232impl<F: Field> Sub<Self> for &Matrix<F> {
233 type Output = Matrix<F>;
234
235 fn sub(self, rhs: Self) -> Matrix<F> {
236 let mut out = self.clone();
237 out -= rhs;
238 out
239 }
240}
241
242impl<F: Field> AddAssign<&Self> for Matrix<F> {
243 fn add_assign(&mut self, rhs: &Self) {
244 assert_eq!(self.dim(), rhs.dim());
245 for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
246 *a_ij += b_ij;
247 }
248 }
249}
250
251impl<F: Field> SubAssign<&Self> for Matrix<F> {
252 fn sub_assign(&mut self, rhs: &Self) {
253 assert_eq!(self.dim(), rhs.dim());
254 for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
255 *a_ij -= b_ij;
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use proptest::prelude::*;
263 use rand::{SeedableRng, rngs::StdRng};
264
265 use super::*;
266 use crate::test_utils::B128;
267
268 type F = B128;
269
270 proptest! {
271 #[test]
272 fn test_left_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
273 let mut rng = StdRng::seed_from_u64(0);
274 let a0 = Matrix::<F>::random(c_m, a_n, &mut rng);
275 let a1 = Matrix::<F>::random(c_m, a_n, &mut rng);
276 let b = Matrix::<F>::random(a_n, c_n, &mut rng);
277 let mut c0 = Matrix::<F>::zeros(c_m, c_n);
278 let mut c1 = Matrix::<F>::zeros(c_m, c_n);
279
280 let a0p1 = &a0 + &a1;
281 let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
282
283 Matrix::mul_into(&a0, &b, &mut c0);
284 Matrix::mul_into(&a1, &b, &mut c1);
285 Matrix::mul_into(&a0p1, &b, &mut c0p1);
286
287 assert_eq!(c0p1, &c0 + &c1);
288 }
289
290 #[test]
291 fn test_right_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
292 let mut rng = StdRng::seed_from_u64(0);
293 let a = Matrix::<F>::random(c_m, a_n, &mut rng);
294 let b0 = Matrix::<F>::random(a_n, c_n, &mut rng);
295 let b1 = Matrix::<F>::random(a_n, c_n, &mut rng);
296 let mut c0 = Matrix::<F>::zeros(c_m, c_n);
297 let mut c1 = Matrix::<F>::zeros(c_m, c_n);
298
299 let b0p1 = &b0 + &b1;
300 let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
301
302 Matrix::mul_into(&a, &b0, &mut c0);
303 Matrix::mul_into(&a, &b1, &mut c1);
304 Matrix::mul_into(&a, &b0p1, &mut c0p1);
305
306 assert_eq!(c0p1, &c0 + &c1);
307 }
308
309 #[test]
310 fn test_double_inverse(n in 0..8usize) {
311 let mut rng = StdRng::seed_from_u64(0);
312 let a = Matrix::<F>::random(n, n, &mut rng);
313 let mut a_inv = Matrix::<F>::zeros(n, n);
314 let mut a_inv_inv = Matrix::<F>::zeros(n, n);
315
316 a.inverse_into(&mut a_inv).unwrap();
317 a_inv.inverse_into(&mut a_inv_inv).unwrap();
318 assert_eq!(a_inv_inv, a);
319 }
320
321 #[test]
322 fn test_inverse(n in 0..8usize) {
323 let mut rng = StdRng::seed_from_u64(0);
324 let a = Matrix::<F>::random(n, n, &mut rng);
325 let mut a_inv = Matrix::<F>::zeros(n, n);
326 let mut prod = Matrix::<F>::zeros(n, n);
327
328 a.inverse_into(&mut a_inv).unwrap();
329
330 Matrix::mul_into(&a, &a_inv, &mut prod);
331 assert_eq!(prod, Matrix::<F>::identity(n));
332
333 Matrix::mul_into(&a_inv, &a, &mut prod);
334 assert_eq!(prod, Matrix::<F>::identity(n));
335 }
336 }
337}