binius_ntt/
strided_array.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use core::slice;
4use std::ops::{Index, IndexMut, Range};
5
6use binius_maybe_rayon::prelude::*;
7
8#[derive(Debug, thiserror::Error)]
9pub enum Error {
10	#[error("dimensions do not match data size")]
11	DimensionMismatch,
12}
13
14/// A mutable view of an 2D array in row-major order that allows for parallel processing of
15/// vertical slices.
16#[derive(Debug)]
17pub struct StridedArray2DViewMut<'a, T> {
18	data: &'a mut [T],
19	data_width: usize,
20	height: usize,
21	cols: Range<usize>,
22}
23
24impl<'a, T> StridedArray2DViewMut<'a, T> {
25	/// Create a single-piece view of the data.
26	pub const fn without_stride(
27		data: &'a mut [T],
28		height: usize,
29		width: usize,
30	) -> Result<Self, Error> {
31		if width * height != data.len() {
32			return Err(Error::DimensionMismatch);
33		}
34		Ok(Self {
35			data,
36			data_width: width,
37			height,
38			cols: 0..width,
39		})
40	}
41
42	/// Returns a reference to the data at the given indices without bounds checking.
43	/// # Safety
44	/// The caller must ensure that `i < self.height` and `j < self.width()`.
45	pub unsafe fn get_unchecked_ref(&self, i: usize, j: usize) -> &T {
46		debug_assert!(i < self.height);
47		debug_assert!(j < self.width());
48		self.data
49			.get_unchecked(i * self.data_width + j + self.cols.start)
50	}
51
52	/// Returns a mutable reference to the data at the given indices without bounds checking.
53	/// # Safety
54	/// The caller must ensure that `i < self.height` and `j < self.width()`.
55	pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
56		debug_assert!(i < self.height);
57		debug_assert!(j < self.width());
58		self.data
59			.get_unchecked_mut(i * self.data_width + j + self.cols.start)
60	}
61
62	pub const fn height(&self) -> usize {
63		self.height
64	}
65
66	pub const fn width(&self) -> usize {
67		self.cols.end - self.cols.start
68	}
69
70	/// Returns iterator over vertical slices of the data for the given stride.
71	#[allow(dead_code)]
72	pub fn into_strides(self, stride: usize) -> impl Iterator<Item = Self> + 'a {
73		let Self {
74			data,
75			data_width,
76			height,
77			cols,
78		} = self;
79
80		cols.clone().step_by(stride).map(move |start| {
81			let end = (start + stride).min(cols.end);
82			Self {
83				// Safety: different instances of StridedArray2DViewMut created with the same data slice
84				// do not access overlapping indices.
85				data: unsafe { slice::from_raw_parts_mut(data.as_mut_ptr(), data.len()) },
86				data_width,
87				height,
88				cols: start..end,
89			}
90		})
91	}
92
93	/// Returns parallel iterator over vertical slices of the data for the given stride.
94	pub fn into_par_strides(self, stride: usize) -> impl ParallelIterator<Item = Self> + 'a
95	where
96		T: Send + Sync,
97	{
98		self.cols
99			.clone()
100			.into_par_iter()
101			.step_by(stride)
102			.map(move |start| {
103				let end = (start + stride).min(self.cols.end);
104				// We are setting the same lifetime as `self` captures.
105				Self {
106					// Safety: different instances of StridedArray2DViewMut created with the same data slice
107					// do not access overlapping indices.
108					data: unsafe {
109						slice::from_raw_parts_mut(self.data.as_ptr() as *mut T, self.data.len())
110					},
111					data_width: self.data_width,
112					height: self.height,
113					cols: start..end,
114				}
115			})
116	}
117}
118
119impl<T> Index<(usize, usize)> for StridedArray2DViewMut<'_, T> {
120	type Output = T;
121
122	fn index(&self, (i, j): (usize, usize)) -> &T {
123		assert!(i < self.height());
124		assert!(j < self.width());
125		unsafe { self.get_unchecked_ref(i, j) }
126	}
127}
128
129impl<T> IndexMut<(usize, usize)> for StridedArray2DViewMut<'_, T> {
130	fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
131		assert!(i < self.height());
132		assert!(j < self.width());
133		unsafe { self.get_unchecked_mut(i, j) }
134	}
135}
136
137#[cfg(test)]
138mod tests {
139	use std::array;
140
141	use super::*;
142
143	#[test]
144	fn test_indexing() {
145		let mut data = array::from_fn::<_, 12, _>(|i| i);
146		let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
147		assert_eq!(arr[(3, 1)], 10);
148		arr[(2, 2)] = 88;
149		assert_eq!(data[8], 88);
150	}
151
152	#[test]
153	fn test_strides() {
154		let mut data = array::from_fn::<_, 12, _>(|i| i);
155		let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
156
157		{
158			let mut strides = arr.into_strides(2);
159			let mut stride0 = strides.next().unwrap();
160			let mut stride1 = strides.next().unwrap();
161			assert!(strides.next().is_none());
162
163			assert_eq!(stride0.width(), 2);
164			assert_eq!(stride1.width(), 1);
165
166			stride0[(0, 0)] = 88;
167			stride1[(1, 0)] = 99;
168		}
169
170		assert_eq!(data[0], 88);
171		assert_eq!(data[5], 99);
172	}
173
174	#[test]
175	fn test_parallel_strides() {
176		let mut data = array::from_fn::<_, 12, _>(|i| i);
177		let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
178
179		{
180			let mut strides: Vec<_> = arr.into_par_strides(2).collect();
181			assert_eq!(strides.len(), 2);
182			assert_eq!(strides[0].width(), 2);
183			assert_eq!(strides[1].width(), 1);
184
185			strides[0][(0, 0)] = 88;
186			strides[1][(1, 0)] = 99;
187		}
188
189		assert_eq!(data[0], 88);
190		assert_eq!(data[5], 99);
191	}
192}