1use core::slice;
4use std::ops::{Index, IndexMut, Range};
5
6use binius_maybe_rayon::prelude::*;
7
8#[derive(Debug, thiserror::Error)]
9pub enum Error {
10 #[error("dimensions do not match data size")]
11 DimensionMismatch,
12}
13
14#[derive(Debug)]
17pub struct StridedArray2DViewMut<'a, T> {
18 data: &'a mut [T],
19 data_width: usize,
20 height: usize,
21 cols: Range<usize>,
22}
23
24impl<'a, T> StridedArray2DViewMut<'a, T> {
25 pub const fn without_stride(
27 data: &'a mut [T],
28 height: usize,
29 width: usize,
30 ) -> Result<Self, Error> {
31 if width * height != data.len() {
32 return Err(Error::DimensionMismatch);
33 }
34 Ok(Self {
35 data,
36 data_width: width,
37 height,
38 cols: 0..width,
39 })
40 }
41
42 pub unsafe fn get_unchecked_ref(&self, i: usize, j: usize) -> &T {
46 debug_assert!(i < self.height);
47 debug_assert!(j < self.width());
48 unsafe {
49 self.data
50 .get_unchecked(i * self.data_width + j + self.cols.start)
51 }
52 }
53
54 pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
58 debug_assert!(i < self.height);
59 debug_assert!(j < self.width());
60 unsafe {
61 self.data
62 .get_unchecked_mut(i * self.data_width + j + self.cols.start)
63 }
64 }
65
66 pub const fn height(&self) -> usize {
67 self.height
68 }
69
70 pub const fn width(&self) -> usize {
71 self.cols.end - self.cols.start
72 }
73
74 pub fn iter_column_mut(&mut self, col: usize) -> impl Iterator<Item = &mut T> + '_ {
76 assert!(col < self.width());
77 let start = col + self.cols.start;
78 let data_ptr = self.data.as_mut_ptr();
79 (0..self.height).map(move |i|
80 unsafe { &mut *data_ptr.add(i * self.data_width + start) })
85 }
86
87 #[allow(dead_code)]
89 pub fn into_strides(self, stride: usize) -> impl Iterator<Item = Self> + 'a {
90 let Self {
91 data,
92 data_width,
93 height,
94 cols,
95 } = self;
96
97 cols.clone().step_by(stride).map(move |start| {
98 let end = (start + stride).min(cols.end);
99 Self {
100 data: unsafe { slice::from_raw_parts_mut(data.as_mut_ptr(), data.len()) },
103 data_width,
104 height,
105 cols: start..end,
106 }
107 })
108 }
109
110 pub fn into_par_strides(self, stride: usize) -> impl ParallelIterator<Item = Self> + 'a
112 where
113 T: Send + Sync,
114 {
115 self.cols
116 .clone()
117 .into_par_iter()
118 .step_by(stride)
119 .map(move |start| {
120 let end = (start + stride).min(self.cols.end);
121 Self {
123 data: unsafe {
126 slice::from_raw_parts_mut(self.data.as_ptr() as *mut T, self.data.len())
127 },
128 data_width: self.data_width,
129 height: self.height,
130 cols: start..end,
131 }
132 })
133 }
134}
135
136impl<T> Index<(usize, usize)> for StridedArray2DViewMut<'_, T> {
137 type Output = T;
138
139 fn index(&self, (i, j): (usize, usize)) -> &T {
140 assert!(i < self.height());
141 assert!(j < self.width());
142 unsafe { self.get_unchecked_ref(i, j) }
143 }
144}
145
146impl<T> IndexMut<(usize, usize)> for StridedArray2DViewMut<'_, T> {
147 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
148 assert!(i < self.height());
149 assert!(j < self.width());
150 unsafe { self.get_unchecked_mut(i, j) }
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use std::array;
157
158 use super::*;
159
160 #[test]
161 fn test_indexing() {
162 let mut data = array::from_fn::<_, 12, _>(|i| i);
163 let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
164 assert_eq!(arr[(3, 1)], 10);
165 arr[(2, 2)] = 88;
166 assert_eq!(data[8], 88);
167 }
168
169 #[test]
170 fn test_strides() {
171 let mut data = array::from_fn::<_, 12, _>(|i| i);
172 let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
173
174 {
175 let mut strides = arr.into_strides(2);
176 let mut stride0 = strides.next().unwrap();
177 let mut stride1 = strides.next().unwrap();
178 assert!(strides.next().is_none());
179
180 assert_eq!(stride0.width(), 2);
181 assert_eq!(stride1.width(), 1);
182
183 stride0[(0, 0)] = 88;
184 stride1[(1, 0)] = 99;
185 }
186
187 assert_eq!(data[0], 88);
188 assert_eq!(data[5], 99);
189 }
190
191 #[test]
192 fn test_parallel_strides() {
193 let mut data = array::from_fn::<_, 12, _>(|i| i);
194 let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
195
196 {
197 let mut strides: Vec<_> = arr.into_par_strides(2).collect();
198 assert_eq!(strides.len(), 2);
199 assert_eq!(strides[0].width(), 2);
200 assert_eq!(strides[1].width(), 1);
201
202 strides[0][(0, 0)] = 88;
203 strides[1][(1, 0)] = 99;
204 }
205
206 assert_eq!(data[0], 88);
207 assert_eq!(data[5], 99);
208 }
209
210 #[test]
211 fn test_iter_column_mut() {
212 let mut data = array::from_fn::<_, 12, _>(|i| i);
213 let data_clone = data;
214 let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
215
216 let mut col_iter = arr.iter_column_mut(1);
217 assert_eq!(col_iter.next().copied(), Some(data_clone[1]));
218 assert_eq!(col_iter.next().copied(), Some(data_clone[4]));
219 assert_eq!(col_iter.next().copied(), Some(data_clone[7]));
220 assert_eq!(col_iter.next().copied(), Some(data_clone[10]));
221 assert_eq!(col_iter.next(), None);
222 }
223}