binius_ntt/
strided_array.rs1use core::slice;
4use std::ops::{Index, IndexMut, Range};
5
6use binius_maybe_rayon::prelude::*;
7
8#[derive(Debug, thiserror::Error)]
9pub enum Error {
10 #[error("dimensions do not match data size")]
11 DimensionMismatch,
12}
13
14#[derive(Debug)]
17pub struct StridedArray2DViewMut<'a, T> {
18 data: &'a mut [T],
19 data_width: usize,
20 height: usize,
21 cols: Range<usize>,
22}
23
24impl<'a, T> StridedArray2DViewMut<'a, T> {
25 pub const fn without_stride(
27 data: &'a mut [T],
28 height: usize,
29 width: usize,
30 ) -> Result<Self, Error> {
31 if width * height != data.len() {
32 return Err(Error::DimensionMismatch);
33 }
34 Ok(Self {
35 data,
36 data_width: width,
37 height,
38 cols: 0..width,
39 })
40 }
41
42 pub unsafe fn get_unchecked_ref(&self, i: usize, j: usize) -> &T {
46 debug_assert!(i < self.height);
47 debug_assert!(j < self.width());
48 self.data
49 .get_unchecked(i * self.data_width + j + self.cols.start)
50 }
51
52 pub unsafe fn get_unchecked_mut(&mut self, i: usize, j: usize) -> &mut T {
56 debug_assert!(i < self.height);
57 debug_assert!(j < self.width());
58 self.data
59 .get_unchecked_mut(i * self.data_width + j + self.cols.start)
60 }
61
62 pub const fn height(&self) -> usize {
63 self.height
64 }
65
66 pub const fn width(&self) -> usize {
67 self.cols.end - self.cols.start
68 }
69
70 #[allow(dead_code)]
72 pub fn into_strides(self, stride: usize) -> impl Iterator<Item = Self> + 'a {
73 let Self {
74 data,
75 data_width,
76 height,
77 cols,
78 } = self;
79
80 cols.clone().step_by(stride).map(move |start| {
81 let end = (start + stride).min(cols.end);
82 Self {
83 data: unsafe { slice::from_raw_parts_mut(data.as_mut_ptr(), data.len()) },
86 data_width,
87 height,
88 cols: start..end,
89 }
90 })
91 }
92
93 pub fn into_par_strides(self, stride: usize) -> impl ParallelIterator<Item = Self> + 'a
95 where
96 T: Send + Sync,
97 {
98 self.cols
99 .clone()
100 .into_par_iter()
101 .step_by(stride)
102 .map(move |start| {
103 let end = (start + stride).min(self.cols.end);
104 Self {
106 data: unsafe {
109 slice::from_raw_parts_mut(self.data.as_ptr() as *mut T, self.data.len())
110 },
111 data_width: self.data_width,
112 height: self.height,
113 cols: start..end,
114 }
115 })
116 }
117}
118
119impl<T> Index<(usize, usize)> for StridedArray2DViewMut<'_, T> {
120 type Output = T;
121
122 fn index(&self, (i, j): (usize, usize)) -> &T {
123 assert!(i < self.height());
124 assert!(j < self.width());
125 unsafe { self.get_unchecked_ref(i, j) }
126 }
127}
128
129impl<T> IndexMut<(usize, usize)> for StridedArray2DViewMut<'_, T> {
130 fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
131 assert!(i < self.height());
132 assert!(j < self.width());
133 unsafe { self.get_unchecked_mut(i, j) }
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use std::array;
140
141 use super::*;
142
143 #[test]
144 fn test_indexing() {
145 let mut data = array::from_fn::<_, 12, _>(|i| i);
146 let mut arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
147 assert_eq!(arr[(3, 1)], 10);
148 arr[(2, 2)] = 88;
149 assert_eq!(data[8], 88);
150 }
151
152 #[test]
153 fn test_strides() {
154 let mut data = array::from_fn::<_, 12, _>(|i| i);
155 let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
156
157 {
158 let mut strides = arr.into_strides(2);
159 let mut stride0 = strides.next().unwrap();
160 let mut stride1 = strides.next().unwrap();
161 assert!(strides.next().is_none());
162
163 assert_eq!(stride0.width(), 2);
164 assert_eq!(stride1.width(), 1);
165
166 stride0[(0, 0)] = 88;
167 stride1[(1, 0)] = 99;
168 }
169
170 assert_eq!(data[0], 88);
171 assert_eq!(data[5], 99);
172 }
173
174 #[test]
175 fn test_parallel_strides() {
176 let mut data = array::from_fn::<_, 12, _>(|i| i);
177 let arr = StridedArray2DViewMut::without_stride(&mut data, 4, 3).unwrap();
178
179 {
180 let mut strides: Vec<_> = arr.into_par_strides(2).collect();
181 assert_eq!(strides.len(), 2);
182 assert_eq!(strides[0].width(), 2);
183 assert_eq!(strides[1].width(), 1);
184
185 strides[0][(0, 0)] = 88;
186 strides[1][(1, 0)] = 99;
187 }
188
189 assert_eq!(data[0], 88);
190 assert_eq!(data[5], 99);
191 }
192}