binius_utils/
strided_array.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use core::slice;
4use std::ops::{Index, IndexMut, Range};
5
6use binius_maybe_rayon::prelude::*;
7
8#[derive(Debug, thiserror::Error)]
9pub enum Error {
10	#[error("dimensions do not match data size")]
11	DimensionMismatch,
12}
13
14/// A mutable view of an 2D array in row-major order that allows for parallel processing of
15/// vertical slices.
16#[derive(Debug)]
17pub struct StridedArray2DViewMut<'a, T> {
18	data: &'a mut [T],
19	data_width: usize,
20	height: usize,
21	cols: Range<usize>,
22}
23
24impl<'a, T> StridedArray2DViewMut<'a, T> {
25	/// Create a single-piece view of the data.
26	pub const fn without_stride(
27		data: &'a mut [T],
28		height: usize,
29		width: usize,
30	) -> Result<Self, Error> {
31		if width * height != data.len() {
32			return Err(Error::DimensionMismatch);
33		}
34		Ok(Self {
35			data,
36			data_width: width,
37			height,
38			cols: 0..width,
39		})
40	}
41
42	/// Returns a reference to the data at the given indices without bounds checking.
43	/// # Safety
44	/// The caller must ensure that `i < self.height` and `j < self.width()`.
45	pub unsafe fn get_unchecked_ref(&self, i: usize, j: usize) -> &T {
46		debug_assert!(i < self.height);
47		debug_assert!(j < self.width());
48		unsafe {
49			self.data
50				.get_unchecked(i * self.data_width + j + self.cols.start)
51		}
52	}
53
54	/// Returns a mutable reference to the data at the given indices without bounds checking.
55	/// # Safety
56	/// The caller must ensure that `i < self.height` and `j < self.width()`.
57	pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
58		debug_assert!(i < self.height);
59		debug_assert!(j < self.width());
60		unsafe {
61			self.data
62				.get_unchecked_mut(i * self.data_width + j + self.cols.start)
63		}
64	}
65
66	pub const fn height(&self) -> usize {
67		self.height
68	}
69
70	pub const fn width(&self) -> usize {
71		self.cols.end - self.cols.start
72	}
73
74	/// Iterate over the mutable references to the elements in the specified column.
75	pub fn iter_column_mut(&mut self, col: usize) -> impl Iterator<Item = &mut T> + '_ {
76		assert!(col < self.width());
77		let start = col + self.cols.start;
78		let data_ptr = self.data.as_mut_ptr();
79		(0..self.height).map(move |i|
80				// Safety: 
81				// - `data_ptr` points to the start of the data slice.
82				// - `col` is within bounds of the width.
83				// - different iterator values do not overlap.
84				unsafe { &mut *data_ptr.add(i * self.data_width + start) })
85	}
86
87	/// Returns iterator over vertical slices of the data for the given stride.
88	#[allow(dead_code)]
89	pub fn into_strides(self, stride: usize) -> impl Iterator<Item = Self> + 'a {
90		let Self {
91			data,
92			data_width,
93			height,
94			cols,
95		} = self;
96
97		cols.clone().step_by(stride).map(move |start| {
98			let end = (start + stride).min(cols.end);
99			Self {
100				// Safety: different instances of StridedArray2DViewMut created with the same data
101				// slice do not access overlapping indices.
102				data: unsafe { slice::from_raw_parts_mut(data.as_mut_ptr(), data.len()) },
103				data_width,
104				height,
105				cols: start..end,
106			}
107		})
108	}
109
110	/// Returns parallel iterator over vertical slices of the data for the given stride.
111	pub fn into_par_strides(self, stride: usize) -> impl ParallelIterator<Item = Self> + 'a
112	where
113		T: Send + Sync,
114	{
115		self.cols
116			.clone()
117			.into_par_iter()
118			.step_by(stride)
119			.map(move |start| {
120				let end = (start + stride).min(self.cols.end);
121				// We are setting the same lifetime as `self` captures.
122				Self {
123					// Safety: different instances of StridedArray2DViewMut created with the same
124					// data slice do not access overlapping indices.
125					data: unsafe {
126						slice::from_raw_parts_mut(self.data.as_ptr() as *mut T, self.data.len())
127					},
128					data_width: self.data_width,
129					height: self.height,
130					cols: start..end,
131				}
132			})
133	}
134}
135
136impl<T> Index<(usize, usize)> for StridedArray2DViewMut<'_, T> {
137	type Output = T;
138
139	fn index(&self, (i, j): (usize, usize)) -> &T {
140		assert!(i < self.height());
141		assert!(j < self.width());
142		unsafe { self.get_unchecked_ref(i, j) }
143	}
144}
145
146impl<T> IndexMut<(usize, usize)> for StridedArray2DViewMut<'_, T> {
147	fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
148		assert!(i < self.height());
149		assert!(j < self.width());
150		unsafe { self.get_unchecked_mut(i, j) }
151	}
152}
153
154#[cfg(test)]
155mod tests {
156	use std::array;
157
158	use super::*;
159
160	#[test]
161	fn test_indexing() {
162		let mut data = array::from_fn::<_, 12, _>(|i| i);
163		let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
164		assert_eq!(arr[(3, 1)], 10);
165		arr[(2, 2)] = 88;
166		assert_eq!(data[8], 88);
167	}
168
169	#[test]
170	fn test_strides() {
171		let mut data = array::from_fn::<_, 12, _>(|i| i);
172		let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
173
174		{
175			let mut strides = arr.into_strides(2);
176			let mut stride0 = strides.next().unwrap();
177			let mut stride1 = strides.next().unwrap();
178			assert!(strides.next().is_none());
179
180			assert_eq!(stride0.width(), 2);
181			assert_eq!(stride1.width(), 1);
182
183			stride0[(0, 0)] = 88;
184			stride1[(1, 0)] = 99;
185		}
186
187		assert_eq!(data[0], 88);
188		assert_eq!(data[5], 99);
189	}
190
191	#[test]
192	fn test_parallel_strides() {
193		let mut data = array::from_fn::<_, 12, _>(|i| i);
194		let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
195
196		{
197			let mut strides: Vec<_> = arr.into_par_strides(2).collect();
198			assert_eq!(strides.len(), 2);
199			assert_eq!(strides[0].width(), 2);
200			assert_eq!(strides[1].width(), 1);
201
202			strides[0][(0, 0)] = 88;
203			strides[1][(1, 0)] = 99;
204		}
205
206		assert_eq!(data[0], 88);
207		assert_eq!(data[5], 99);
208	}
209
210	#[test]
211	fn test_iter_column_mut() {
212		let mut data = array::from_fn::<_, 12, _>(|i| i);
213		let data_clone = data;
214		let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
215
216		let mut col_iter = arr.iter_column_mut(1);
217		assert_eq!(col_iter.next().copied(), Some(data_clone[1]));
218		assert_eq!(col_iter.next().copied(), Some(data_clone[4]));
219		assert_eq!(col_iter.next().copied(), Some(data_clone[7]));
220		assert_eq!(col_iter.next().copied(), Some(data_clone[10]));
221		assert_eq!(col_iter.next(), None);
222	}
223}