binius_utils/
array_2d.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{AddAssign, Deref, DerefMut, Index, IndexMut};
4
5use bytemuck::{allocation::zeroed_vec, Zeroable};
6
7/// 2D array with row-major layout.
8#[derive(Debug)]
9pub struct Array2D<T, Data: Deref<Target = [T]> = Vec<T>> {
10	data: Data,
11	rows: usize,
12	cols: usize,
13}
14
15impl<T: Default + Clone> Array2D<T> {
16	/// Create a new 2D array of the given size initialized with default values.
17	pub fn new(rows: usize, cols: usize) -> Self {
18		Self {
19			data: vec![T::default(); rows * cols],
20			rows,
21			cols,
22		}
23	}
24
25	/// Create a new 2D array of the given size initialized with zeroes.
26	pub fn zeroes(rows: usize, cols: usize) -> Self
27	where
28		T: Zeroable,
29	{
30		Self {
31			data: zeroed_vec(rows * cols),
32			rows,
33			cols,
34		}
35	}
36}
37
38impl<T, Data: Deref<Target = [T]>> Array2D<T, Data> {
39	/// Returns the number of rows in the array.
40	pub fn rows(&self) -> usize {
41		self.data.len() / self.cols
42	}
43
44	/// Returns the number of columns in the array.
45	pub const fn cols(&self) -> usize {
46		self.cols
47	}
48
49	/// Returns the row at the given index.
50	pub fn get_row(&self, i: usize) -> &[T] {
51		let start = i * self.cols;
52		&self.data[start..start + self.cols]
53	}
54
55	/// Returns an iterator over the rows of the array.
56	pub fn iter_rows(&self) -> impl Iterator<Item = &[T]> {
57		(0..self.rows).map(move |i| self.get_row(i))
58	}
59
60	/// Return the element at the given row and column without bounds checking.
61	/// # Safety
62	/// The caller must ensure that `i` and `j` are less than the number of rows and columns
63	/// respectively.
64	pub unsafe fn get_unchecked(&self, i: usize, j: usize) -> &T {
65		self.data.get_unchecked(i * self.cols + j)
66	}
67
68	/// View of the array in a different shape, underlying elements stay the same.
69	pub fn reshape(&self, rows: usize, cols: usize) -> Option<Array2D<T, &[T]>> {
70		if rows * cols != self.data.len() {
71			return None;
72		}
73
74		Some(Array2D {
75			data: &self.data,
76			rows,
77			cols,
78		})
79	}
80}
81
82impl<T, Data: DerefMut<Target = [T]>> Array2D<T, Data> {
83	/// Returns the mutable row at the given index.
84	pub fn get_row_mut(&mut self, i: usize) -> &mut [T] {
85		let start = i * self.cols;
86		&mut self.data[start..start + self.cols]
87	}
88
89	/// Return the mutable element at the given row and column without bounds checking.
90	/// # Safety
91	/// The caller must ensure that `i` and `j` are less than the number of rows and columns
92	/// respectively.
93	pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
94		self.data.get_unchecked_mut(i * self.cols + j)
95	}
96
97	/// Mutable view of the array in a different shape, underlying elements stay the same.
98	pub fn reshape_mut(&mut self, rows: usize, cols: usize) -> Option<Array2D<T, &mut [T]>> {
99		if rows * cols != self.data.len() {
100			return None;
101		}
102
103		Some(Array2D {
104			data: self.data.deref_mut(),
105			rows,
106			cols,
107		})
108	}
109}
110
111impl<T, Data: Deref<Target = [T]>> Index<(usize, usize)> for Array2D<T, Data> {
112	type Output = T;
113
114	fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
115		&self.data[i * self.cols + j]
116	}
117}
118
119impl<T, Data: DerefMut<Target = [T]>> IndexMut<(usize, usize)> for Array2D<T, Data> {
120	fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
121		&mut self.data[i * self.cols + j]
122	}
123}
124
125impl<T: Default + Clone + AddAssign, Data: Deref<Target = [T]>> Array2D<T, Data> {
126	/// Returns the sum of the elements in each row.
127	pub fn sum_rows(&self) -> Vec<T> {
128		let mut sum = vec![T::default(); self.cols];
129
130		for row in self.iter_rows() {
131			for (i, elem) in row.iter().enumerate() {
132				sum[i] += elem.clone();
133			}
134		}
135
136		sum
137	}
138}
139
140#[cfg(test)]
141mod tests {
142	use super::*;
143
144	#[test]
145	fn test_get_set() {
146		let mut arr = Array2D::new(2, 3);
147		arr[(0, 0)] = 1;
148		arr[(0, 1)] = 2;
149		arr[(0, 2)] = 3;
150		arr[(1, 0)] = 4;
151		arr[(1, 1)] = 5;
152		arr[(1, 2)] = 6;
153
154		assert_eq!(arr[(0, 0)], 1);
155		assert_eq!(arr[(0, 1)], 2);
156		assert_eq!(arr[(0, 2)], 3);
157		assert_eq!(arr[(1, 0)], 4);
158		assert_eq!(arr[(1, 1)], 5);
159		assert_eq!(arr[(1, 2)], 6);
160	}
161
162	#[test]
163	fn test_unchecked_access() {
164		let mut arr = Array2D::new(2, 3);
165		unsafe {
166			*arr.get_unchecked_mut(0, 0) = 1;
167			*arr.get_unchecked_mut(0, 1) = 2;
168			*arr.get_unchecked_mut(0, 2) = 3;
169			*arr.get_unchecked_mut(1, 0) = 4;
170			*arr.get_unchecked_mut(1, 1) = 5;
171			*arr.get_unchecked_mut(1, 2) = 6;
172		}
173
174		unsafe {
175			assert_eq!(*arr.get_unchecked(0, 0), 1);
176			assert_eq!(*arr.get_unchecked(0, 1), 2);
177			assert_eq!(*arr.get_unchecked(0, 2), 3);
178			assert_eq!(*arr.get_unchecked(1, 0), 4);
179			assert_eq!(*arr.get_unchecked(1, 1), 5);
180			assert_eq!(*arr.get_unchecked(1, 2), 6);
181		}
182	}
183
184	#[test]
185	fn test_get_row() {
186		let mut arr = Array2D::new(2, 3);
187		arr[(0, 0)] = 1;
188		arr[(0, 1)] = 2;
189		arr[(0, 2)] = 3;
190		arr[(1, 0)] = 4;
191		arr[(1, 1)] = 5;
192		arr[(1, 2)] = 6;
193
194		assert_eq!(arr.get_row(0), &[1, 2, 3]);
195		assert_eq!(arr.get_row_mut(0), &mut [1, 2, 3]);
196		assert_eq!(arr.get_row(1), &[4, 5, 6]);
197		assert_eq!(arr.get_row_mut(1), &mut [4, 5, 6]);
198	}
199
200	#[test]
201	fn test_sum_rows() {
202		let mut arr = Array2D::new(2, 3);
203		arr[(0, 0)] = 1;
204		arr[(0, 1)] = 2;
205		arr[(0, 2)] = 3;
206		arr[(1, 0)] = 4;
207		arr[(1, 1)] = 5;
208		arr[(1, 2)] = 6;
209
210		assert_eq!(arr.sum_rows(), vec![5, 7, 9]);
211	}
212}