1use std::ops::{AddAssign, Deref, DerefMut, Index, IndexMut};
4
5use bytemuck::{allocation::zeroed_vec, Zeroable};
6
7#[derive(Debug)]
9pub struct Array2D<T, Data: Deref<Target = [T]> = Vec<T>> {
10 data: Data,
11 rows: usize,
12 cols: usize,
13}
14
15impl<T: Default + Clone> Array2D<T> {
16 pub fn new(rows: usize, cols: usize) -> Self {
18 Self {
19 data: vec![T::default(); rows * cols],
20 rows,
21 cols,
22 }
23 }
24
25 pub fn zeroes(rows: usize, cols: usize) -> Self
27 where
28 T: Zeroable,
29 {
30 Self {
31 data: zeroed_vec(rows * cols),
32 rows,
33 cols,
34 }
35 }
36}
37
38impl<T, Data: Deref<Target = [T]>> Array2D<T, Data> {
39 pub fn rows(&self) -> usize {
41 self.data.len() / self.cols
42 }
43
44 pub const fn cols(&self) -> usize {
46 self.cols
47 }
48
49 pub fn get_row(&self, i: usize) -> &[T] {
51 let start = i * self.cols;
52 &self.data[start..start + self.cols]
53 }
54
55 pub fn iter_rows(&self) -> impl Iterator<Item = &[T]> {
57 (0..self.rows).map(move |i| self.get_row(i))
58 }
59
60 pub unsafe fn get_unchecked(&self, i: usize, j: usize) -> &T {
65 self.data.get_unchecked(i * self.cols + j)
66 }
67
68 pub fn reshape(&self, rows: usize, cols: usize) -> Option<Array2D<T, &[T]>> {
70 if rows * cols != self.data.len() {
71 return None;
72 }
73
74 Some(Array2D {
75 data: &self.data,
76 rows,
77 cols,
78 })
79 }
80}
81
82impl<T, Data: DerefMut<Target = [T]>> Array2D<T, Data> {
83 pub fn get_row_mut(&mut self, i: usize) -> &mut [T] {
85 let start = i * self.cols;
86 &mut self.data[start..start + self.cols]
87 }
88
89 pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
94 self.data.get_unchecked_mut(i * self.cols + j)
95 }
96
97 pub fn reshape_mut(&mut self, rows: usize, cols: usize) -> Option<Array2D<T, &mut [T]>> {
99 if rows * cols != self.data.len() {
100 return None;
101 }
102
103 Some(Array2D {
104 data: self.data.deref_mut(),
105 rows,
106 cols,
107 })
108 }
109}
110
111impl<T, Data: Deref<Target = [T]>> Index<(usize, usize)> for Array2D<T, Data> {
112 type Output = T;
113
114 fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
115 &self.data[i * self.cols + j]
116 }
117}
118
119impl<T, Data: DerefMut<Target = [T]>> IndexMut<(usize, usize)> for Array2D<T, Data> {
120 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
121 &mut self.data[i * self.cols + j]
122 }
123}
124
125impl<T: Default + Clone + AddAssign, Data: Deref<Target = [T]>> Array2D<T, Data> {
126 pub fn sum_rows(&self) -> Vec<T> {
128 let mut sum = vec![T::default(); self.cols];
129
130 for row in self.iter_rows() {
131 for (i, elem) in row.iter().enumerate() {
132 sum[i] += elem.clone();
133 }
134 }
135
136 sum
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_get_set() {
146 let mut arr = Array2D::new(2, 3);
147 arr[(0, 0)] = 1;
148 arr[(0, 1)] = 2;
149 arr[(0, 2)] = 3;
150 arr[(1, 0)] = 4;
151 arr[(1, 1)] = 5;
152 arr[(1, 2)] = 6;
153
154 assert_eq!(arr[(0, 0)], 1);
155 assert_eq!(arr[(0, 1)], 2);
156 assert_eq!(arr[(0, 2)], 3);
157 assert_eq!(arr[(1, 0)], 4);
158 assert_eq!(arr[(1, 1)], 5);
159 assert_eq!(arr[(1, 2)], 6);
160 }
161
162 #[test]
163 fn test_unchecked_access() {
164 let mut arr = Array2D::new(2, 3);
165 unsafe {
166 *arr.get_unchecked_mut(0, 0) = 1;
167 *arr.get_unchecked_mut(0, 1) = 2;
168 *arr.get_unchecked_mut(0, 2) = 3;
169 *arr.get_unchecked_mut(1, 0) = 4;
170 *arr.get_unchecked_mut(1, 1) = 5;
171 *arr.get_unchecked_mut(1, 2) = 6;
172 }
173
174 unsafe {
175 assert_eq!(*arr.get_unchecked(0, 0), 1);
176 assert_eq!(*arr.get_unchecked(0, 1), 2);
177 assert_eq!(*arr.get_unchecked(0, 2), 3);
178 assert_eq!(*arr.get_unchecked(1, 0), 4);
179 assert_eq!(*arr.get_unchecked(1, 1), 5);
180 assert_eq!(*arr.get_unchecked(1, 2), 6);
181 }
182 }
183
184 #[test]
185 fn test_get_row() {
186 let mut arr = Array2D::new(2, 3);
187 arr[(0, 0)] = 1;
188 arr[(0, 1)] = 2;
189 arr[(0, 2)] = 3;
190 arr[(1, 0)] = 4;
191 arr[(1, 1)] = 5;
192 arr[(1, 2)] = 6;
193
194 assert_eq!(arr.get_row(0), &[1, 2, 3]);
195 assert_eq!(arr.get_row_mut(0), &mut [1, 2, 3]);
196 assert_eq!(arr.get_row(1), &[4, 5, 6]);
197 assert_eq!(arr.get_row_mut(1), &mut [4, 5, 6]);
198 }
199
200 #[test]
201 fn test_sum_rows() {
202 let mut arr = Array2D::new(2, 3);
203 arr[(0, 0)] = 1;
204 arr[(0, 1)] = 2;
205 arr[(0, 2)] = 3;
206 arr[(1, 0)] = 4;
207 arr[(1, 1)] = 5;
208 arr[(1, 2)] = 6;
209
210 assert_eq!(arr.sum_rows(), vec![5, 7, 9]);
211 }
212}