1use core::slice;
4use std::ops::{Index, IndexMut, Range};
5
6use crate::rayon::prelude::*;
7
8#[derive(Debug, thiserror::Error)]
9pub enum Error {
10 #[error("dimensions do not match data size")]
11 DimensionMismatch,
12}
13
14#[derive(Debug)]
17pub struct StridedArray2DViewMut<'a, T> {
18 data: &'a mut [T],
19 data_width: usize,
20 height: usize,
21 cols: Range<usize>,
22}
23
24impl<'a, T> StridedArray2DViewMut<'a, T> {
25 pub const fn without_stride(
27 data: &'a mut [T],
28 height: usize,
29 width: usize,
30 ) -> Result<Self, Error> {
31 if width * height != data.len() {
32 return Err(Error::DimensionMismatch);
33 }
34 Ok(Self {
35 data,
36 data_width: width,
37 height,
38 cols: 0..width,
39 })
40 }
41
42 pub unsafe fn get_unchecked_ref(&self, i: usize, j: usize) -> &T {
46 debug_assert!(i < self.height);
47 debug_assert!(j < self.width());
48 unsafe {
49 self.data
50 .get_unchecked(i * self.data_width + j + self.cols.start)
51 }
52 }
53
54 pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
58 debug_assert!(i < self.height);
59 debug_assert!(j < self.width());
60 unsafe {
61 self.data
62 .get_unchecked_mut(i * self.data_width + j + self.cols.start)
63 }
64 }
65
66 pub const fn height(&self) -> usize {
67 self.height
68 }
69
70 pub const fn width(&self) -> usize {
71 self.cols.end - self.cols.start
72 }
73
74 pub fn iter_column_mut(&mut self, col: usize) -> impl Iterator<Item = &mut T> + '_ {
76 assert!(col < self.width());
77 let start = col + self.cols.start;
78 let data_ptr = self.data.as_mut_ptr();
79 (0..self.height).map(move |i|
80 unsafe { &mut *data_ptr.add(i * self.data_width + start) })
85 }
86
87 #[allow(dead_code)]
89 pub fn into_strides(self, stride: usize) -> impl Iterator<Item = Self> + 'a {
90 let Self {
91 data,
92 data_width,
93 height,
94 cols,
95 } = self;
96
97 cols.clone().step_by(stride).map(move |start| {
98 let end = (start + stride).min(cols.end);
99 Self {
100 data: unsafe { slice::from_raw_parts_mut(data.as_mut_ptr(), data.len()) },
103 data_width,
104 height,
105 cols: start..end,
106 }
107 })
108 }
109
110 pub fn into_par_strides(self, stride: usize) -> impl ParallelIterator<Item = Self> + 'a
112 where
113 T: Send + Sync,
114 {
115 self.cols
116 .clone()
117 .into_par_iter()
118 .step_by(stride)
119 .map(move |start| {
120 let end = (start + stride).min(self.cols.end);
121 Self {
123 data: unsafe {
126 slice::from_raw_parts_mut(self.data.as_ptr() as *mut T, self.data.len())
127 },
128 data_width: self.data_width,
129 height: self.height,
130 cols: start..end,
131 }
132 })
133 }
134
135 pub fn iter_cols(&mut self) -> impl Iterator<Item = StridedArray2DColMut<'_, T>> + '_ {
137 let data_ptr = self.data.as_mut_ptr();
138 let data_len = self.data.len();
139 self.cols.clone().map(move |col| StridedArray2DColMut {
140 data: unsafe { slice::from_raw_parts_mut(data_ptr, data_len) },
143 data_width: self.data_width,
144 height: self.height,
145 col,
146 })
147 }
148
149 pub fn par_iter_cols(
151 &mut self,
152 ) -> impl IndexedParallelIterator<Item = StridedArray2DColMut<'_, T>> + '_
153 where
154 T: Send + Sync,
155 {
156 let data_ptr = SendPtr(self.data.as_mut_ptr());
157 let data_len = self.data.len();
158 let data_width = self.data_width;
159 let height = self.height;
160 self.cols.clone().into_par_iter().map(move |col| {
161 StridedArray2DColMut {
162 data: unsafe { slice::from_raw_parts_mut(data_ptr.as_ptr(), data_len) },
165 data_width,
166 height,
167 col,
168 }
169 })
170 }
171}
172
173#[derive(Debug)]
175pub struct StridedArray2DColMut<'a, T> {
176 data: &'a mut [T],
177 data_width: usize,
178 height: usize,
179 col: usize,
180}
181
182impl<'a, T> StridedArray2DColMut<'a, T> {
183 pub const fn height(&self) -> usize {
184 self.height
185 }
186
187 pub unsafe fn get_unchecked_ref(&self, i: usize) -> &T {
191 debug_assert!(i < self.height);
192 unsafe { self.data.get_unchecked(i * self.data_width + self.col) }
193 }
194
195 pub unsafe fn get_unchecked_mut(&mut self, i: usize) -> &mut T {
199 debug_assert!(i < self.height);
200 unsafe { self.data.get_unchecked_mut(i * self.data_width + self.col) }
201 }
202}
203
204impl<T> Index<usize> for StridedArray2DColMut<'_, T> {
205 type Output = T;
206
207 fn index(&self, i: usize) -> &T {
208 assert!(i < self.height());
209 unsafe { self.get_unchecked_ref(i) }
210 }
211}
212
213impl<T> IndexMut<usize> for StridedArray2DColMut<'_, T> {
214 fn index_mut(&mut self, i: usize) -> &mut Self::Output {
215 assert!(i < self.height());
216 unsafe { self.get_unchecked_mut(i) }
217 }
218}
219
220impl<T> Index<(usize, usize)> for StridedArray2DViewMut<'_, T> {
221 type Output = T;
222
223 fn index(&self, (i, j): (usize, usize)) -> &T {
224 assert!(i < self.height());
225 assert!(j < self.width());
226 unsafe { self.get_unchecked_ref(i, j) }
227 }
228}
229
230impl<T> IndexMut<(usize, usize)> for StridedArray2DViewMut<'_, T> {
231 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
232 assert!(i < self.height());
233 assert!(j < self.width());
234 unsafe { self.get_unchecked_mut(i, j) }
235 }
236}
237
238struct SendPtr<T>(*mut T);
244
245impl<T> SendPtr<T> {
246 fn as_ptr(self) -> *mut T {
247 self.0
248 }
249}
250
251impl<T> Clone for SendPtr<T> {
252 fn clone(&self) -> Self {
253 *self
254 }
255}
256
257impl<T> Copy for SendPtr<T> {}
258
259unsafe impl<T: Send> Send for SendPtr<T> {}
261unsafe impl<T: Sync> Sync for SendPtr<T> {}
262
263#[cfg(test)]
264mod tests {
265 use std::array;
266
267 use super::*;
268
269 #[test]
270 fn test_indexing() {
271 let mut data = array::from_fn::<_, 12, _>(|i| i);
272 let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
273 assert_eq!(arr[(3, 1)], 10);
274 arr[(2, 2)] = 88;
275 assert_eq!(data[8], 88);
276 }
277
278 #[test]
279 fn test_strides() {
280 let mut data = array::from_fn::<_, 12, _>(|i| i);
281 let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
282
283 {
284 let mut strides = arr.into_strides(2);
285 let mut stride0 = strides.next().unwrap();
286 let mut stride1 = strides.next().unwrap();
287 assert!(strides.next().is_none());
288
289 assert_eq!(stride0.width(), 2);
290 assert_eq!(stride1.width(), 1);
291
292 stride0[(0, 0)] = 88;
293 stride1[(1, 0)] = 99;
294 }
295
296 assert_eq!(data[0], 88);
297 assert_eq!(data[5], 99);
298 }
299
300 #[test]
301 fn test_parallel_strides() {
302 let mut data = array::from_fn::<_, 12, _>(|i| i);
303 let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
304
305 {
306 let mut strides: Vec<_> = arr.into_par_strides(2).collect();
307 assert_eq!(strides.len(), 2);
308 assert_eq!(strides[0].width(), 2);
309 assert_eq!(strides[1].width(), 1);
310
311 strides[0][(0, 0)] = 88;
312 strides[1][(1, 0)] = 99;
313 }
314
315 assert_eq!(data[0], 88);
316 assert_eq!(data[5], 99);
317 }
318
319 #[test]
320 fn test_iter_column_mut() {
321 let mut data = array::from_fn::<_, 12, _>(|i| i);
322 let data_clone = data;
323 let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
324
325 let mut col_iter = arr.iter_column_mut(1);
326 assert_eq!(col_iter.next().copied(), Some(data_clone[1]));
327 assert_eq!(col_iter.next().copied(), Some(data_clone[4]));
328 assert_eq!(col_iter.next().copied(), Some(data_clone[7]));
329 assert_eq!(col_iter.next().copied(), Some(data_clone[10]));
330 assert_eq!(col_iter.next(), None);
331 }
332
333 #[test]
334 fn test_col_mut_indexing() {
335 let mut data = array::from_fn::<_, 12, _>(|i| i);
336 let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
337
338 let mut cols: Vec<_> = arr.iter_cols().collect();
339 assert_eq!(cols.len(), 3);
340
341 assert_eq!(cols[1][0], 1);
343 assert_eq!(cols[1][1], 4);
344 assert_eq!(cols[1][2], 7);
345 assert_eq!(cols[1][3], 10);
346
347 cols[0][2] = 88;
349 cols[2][1] = 99;
350
351 assert_eq!(data[6], 88); assert_eq!(data[5], 99); }
354
355 #[test]
356 fn test_iter_cols() {
357 let mut data = array::from_fn::<_, 12, _>(|i| i);
358 let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
359
360 {
361 let mut cols = arr.iter_cols();
362 let mut col0 = cols.next().unwrap();
363 let mut col1 = cols.next().unwrap();
364 let mut col2 = cols.next().unwrap();
365 assert!(cols.next().is_none());
366
367 assert_eq!(col0.height(), 4);
368 assert_eq!(col1.height(), 4);
369 assert_eq!(col2.height(), 4);
370
371 col0[0] = 88;
372 col1[1] = 99;
373 col2[3] = 77;
374 }
375
376 assert_eq!(data[0], 88); assert_eq!(data[4], 99); assert_eq!(data[11], 77); }
380
381 #[test]
382 fn test_par_iter_cols() {
383 let mut data = array::from_fn::<_, 12, _>(|i| i);
384 let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
385
386 {
387 let mut cols: Vec<_> = arr.par_iter_cols().collect();
388 assert_eq!(cols.len(), 3);
389
390 cols[0][0] = 88;
391 cols[1][1] = 99;
392 cols[2][3] = 77;
393 }
394
395 assert_eq!(data[0], 88); assert_eq!(data[4], 99); assert_eq!(data[11], 77); }
399}