use std::ops::{AddAssign, Deref, DerefMut, Index, IndexMut};
use bytemuck::{allocation::zeroed_vec, Zeroable};
#[derive(Debug)]
pub struct Array2D<T, Data: Deref<Target = [T]> = Vec<T>> {
data: Data,
rows: usize,
cols: usize,
}
impl<T: Default + Clone> Array2D<T> {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
data: vec![T::default(); rows * cols],
rows,
cols,
}
}
pub fn zeroes(rows: usize, cols: usize) -> Self
where
T: Zeroable,
{
Self {
data: zeroed_vec(rows * cols),
rows,
cols,
}
}
}
impl<T, Data: Deref<Target = [T]>> Array2D<T, Data> {
pub fn rows(&self) -> usize {
self.data.len() / self.cols
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn get_row(&self, i: usize) -> &[T] {
let start = i * self.cols;
&self.data[start..start + self.cols]
}
pub fn iter_rows(&self) -> impl Iterator<Item = &[T]> {
(0..self.rows).map(move |i| self.get_row(i))
}
pub unsafe fn get_unchecked(&self, i: usize, j: usize) -> &T {
self.data.get_unchecked(i * self.cols + j)
}
pub fn reshape(&self, rows: usize, cols: usize) -> Option<Array2D<T, &[T]>> {
if rows * cols != self.data.len() {
return None;
}
Some(Array2D {
data: self.data.deref(),
rows,
cols,
})
}
}
impl<T, Data: DerefMut<Target = [T]>> Array2D<T, Data> {
pub fn get_row_mut(&mut self, i: usize) -> &mut [T] {
let start = i * self.cols;
&mut self.data[start..start + self.cols]
}
pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
self.data.get_unchecked_mut(i * self.cols + j)
}
pub fn reshape_mut(&mut self, rows: usize, cols: usize) -> Option<Array2D<T, &mut [T]>> {
if rows * cols != self.data.len() {
return None;
}
Some(Array2D {
data: self.data.deref_mut(),
rows,
cols,
})
}
}
impl<T, Data: Deref<Target = [T]>> Index<(usize, usize)> for Array2D<T, Data> {
type Output = T;
fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
&self.data[i * self.cols + j]
}
}
impl<T, Data: DerefMut<Target = [T]>> IndexMut<(usize, usize)> for Array2D<T, Data> {
fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
&mut self.data[i * self.cols + j]
}
}
impl<T: Default + Clone + AddAssign, Data: Deref<Target = [T]>> Array2D<T, Data> {
pub fn sum_rows(&self) -> Vec<T> {
let mut sum = vec![T::default(); self.cols];
for row in self.iter_rows() {
for (i, elem) in row.iter().enumerate() {
sum[i] += elem.clone();
}
}
sum
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_set() {
let mut arr = Array2D::new(2, 3);
arr[(0, 0)] = 1;
arr[(0, 1)] = 2;
arr[(0, 2)] = 3;
arr[(1, 0)] = 4;
arr[(1, 1)] = 5;
arr[(1, 2)] = 6;
assert_eq!(arr[(0, 0)], 1);
assert_eq!(arr[(0, 1)], 2);
assert_eq!(arr[(0, 2)], 3);
assert_eq!(arr[(1, 0)], 4);
assert_eq!(arr[(1, 1)], 5);
assert_eq!(arr[(1, 2)], 6);
}
#[test]
fn test_unchecked_access() {
let mut arr = Array2D::new(2, 3);
unsafe {
*arr.get_unchecked_mut(0, 0) = 1;
*arr.get_unchecked_mut(0, 1) = 2;
*arr.get_unchecked_mut(0, 2) = 3;
*arr.get_unchecked_mut(1, 0) = 4;
*arr.get_unchecked_mut(1, 1) = 5;
*arr.get_unchecked_mut(1, 2) = 6;
}
unsafe {
assert_eq!(*arr.get_unchecked(0, 0), 1);
assert_eq!(*arr.get_unchecked(0, 1), 2);
assert_eq!(*arr.get_unchecked(0, 2), 3);
assert_eq!(*arr.get_unchecked(1, 0), 4);
assert_eq!(*arr.get_unchecked(1, 1), 5);
assert_eq!(*arr.get_unchecked(1, 2), 6);
}
}
#[test]
fn test_get_row() {
let mut arr = Array2D::new(2, 3);
arr[(0, 0)] = 1;
arr[(0, 1)] = 2;
arr[(0, 2)] = 3;
arr[(1, 0)] = 4;
arr[(1, 1)] = 5;
arr[(1, 2)] = 6;
assert_eq!(arr.get_row(0), &[1, 2, 3]);
assert_eq!(arr.get_row_mut(0), &mut [1, 2, 3]);
assert_eq!(arr.get_row(1), &[4, 5, 6]);
assert_eq!(arr.get_row_mut(1), &mut [4, 5, 6]);
}
#[test]
fn test_sum_rows() {
let mut arr = Array2D::new(2, 3);
arr[(0, 0)] = 1;
arr[(0, 1)] = 2;
arr[(0, 2)] = 3;
arr[(1, 0)] = 4;
arr[(1, 1)] = 5;
arr[(1, 2)] = 6;
assert_eq!(arr.sum_rows(), vec![5, 7, 9]);
}
}