1use std::ops::{AddAssign, Deref, DerefMut, Index, IndexMut};
4
5use bytemuck::{allocation::zeroed_vec, Zeroable};
6
7#[derive(Debug)]
9pub struct Array2D<T, Data: Deref<Target = [T]> = Vec<T>> {
10 data: Data,
11 rows: usize,
12 cols: usize,
13}
14
15impl<T: Default + Clone> Array2D<T> {
16 pub fn new(rows: usize, cols: usize) -> Self {
18 Self {
19 data: vec![T::default(); rows * cols],
20 rows,
21 cols,
22 }
23 }
24
25 pub fn zeroes(rows: usize, cols: usize) -> Self
27 where
28 T: Zeroable,
29 {
30 Self {
31 data: zeroed_vec(rows * cols),
32 rows,
33 cols,
34 }
35 }
36}
37
38impl<T, Data: Deref<Target = [T]>> Array2D<T, Data> {
39 pub fn rows(&self) -> usize {
41 self.data.len() / self.cols
42 }
43
44 pub const fn cols(&self) -> usize {
46 self.cols
47 }
48
49 pub fn get_row(&self, i: usize) -> &[T] {
51 let start = i * self.cols;
52 &self.data[start..start + self.cols]
53 }
54
55 pub fn iter_rows(&self) -> impl Iterator<Item = &[T]> {
57 (0..self.rows).map(move |i| self.get_row(i))
58 }
59
60 pub unsafe fn get_unchecked(&self, i: usize, j: usize) -> &T {
64 self.data.get_unchecked(i * self.cols + j)
65 }
66
67 pub fn reshape(&self, rows: usize, cols: usize) -> Option<Array2D<T, &[T]>> {
69 if rows * cols != self.data.len() {
70 return None;
71 }
72
73 Some(Array2D {
74 data: &self.data,
75 rows,
76 cols,
77 })
78 }
79}
80
81impl<T, Data: DerefMut<Target = [T]>> Array2D<T, Data> {
82 pub fn get_row_mut(&mut self, i: usize) -> &mut [T] {
84 let start = i * self.cols;
85 &mut self.data[start..start + self.cols]
86 }
87
88 pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
92 self.data.get_unchecked_mut(i * self.cols + j)
93 }
94
95 pub fn reshape_mut(&mut self, rows: usize, cols: usize) -> Option<Array2D<T, &mut [T]>> {
97 if rows * cols != self.data.len() {
98 return None;
99 }
100
101 Some(Array2D {
102 data: self.data.deref_mut(),
103 rows,
104 cols,
105 })
106 }
107}
108
109impl<T, Data: Deref<Target = [T]>> Index<(usize, usize)> for Array2D<T, Data> {
110 type Output = T;
111
112 fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
113 &self.data[i * self.cols + j]
114 }
115}
116
117impl<T, Data: DerefMut<Target = [T]>> IndexMut<(usize, usize)> for Array2D<T, Data> {
118 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
119 &mut self.data[i * self.cols + j]
120 }
121}
122
123impl<T: Default + Clone + AddAssign, Data: Deref<Target = [T]>> Array2D<T, Data> {
124 pub fn sum_rows(&self) -> Vec<T> {
126 let mut sum = vec![T::default(); self.cols];
127
128 for row in self.iter_rows() {
129 for (i, elem) in row.iter().enumerate() {
130 sum[i] += elem.clone();
131 }
132 }
133
134 sum
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_get_set() {
144 let mut arr = Array2D::new(2, 3);
145 arr[(0, 0)] = 1;
146 arr[(0, 1)] = 2;
147 arr[(0, 2)] = 3;
148 arr[(1, 0)] = 4;
149 arr[(1, 1)] = 5;
150 arr[(1, 2)] = 6;
151
152 assert_eq!(arr[(0, 0)], 1);
153 assert_eq!(arr[(0, 1)], 2);
154 assert_eq!(arr[(0, 2)], 3);
155 assert_eq!(arr[(1, 0)], 4);
156 assert_eq!(arr[(1, 1)], 5);
157 assert_eq!(arr[(1, 2)], 6);
158 }
159
160 #[test]
161 fn test_unchecked_access() {
162 let mut arr = Array2D::new(2, 3);
163 unsafe {
164 *arr.get_unchecked_mut(0, 0) = 1;
165 *arr.get_unchecked_mut(0, 1) = 2;
166 *arr.get_unchecked_mut(0, 2) = 3;
167 *arr.get_unchecked_mut(1, 0) = 4;
168 *arr.get_unchecked_mut(1, 1) = 5;
169 *arr.get_unchecked_mut(1, 2) = 6;
170 }
171
172 unsafe {
173 assert_eq!(*arr.get_unchecked(0, 0), 1);
174 assert_eq!(*arr.get_unchecked(0, 1), 2);
175 assert_eq!(*arr.get_unchecked(0, 2), 3);
176 assert_eq!(*arr.get_unchecked(1, 0), 4);
177 assert_eq!(*arr.get_unchecked(1, 1), 5);
178 assert_eq!(*arr.get_unchecked(1, 2), 6);
179 }
180 }
181
182 #[test]
183 fn test_get_row() {
184 let mut arr = Array2D::new(2, 3);
185 arr[(0, 0)] = 1;
186 arr[(0, 1)] = 2;
187 arr[(0, 2)] = 3;
188 arr[(1, 0)] = 4;
189 arr[(1, 1)] = 5;
190 arr[(1, 2)] = 6;
191
192 assert_eq!(arr.get_row(0), &[1, 2, 3]);
193 assert_eq!(arr.get_row_mut(0), &mut [1, 2, 3]);
194 assert_eq!(arr.get_row(1), &[4, 5, 6]);
195 assert_eq!(arr.get_row_mut(1), &mut [4, 5, 6]);
196 }
197
198 #[test]
199 fn test_sum_rows() {
200 let mut arr = Array2D::new(2, 3);
201 arr[(0, 0)] = 1;
202 arr[(0, 1)] = 2;
203 arr[(0, 2)] = 3;
204 arr[(1, 0)] = 4;
205 arr[(1, 1)] = 5;
206 arr[(1, 2)] = 6;
207
208 assert_eq!(arr.sum_rows(), vec![5, 7, 9]);
209 }
210}