binius_utils/
array_2d.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{AddAssign, Deref, DerefMut, Index, IndexMut};
4
5use bytemuck::{allocation::zeroed_vec, Zeroable};
6
7/// 2D array with row-major layout.
8#[derive(Debug)]
9pub struct Array2D<T, Data: Deref<Target = [T]> = Vec<T>> {
10	data: Data,
11	rows: usize,
12	cols: usize,
13}
14
15impl<T: Default + Clone> Array2D<T> {
16	/// Create a new 2D array of the given size initialized with default values.
17	pub fn new(rows: usize, cols: usize) -> Self {
18		Self {
19			data: vec![T::default(); rows * cols],
20			rows,
21			cols,
22		}
23	}
24
25	/// Create a new 2D array of the given size initialized with zeroes.
26	pub fn zeroes(rows: usize, cols: usize) -> Self
27	where
28		T: Zeroable,
29	{
30		Self {
31			data: zeroed_vec(rows * cols),
32			rows,
33			cols,
34		}
35	}
36}
37
38impl<T, Data: Deref<Target = [T]>> Array2D<T, Data> {
39	/// Returns the number of rows in the array.
40	pub fn rows(&self) -> usize {
41		self.data.len() / self.cols
42	}
43
44	/// Returns the number of columns in the array.
45	pub const fn cols(&self) -> usize {
46		self.cols
47	}
48
49	/// Returns the row at the given index.
50	pub fn get_row(&self, i: usize) -> &[T] {
51		let start = i * self.cols;
52		&self.data[start..start + self.cols]
53	}
54
55	/// Returns an iterator over the rows of the array.
56	pub fn iter_rows(&self) -> impl Iterator<Item = &[T]> {
57		(0..self.rows).map(move |i| self.get_row(i))
58	}
59
60	/// Return the element at the given row and column without bounds checking.
61	/// # Safety
62	/// The caller must ensure that `i` and `j` are less than the number of rows and columns respectively.
63	pub unsafe fn get_unchecked(&self, i: usize, j: usize) -> &T {
64		self.data.get_unchecked(i * self.cols + j)
65	}
66
67	/// View of the array in a different shape, underlying elements stay the same.
68	pub fn reshape(&self, rows: usize, cols: usize) -> Option<Array2D<T, &[T]>> {
69		if rows * cols != self.data.len() {
70			return None;
71		}
72
73		Some(Array2D {
74			data: &self.data,
75			rows,
76			cols,
77		})
78	}
79}
80
81impl<T, Data: DerefMut<Target = [T]>> Array2D<T, Data> {
82	/// Returns the mutable row at the given index.
83	pub fn get_row_mut(&mut self, i: usize) -> &mut [T] {
84		let start = i * self.cols;
85		&mut self.data[start..start + self.cols]
86	}
87
88	/// Return the mutable element at the given row and column without bounds checking.
89	/// # Safety
90	/// The caller must ensure that `i` and `j` are less than the number of rows and columns respectively.
91	pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
92		self.data.get_unchecked_mut(i * self.cols + j)
93	}
94
95	/// Mutable view of the array in a different shape, underlying elements stay the same.
96	pub fn reshape_mut(&mut self, rows: usize, cols: usize) -> Option<Array2D<T, &mut [T]>> {
97		if rows * cols != self.data.len() {
98			return None;
99		}
100
101		Some(Array2D {
102			data: self.data.deref_mut(),
103			rows,
104			cols,
105		})
106	}
107}
108
109impl<T, Data: Deref<Target = [T]>> Index<(usize, usize)> for Array2D<T, Data> {
110	type Output = T;
111
112	fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
113		&self.data[i * self.cols + j]
114	}
115}
116
117impl<T, Data: DerefMut<Target = [T]>> IndexMut<(usize, usize)> for Array2D<T, Data> {
118	fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
119		&mut self.data[i * self.cols + j]
120	}
121}
122
123impl<T: Default + Clone + AddAssign, Data: Deref<Target = [T]>> Array2D<T, Data> {
124	/// Returns the sum of the elements in each row.
125	pub fn sum_rows(&self) -> Vec<T> {
126		let mut sum = vec![T::default(); self.cols];
127
128		for row in self.iter_rows() {
129			for (i, elem) in row.iter().enumerate() {
130				sum[i] += elem.clone();
131			}
132		}
133
134		sum
135	}
136}
137
138#[cfg(test)]
139mod tests {
140	use super::*;
141
142	#[test]
143	fn test_get_set() {
144		let mut arr = Array2D::new(2, 3);
145		arr[(0, 0)] = 1;
146		arr[(0, 1)] = 2;
147		arr[(0, 2)] = 3;
148		arr[(1, 0)] = 4;
149		arr[(1, 1)] = 5;
150		arr[(1, 2)] = 6;
151
152		assert_eq!(arr[(0, 0)], 1);
153		assert_eq!(arr[(0, 1)], 2);
154		assert_eq!(arr[(0, 2)], 3);
155		assert_eq!(arr[(1, 0)], 4);
156		assert_eq!(arr[(1, 1)], 5);
157		assert_eq!(arr[(1, 2)], 6);
158	}
159
160	#[test]
161	fn test_unchecked_access() {
162		let mut arr = Array2D::new(2, 3);
163		unsafe {
164			*arr.get_unchecked_mut(0, 0) = 1;
165			*arr.get_unchecked_mut(0, 1) = 2;
166			*arr.get_unchecked_mut(0, 2) = 3;
167			*arr.get_unchecked_mut(1, 0) = 4;
168			*arr.get_unchecked_mut(1, 1) = 5;
169			*arr.get_unchecked_mut(1, 2) = 6;
170		}
171
172		unsafe {
173			assert_eq!(*arr.get_unchecked(0, 0), 1);
174			assert_eq!(*arr.get_unchecked(0, 1), 2);
175			assert_eq!(*arr.get_unchecked(0, 2), 3);
176			assert_eq!(*arr.get_unchecked(1, 0), 4);
177			assert_eq!(*arr.get_unchecked(1, 1), 5);
178			assert_eq!(*arr.get_unchecked(1, 2), 6);
179		}
180	}
181
182	#[test]
183	fn test_get_row() {
184		let mut arr = Array2D::new(2, 3);
185		arr[(0, 0)] = 1;
186		arr[(0, 1)] = 2;
187		arr[(0, 2)] = 3;
188		arr[(1, 0)] = 4;
189		arr[(1, 1)] = 5;
190		arr[(1, 2)] = 6;
191
192		assert_eq!(arr.get_row(0), &[1, 2, 3]);
193		assert_eq!(arr.get_row_mut(0), &mut [1, 2, 3]);
194		assert_eq!(arr.get_row(1), &[4, 5, 6]);
195		assert_eq!(arr.get_row_mut(1), &mut [4, 5, 6]);
196	}
197
198	#[test]
199	fn test_sum_rows() {
200		let mut arr = Array2D::new(2, 3);
201		arr[(0, 0)] = 1;
202		arr[(0, 1)] = 2;
203		arr[(0, 2)] = 3;
204		arr[(1, 0)] = 4;
205		arr[(1, 1)] = 5;
206		arr[(1, 2)] = 6;
207
208		assert_eq!(arr.sum_rows(), vec![5, 7, 9]);
209	}
210}