binius_utils/
strided_array.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use core::slice;
4use std::ops::{Index, IndexMut, Range};
5
6use crate::rayon::prelude::*;
7
8#[derive(Debug, thiserror::Error)]
9pub enum Error {
10	#[error("dimensions do not match data size")]
11	DimensionMismatch,
12}
13
14/// A mutable view of an 2D array in row-major order that allows for parallel processing of
15/// vertical slices.
16#[derive(Debug)]
17pub struct StridedArray2DViewMut<'a, T> {
18	data: &'a mut [T],
19	data_width: usize,
20	height: usize,
21	cols: Range<usize>,
22}
23
24impl<'a, T> StridedArray2DViewMut<'a, T> {
25	/// Create a single-piece view of the data.
26	pub const fn without_stride(
27		data: &'a mut [T],
28		height: usize,
29		width: usize,
30	) -> Result<Self, Error> {
31		if width * height != data.len() {
32			return Err(Error::DimensionMismatch);
33		}
34		Ok(Self {
35			data,
36			data_width: width,
37			height,
38			cols: 0..width,
39		})
40	}
41
42	/// Returns a reference to the data at the given indices without bounds checking.
43	/// # Safety
44	/// The caller must ensure that `i < self.height` and `j < self.width()`.
45	pub unsafe fn get_unchecked_ref(&self, i: usize, j: usize) -> &T {
46		debug_assert!(i < self.height);
47		debug_assert!(j < self.width());
48		unsafe {
49			self.data
50				.get_unchecked(i * self.data_width + j + self.cols.start)
51		}
52	}
53
54	/// Returns a mutable reference to the data at the given indices without bounds checking.
55	/// # Safety
56	/// The caller must ensure that `i < self.height` and `j < self.width()`.
57	pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
58		debug_assert!(i < self.height);
59		debug_assert!(j < self.width());
60		unsafe {
61			self.data
62				.get_unchecked_mut(i * self.data_width + j + self.cols.start)
63		}
64	}
65
66	pub const fn height(&self) -> usize {
67		self.height
68	}
69
70	pub const fn width(&self) -> usize {
71		self.cols.end - self.cols.start
72	}
73
74	/// Iterate over the mutable references to the elements in the specified column.
75	pub fn iter_column_mut(&mut self, col: usize) -> impl Iterator<Item = &mut T> + '_ {
76		assert!(col < self.width());
77		let start = col + self.cols.start;
78		let data_ptr = self.data.as_mut_ptr();
79		(0..self.height).map(move |i|
80				// Safety:
81				// - `data_ptr` points to the start of the data slice.
82				// - `col` is within bounds of the width.
83				// - different iterator values do not overlap.
84				unsafe { &mut *data_ptr.add(i * self.data_width + start) })
85	}
86
87	/// Returns iterator over vertical slices of the data for the given stride.
88	#[allow(dead_code)]
89	pub fn into_strides(self, stride: usize) -> impl Iterator<Item = Self> + 'a {
90		let Self {
91			data,
92			data_width,
93			height,
94			cols,
95		} = self;
96
97		cols.clone().step_by(stride).map(move |start| {
98			let end = (start + stride).min(cols.end);
99			Self {
100				// Safety: different instances of StridedArray2DViewMut created with the same data
101				// slice do not access overlapping indices.
102				data: unsafe { slice::from_raw_parts_mut(data.as_mut_ptr(), data.len()) },
103				data_width,
104				height,
105				cols: start..end,
106			}
107		})
108	}
109
110	/// Returns parallel iterator over vertical slices of the data for the given stride.
111	pub fn into_par_strides(self, stride: usize) -> impl ParallelIterator<Item = Self> + 'a
112	where
113		T: Send + Sync,
114	{
115		self.cols
116			.clone()
117			.into_par_iter()
118			.step_by(stride)
119			.map(move |start| {
120				let end = (start + stride).min(self.cols.end);
121				// We are setting the same lifetime as `self` captures.
122				Self {
123					// Safety: different instances of StridedArray2DViewMut created with the same
124					// data slice do not access overlapping indices.
125					data: unsafe {
126						slice::from_raw_parts_mut(self.data.as_ptr() as *mut T, self.data.len())
127					},
128					data_width: self.data_width,
129					height: self.height,
130					cols: start..end,
131				}
132			})
133	}
134
135	/// Returns iterator over single-column mutable views of the data.
136	pub fn iter_cols(&mut self) -> impl Iterator<Item = StridedArray2DColMut<'_, T>> + '_ {
137		let data_ptr = self.data.as_mut_ptr();
138		let data_len = self.data.len();
139		self.cols.clone().map(move |col| StridedArray2DColMut {
140			// Safety: different instances of StridedArray2DColMut created with the same data
141			// slice do not access overlapping indices since each accesses a different column.
142			data: unsafe { slice::from_raw_parts_mut(data_ptr, data_len) },
143			data_width: self.data_width,
144			height: self.height,
145			col,
146		})
147	}
148
149	/// Returns parallel iterator over single-column mutable views of the data.
150	pub fn par_iter_cols(
151		&mut self,
152	) -> impl IndexedParallelIterator<Item = StridedArray2DColMut<'_, T>> + '_
153	where
154		T: Send + Sync,
155	{
156		let data_ptr = SendPtr(self.data.as_mut_ptr());
157		let data_len = self.data.len();
158		let data_width = self.data_width;
159		let height = self.height;
160		self.cols.clone().into_par_iter().map(move |col| {
161			StridedArray2DColMut {
162				// Safety: different instances of StridedArray2DColMut created with the same data
163				// slice do not access overlapping indices since each accesses a different column.
164				data: unsafe { slice::from_raw_parts_mut(data_ptr.as_ptr(), data_len) },
165				data_width,
166				height,
167				col,
168			}
169		})
170	}
171}
172
173/// A mutable view of a single column (vertical slice) of a 2D array in row-major order.
174#[derive(Debug)]
175pub struct StridedArray2DColMut<'a, T> {
176	data: &'a mut [T],
177	data_width: usize,
178	height: usize,
179	col: usize,
180}
181
182impl<'a, T> StridedArray2DColMut<'a, T> {
183	pub const fn height(&self) -> usize {
184		self.height
185	}
186
187	/// Returns a reference to the data at the given row index without bounds checking.
188	/// # Safety
189	/// The caller must ensure that `i < self.height`.
190	pub unsafe fn get_unchecked_ref(&self, i: usize) -> &T {
191		debug_assert!(i < self.height);
192		unsafe { self.data.get_unchecked(i * self.data_width + self.col) }
193	}
194
195	/// Returns a mutable reference to the data at the given row index without bounds checking.
196	/// # Safety
197	/// The caller must ensure that `i < self.height`.
198	pub unsafe fn get_unchecked_mut(&mut self, i: usize) -> &mut T {
199		debug_assert!(i < self.height);
200		unsafe { self.data.get_unchecked_mut(i * self.data_width + self.col) }
201	}
202}
203
204impl<T> Index<usize> for StridedArray2DColMut<'_, T> {
205	type Output = T;
206
207	fn index(&self, i: usize) -> &T {
208		assert!(i < self.height());
209		unsafe { self.get_unchecked_ref(i) }
210	}
211}
212
213impl<T> IndexMut<usize> for StridedArray2DColMut<'_, T> {
214	fn index_mut(&mut self, i: usize) -> &mut Self::Output {
215		assert!(i < self.height());
216		unsafe { self.get_unchecked_mut(i) }
217	}
218}
219
220impl<T> Index<(usize, usize)> for StridedArray2DViewMut<'_, T> {
221	type Output = T;
222
223	fn index(&self, (i, j): (usize, usize)) -> &T {
224		assert!(i < self.height());
225		assert!(j < self.width());
226		unsafe { self.get_unchecked_ref(i, j) }
227	}
228}
229
230impl<T> IndexMut<(usize, usize)> for StridedArray2DViewMut<'_, T> {
231	fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
232		assert!(i < self.height());
233		assert!(j < self.width());
234		unsafe { self.get_unchecked_mut(i, j) }
235	}
236}
237
238/// A wrapper around a raw pointer that implements Send and Sync.
239///
240/// # Safety
241/// The caller must ensure that the pointer is valid and that concurrent access
242/// through multiple `SendPtr` instances does not cause data races.
243struct SendPtr<T>(*mut T);
244
245impl<T> SendPtr<T> {
246	fn as_ptr(self) -> *mut T {
247		self.0
248	}
249}
250
251impl<T> Clone for SendPtr<T> {
252	fn clone(&self) -> Self {
253		*self
254	}
255}
256
257impl<T> Copy for SendPtr<T> {}
258
259// Safety: SendPtr is only used internally where we ensure non-overlapping access
260unsafe impl<T: Send> Send for SendPtr<T> {}
261unsafe impl<T: Sync> Sync for SendPtr<T> {}
262
263#[cfg(test)]
264mod tests {
265	use std::array;
266
267	use super::*;
268
269	#[test]
270	fn test_indexing() {
271		let mut data = array::from_fn::<_, 12, _>(|i| i);
272		let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
273		assert_eq!(arr[(3, 1)], 10);
274		arr[(2, 2)] = 88;
275		assert_eq!(data[8], 88);
276	}
277
278	#[test]
279	fn test_strides() {
280		let mut data = array::from_fn::<_, 12, _>(|i| i);
281		let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
282
283		{
284			let mut strides = arr.into_strides(2);
285			let mut stride0 = strides.next().unwrap();
286			let mut stride1 = strides.next().unwrap();
287			assert!(strides.next().is_none());
288
289			assert_eq!(stride0.width(), 2);
290			assert_eq!(stride1.width(), 1);
291
292			stride0[(0, 0)] = 88;
293			stride1[(1, 0)] = 99;
294		}
295
296		assert_eq!(data[0], 88);
297		assert_eq!(data[5], 99);
298	}
299
300	#[test]
301	fn test_parallel_strides() {
302		let mut data = array::from_fn::<_, 12, _>(|i| i);
303		let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
304
305		{
306			let mut strides: Vec<_> = arr.into_par_strides(2).collect();
307			assert_eq!(strides.len(), 2);
308			assert_eq!(strides[0].width(), 2);
309			assert_eq!(strides[1].width(), 1);
310
311			strides[0][(0, 0)] = 88;
312			strides[1][(1, 0)] = 99;
313		}
314
315		assert_eq!(data[0], 88);
316		assert_eq!(data[5], 99);
317	}
318
319	#[test]
320	fn test_iter_column_mut() {
321		let mut data = array::from_fn::<_, 12, _>(|i| i);
322		let data_clone = data;
323		let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
324
325		let mut col_iter = arr.iter_column_mut(1);
326		assert_eq!(col_iter.next().copied(), Some(data_clone[1]));
327		assert_eq!(col_iter.next().copied(), Some(data_clone[4]));
328		assert_eq!(col_iter.next().copied(), Some(data_clone[7]));
329		assert_eq!(col_iter.next().copied(), Some(data_clone[10]));
330		assert_eq!(col_iter.next(), None);
331	}
332
333	#[test]
334	fn test_col_mut_indexing() {
335		let mut data = array::from_fn::<_, 12, _>(|i| i);
336		let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
337
338		let mut cols: Vec<_> = arr.iter_cols().collect();
339		assert_eq!(cols.len(), 3);
340
341		// Test reading - column 1 contains elements at indices 1, 4, 7, 10
342		assert_eq!(cols[1][0], 1);
343		assert_eq!(cols[1][1], 4);
344		assert_eq!(cols[1][2], 7);
345		assert_eq!(cols[1][3], 10);
346
347		// Test writing
348		cols[0][2] = 88;
349		cols[2][1] = 99;
350
351		assert_eq!(data[6], 88); // row 2, col 0
352		assert_eq!(data[5], 99); // row 1, col 2
353	}
354
355	#[test]
356	fn test_iter_cols() {
357		let mut data = array::from_fn::<_, 12, _>(|i| i);
358		let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
359
360		{
361			let mut cols = arr.iter_cols();
362			let mut col0 = cols.next().unwrap();
363			let mut col1 = cols.next().unwrap();
364			let mut col2 = cols.next().unwrap();
365			assert!(cols.next().is_none());
366
367			assert_eq!(col0.height(), 4);
368			assert_eq!(col1.height(), 4);
369			assert_eq!(col2.height(), 4);
370
371			col0[0] = 88;
372			col1[1] = 99;
373			col2[3] = 77;
374		}
375
376		assert_eq!(data[0], 88); // row 0, col 0
377		assert_eq!(data[4], 99); // row 1, col 1
378		assert_eq!(data[11], 77); // row 3, col 2
379	}
380
381	#[test]
382	fn test_par_iter_cols() {
383		let mut data = array::from_fn::<_, 12, _>(|i| i);
384		let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
385
386		{
387			let mut cols: Vec<_> = arr.par_iter_cols().collect();
388			assert_eq!(cols.len(), 3);
389
390			cols[0][0] = 88;
391			cols[1][1] = 99;
392			cols[2][3] = 77;
393		}
394
395		assert_eq!(data[0], 88); // row 0, col 0
396		assert_eq!(data[4], 99); // row 1, col 1
397		assert_eq!(data[11], 77); // row 3, col 2
398	}
399}