1use std::{
4 iter::repeat_with,
5 ops::{Add, AddAssign, Index, IndexMut, Sub, SubAssign},
6};
7
8use binius_field::{ExtensionField, Field};
9use bytemuck::zeroed_slice_box;
10use getset::CopyGetters;
11use rand::RngCore;
12
13#[derive(Debug, Clone, PartialEq, Eq, CopyGetters)]
15pub struct Matrix<F: Field> {
16 #[getset(get_copy = "pub")]
18 m: usize,
19 #[getset(get_copy = "pub")]
21 n: usize,
22 elements: Box<[F]>,
23}
24
25impl<F: Field> Matrix<F> {
26 pub fn new(m: usize, n: usize, elements: &[F]) -> Self {
27 assert_eq!(elements.len(), m * n, "precondition: elements length must equal m * n");
28 Self {
29 m,
30 n,
31 elements: elements.into(),
32 }
33 }
34
35 pub fn zeros(m: usize, n: usize) -> Self {
36 Self {
37 m,
38 n,
39 elements: zeroed_slice_box(m * n),
40 }
41 }
42
43 pub fn identity(n: usize) -> Self {
44 let mut out = Self::zeros(n, n);
45 for i in 0..n {
46 out[(i, i)] = F::ONE;
47 }
48 out
49 }
50
51 fn fill_identity(&mut self) {
52 assert_eq!(self.m, self.n);
53 self.elements.fill(F::ZERO);
54 for i in 0..self.n {
55 self[(i, i)] = F::ONE;
56 }
57 }
58
59 pub const fn elements(&self) -> &[F] {
60 &self.elements
61 }
62
63 pub fn random(m: usize, n: usize, mut rng: impl RngCore) -> Self {
64 Self {
65 m,
66 n,
67 elements: repeat_with(|| F::random(&mut rng)).take(m * n).collect(),
68 }
69 }
70
71 pub const fn dim(&self) -> (usize, usize) {
72 (self.m, self.n)
73 }
74
75 pub fn copy_from(&mut self, other: &Self) {
76 assert_eq!(self.dim(), other.dim());
77 self.elements.copy_from_slice(&other.elements);
78 }
79
80 pub fn mul_into(a: &Self, b: &Self, c: &mut Self) {
81 assert_eq!(a.n(), b.m());
82 assert_eq!(a.m(), c.m());
83 assert_eq!(b.n(), c.n());
84
85 for i in 0..c.m() {
86 for j in 0..c.n() {
87 c[(i, j)] = (0..a.n()).map(|k| a[(i, k)] * b[(k, j)]).sum();
88 }
89 }
90 }
91
92 pub fn mul_vec_into<FE: ExtensionField<F>>(&self, x: &[FE], y: &mut [FE]) {
93 assert_eq!(self.n(), x.len());
94 assert_eq!(self.m(), y.len());
95
96 for i in 0..y.len() {
97 y[i] = (0..self.n()).map(|j| x[j] * self[(i, j)]).sum();
98 }
99 }
100
101 pub fn inverse_into(&self, out: &mut Self) {
109 assert_eq!(self.dim(), out.dim());
110 assert_eq!(self.m, self.n, "precondition: matrix must be square");
111
112 let n = self.n;
113
114 let mut tmp = self.clone();
115 out.fill_identity();
116
117 let mut row_buffer = vec![F::ZERO; n];
118
119 for i in 0..n {
120 let pivot = (i..n)
122 .find(|&pivot| tmp[(pivot, i)] != F::ZERO)
123 .expect("precondition: matrix must be non-singular");
124 if pivot != i {
125 tmp.swap_rows(i, pivot, &mut row_buffer);
126 out.swap_rows(i, pivot, &mut row_buffer);
127 }
128
129 let scalar = tmp[(i, i)]
131 .invert()
132 .expect("pivot is checked to be non-zero above");
133 tmp.scale_row(i, scalar);
134 out.scale_row(i, scalar);
135
136 for j in (0..i).chain(i + 1..n) {
138 let scalar = tmp[(j, i)];
139 tmp.sub_pivot_row(j, i, scalar);
140 out.sub_pivot_row(j, i, scalar);
141 }
142 }
143
144 debug_assert_eq!(tmp, Self::identity(n));
145 }
146
147 fn row_ref(&self, i: usize) -> &[F] {
148 assert!(i < self.m);
149 &self.elements[i * self.n..(i + 1) * self.n]
150 }
151
152 fn row_mut(&mut self, i: usize) -> &mut [F] {
153 assert!(i < self.m);
154 &mut self.elements[i * self.n..(i + 1) * self.n]
155 }
156
157 fn swap_rows(&mut self, i0: usize, i1: usize, buffer: &mut [F]) {
158 assert!(i0 < self.m);
159 assert!(i1 < self.m);
160 assert_eq!(buffer.len(), self.n);
161
162 if i0 == i1 {
163 return;
164 }
165
166 buffer.copy_from_slice(self.row_ref(i1));
167 self.elements
168 .copy_within(i0 * self.n..(i0 + 1) * self.n, i1 * self.n);
169 self.row_mut(i0).copy_from_slice(buffer);
170 }
171
172 fn scale_row(&mut self, i: usize, scalar: F) {
173 for x in self.row_mut(i) {
174 *x *= scalar;
175 }
176 }
177
178 fn sub_pivot_row(&mut self, i0: usize, i1: usize, scalar: F) {
179 assert!(i0 < self.m);
180 assert!(i1 < self.m);
181
182 for j in 0..self.n {
183 let x = self[(i1, j)];
184 self[(i0, j)] -= x * scalar;
185 }
186 }
187}
188
189impl<F: Field> Index<(usize, usize)> for Matrix<F> {
190 type Output = F;
191
192 fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
193 assert!(i < self.m);
194 assert!(j < self.n);
195 &self.elements[i * self.n + j]
196 }
197}
198
199impl<F: Field> IndexMut<(usize, usize)> for Matrix<F> {
200 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
201 assert!(i < self.m);
202 assert!(j < self.n);
203 &mut self.elements[i * self.n + j]
204 }
205}
206
207impl<F: Field> Add<Self> for &Matrix<F> {
208 type Output = Matrix<F>;
209
210 fn add(self, rhs: Self) -> Matrix<F> {
211 let mut out = self.clone();
212 out += rhs;
213 out
214 }
215}
216
217impl<F: Field> Sub<Self> for &Matrix<F> {
218 type Output = Matrix<F>;
219
220 fn sub(self, rhs: Self) -> Matrix<F> {
221 let mut out = self.clone();
222 out -= rhs;
223 out
224 }
225}
226
227impl<F: Field> AddAssign<&Self> for Matrix<F> {
228 fn add_assign(&mut self, rhs: &Self) {
229 assert_eq!(self.dim(), rhs.dim());
230 for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
231 *a_ij += b_ij;
232 }
233 }
234}
235
236impl<F: Field> SubAssign<&Self> for Matrix<F> {
237 fn sub_assign(&mut self, rhs: &Self) {
238 assert_eq!(self.dim(), rhs.dim());
239 for (a_ij, &b_ij) in self.elements.iter_mut().zip(rhs.elements.iter()) {
240 *a_ij -= b_ij;
241 }
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use proptest::prelude::*;
248 use rand::{SeedableRng, rngs::StdRng};
249
250 use super::*;
251 use crate::test_utils::B128;
252
253 type F = B128;
254
255 proptest! {
256 #[test]
257 fn test_left_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
258 let mut rng = StdRng::seed_from_u64(0);
259 let a0 = Matrix::<F>::random(c_m, a_n, &mut rng);
260 let a1 = Matrix::<F>::random(c_m, a_n, &mut rng);
261 let b = Matrix::<F>::random(a_n, c_n, &mut rng);
262 let mut c0 = Matrix::<F>::zeros(c_m, c_n);
263 let mut c1 = Matrix::<F>::zeros(c_m, c_n);
264
265 let a0p1 = &a0 + &a1;
266 let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
267
268 Matrix::mul_into(&a0, &b, &mut c0);
269 Matrix::mul_into(&a1, &b, &mut c1);
270 Matrix::mul_into(&a0p1, &b, &mut c0p1);
271
272 assert_eq!(c0p1, &c0 + &c1);
273 }
274
275 #[test]
276 fn test_right_linearity(c_m in 0..8usize, c_n in 0..8usize, a_n in 0..8usize) {
277 let mut rng = StdRng::seed_from_u64(0);
278 let a = Matrix::<F>::random(c_m, a_n, &mut rng);
279 let b0 = Matrix::<F>::random(a_n, c_n, &mut rng);
280 let b1 = Matrix::<F>::random(a_n, c_n, &mut rng);
281 let mut c0 = Matrix::<F>::zeros(c_m, c_n);
282 let mut c1 = Matrix::<F>::zeros(c_m, c_n);
283
284 let b0p1 = &b0 + &b1;
285 let mut c0p1 = Matrix::<F>::zeros(c_m, c_n);
286
287 Matrix::mul_into(&a, &b0, &mut c0);
288 Matrix::mul_into(&a, &b1, &mut c1);
289 Matrix::mul_into(&a, &b0p1, &mut c0p1);
290
291 assert_eq!(c0p1, &c0 + &c1);
292 }
293
294 #[test]
295 fn test_double_inverse(n in 0..8usize) {
296 let mut rng = StdRng::seed_from_u64(0);
297 let a = Matrix::<F>::random(n, n, &mut rng);
298 let mut a_inv = Matrix::<F>::zeros(n, n);
299 let mut a_inv_inv = Matrix::<F>::zeros(n, n);
300
301 a.inverse_into(&mut a_inv);
302 a_inv.inverse_into(&mut a_inv_inv);
303 assert_eq!(a_inv_inv, a);
304 }
305
306 #[test]
307 fn test_inverse(n in 0..8usize) {
308 let mut rng = StdRng::seed_from_u64(0);
309 let a = Matrix::<F>::random(n, n, &mut rng);
310 let mut a_inv = Matrix::<F>::zeros(n, n);
311 let mut prod = Matrix::<F>::zeros(n, n);
312
313 a.inverse_into(&mut a_inv);
314
315 Matrix::mul_into(&a, &a_inv, &mut prod);
316 assert_eq!(prod, Matrix::<F>::identity(n));
317
318 Matrix::mul_into(&a_inv, &a, &mut prod);
319 assert_eq!(prod, Matrix::<F>::identity(n));
320 }
321 }
322}