binius_utils/
random_access_sequence.rs

1// Copyright 2025 Irreducible Inc.
2
3/// A trait for a collection that allows indexed access by value.
4/// This trait is used to abstract over different types of collections - scalar slices,
5/// slices of packed field elements including subranges of collections.
6pub trait RandomAccessSequence<T: Copy> {
7	fn len(&self) -> usize;
8
9	#[inline(always)]
10	fn is_empty(&self) -> bool {
11		self.len() == 0
12	}
13
14	#[inline(always)]
15	fn get(&self, index: usize) -> T {
16		assert!(index < self.len(), "Index out of bounds");
17		unsafe { self.get_unchecked(index) }
18	}
19
20	/// Returns a copy of the element at the given index.
21	///
22	/// # Safety
23	/// The caller must ensure that the `index` < `self.len()`.
24	unsafe fn get_unchecked(&self, index: usize) -> T;
25}
26
27/// A trait for a mutable access to a collection of scalars.
28pub trait RandomAccessSequenceMut<T: Copy>: RandomAccessSequence<T> {
29	#[inline(always)]
30	fn set(&mut self, index: usize, value: T) {
31		assert!(index < self.len(), "Index out of bounds");
32		unsafe { self.set_unchecked(index, value) }
33	}
34
35	/// Sets the element at the given index to the given value.
36	///
37	/// # Safety
38	/// The caller must ensure that the `index` < `self.len()`.
39	unsafe fn set_unchecked(&mut self, index: usize, value: T);
40}
41
42impl<T: Copy> RandomAccessSequence<T> for &[T] {
43	#[inline(always)]
44	fn len(&self) -> usize {
45		<[T]>::len(self)
46	}
47
48	#[inline(always)]
49	fn get(&self, index: usize) -> T {
50		self[index]
51	}
52
53	#[inline(always)]
54	unsafe fn get_unchecked(&self, index: usize) -> T {
55		unsafe { *<[T]>::get_unchecked(self, index) }
56	}
57}
58
59impl<T: Copy> RandomAccessSequence<T> for &mut [T] {
60	#[inline(always)]
61	fn len(&self) -> usize {
62		<[T]>::len(self)
63	}
64
65	#[inline(always)]
66	fn get(&self, index: usize) -> T {
67		self[index]
68	}
69
70	#[inline(always)]
71	unsafe fn get_unchecked(&self, index: usize) -> T {
72		unsafe { *<[T]>::get_unchecked(self, index) }
73	}
74}
75
76impl<T: Copy> RandomAccessSequenceMut<T> for &mut [T] {
77	#[inline(always)]
78	fn set(&mut self, index: usize, value: T) {
79		self[index] = value;
80	}
81
82	#[inline(always)]
83	unsafe fn set_unchecked(&mut self, index: usize, value: T) {
84		unsafe {
85			*<[T]>::get_unchecked_mut(self, index) = value;
86		}
87	}
88}
89
90/// A subrange adapter of a collection of scalars.
91#[derive(Clone)]
92pub struct SequenceSubrange<'a, T: Copy, Inner: RandomAccessSequence<T>> {
93	inner: &'a Inner,
94	offset: usize,
95	len: usize,
96	_marker: std::marker::PhantomData<T>,
97}
98
99impl<'a, T: Copy, Inner: RandomAccessSequence<T>> SequenceSubrange<'a, T, Inner> {
100	#[inline(always)]
101	pub fn new(inner: &'a Inner, offset: usize, len: usize) -> Self {
102		assert!(offset + len <= inner.len(), "subrange out of bounds");
103
104		Self {
105			inner,
106			offset,
107			len,
108			_marker: std::marker::PhantomData,
109		}
110	}
111}
112
113impl<T: Copy, Inner: RandomAccessSequence<T>> RandomAccessSequence<T>
114	for SequenceSubrange<'_, T, Inner>
115{
116	#[inline(always)]
117	fn len(&self) -> usize {
118		self.len
119	}
120
121	#[inline(always)]
122	unsafe fn get_unchecked(&self, index: usize) -> T {
123		unsafe { self.inner.get_unchecked(index + self.offset) }
124	}
125}
126
127/// A subrange adapter of a mutable collection of scalars.
128pub struct SequenceSubrangeMut<'a, T: Copy, Inner: RandomAccessSequenceMut<T>> {
129	inner: &'a mut Inner,
130	offset: usize,
131	len: usize,
132	_marker: std::marker::PhantomData<&'a T>,
133}
134
135impl<'a, T: Copy, Inner: RandomAccessSequenceMut<T>> SequenceSubrangeMut<'a, T, Inner> {
136	#[inline(always)]
137	pub fn new(inner: &'a mut Inner, offset: usize, len: usize) -> Self {
138		assert!(offset + len <= inner.len(), "subrange out of bounds");
139
140		Self {
141			inner,
142			offset,
143			len,
144			_marker: std::marker::PhantomData,
145		}
146	}
147}
148impl<T: Copy, Inner: RandomAccessSequenceMut<T>> RandomAccessSequence<T>
149	for SequenceSubrangeMut<'_, T, Inner>
150{
151	#[inline(always)]
152	fn len(&self) -> usize {
153		self.len
154	}
155
156	#[inline(always)]
157	unsafe fn get_unchecked(&self, index: usize) -> T {
158		unsafe { self.inner.get_unchecked(index + self.offset) }
159	}
160}
161impl<T: Copy, Inner: RandomAccessSequenceMut<T>> RandomAccessSequenceMut<T>
162	for SequenceSubrangeMut<'_, T, Inner>
163{
164	#[inline(always)]
165	unsafe fn set_unchecked(&mut self, index: usize, value: T) {
166		unsafe {
167			self.inner.set_unchecked(index + self.offset, value);
168		}
169	}
170}
171
172/// Power-of-two aligned vertical slice of a sequence when viewed as a row-major matrix.
173/// This is useful access pattern for algorithms like 4-step NTT or switchover.
174#[derive(Clone)]
175pub struct MatrixVertSliceSubrange<'a, T: Copy, Inner: RandomAccessSequence<T>> {
176	inner: &'a Inner,
177	log_cols: usize,
178	log_slice: usize,
179	slice_index: usize,
180	len: usize,
181	_marker: std::marker::PhantomData<T>,
182}
183
184impl<'a, T: Copy, Inner: RandomAccessSequence<T>> MatrixVertSliceSubrange<'a, T, Inner> {
185	/// Rearrange the `inner` sequence into a row-major matrix with sides `2^log_rows` and
186	/// `2^log_cols`, then take an aligned vertical slice of size `2^log_slice` with index
187	/// `slice_index` (which ranges from 0 to `2^(log_cols - log_slice)`, non-inclusive), and
188	/// present that slice as a view.
189	#[inline(always)]
190	pub fn new(
191		inner: &'a Inner,
192		log_rows: usize,
193		log_cols: usize,
194		log_slice: usize,
195		slice_index: usize,
196	) -> Self {
197		assert_eq!(
198			1 << (log_rows + log_cols),
199			inner.len(),
200			"matrix dimensions do not match inner sequence"
201		);
202		assert!(log_slice <= log_cols && slice_index < 1 << (log_cols - log_slice));
203
204		let len = 1 << (log_slice + log_rows);
205
206		Self {
207			inner,
208			log_cols,
209			log_slice,
210			slice_index,
211			len,
212			_marker: std::marker::PhantomData,
213		}
214	}
215}
216
217impl<T: Copy, Inner: RandomAccessSequence<T>> RandomAccessSequence<T>
218	for MatrixVertSliceSubrange<'_, T, Inner>
219{
220	#[inline(always)]
221	fn len(&self) -> usize {
222		self.len
223	}
224
225	#[inline(always)]
226	unsafe fn get_unchecked(&self, index: usize) -> T {
227		let row = index >> self.log_slice;
228		let col = index ^ (row << self.log_slice);
229		let inner_index = row << self.log_cols | self.slice_index << self.log_slice | col;
230		unsafe { self.inner.get_unchecked(inner_index) }
231	}
232}
233
234#[cfg(test)]
235mod tests {
236	use std::fmt::Debug;
237
238	use rand::{Rng, SeedableRng, rngs::StdRng};
239
240	use super::*;
241
242	fn check_collection<T: Copy + Eq + Debug>(
243		collection: &impl RandomAccessSequence<T>,
244		expected: &[T],
245	) {
246		assert_eq!(collection.len(), expected.len());
247
248		for (i, v) in expected.iter().enumerate() {
249			assert_eq!(&collection.get(i), v);
250			assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
251		}
252	}
253
254	fn check_collection_get_set<T: Eq + Copy + Debug>(
255		collection: &mut impl RandomAccessSequenceMut<T>,
256		random: &mut impl FnMut() -> T,
257	) {
258		for i in 0..collection.len() {
259			let value = random();
260			collection.set(i, value);
261			assert_eq!(collection.get(i), value);
262			assert_eq!(unsafe { collection.get_unchecked(i) }, value);
263		}
264	}
265
266	#[test]
267	fn check_slice() {
268		let slice: &[usize] = &[];
269		check_collection::<usize>(&slice, slice);
270
271		let slice: &[usize] = &[1usize, 2, 3];
272		check_collection(&slice, slice);
273	}
274
275	#[test]
276	fn check_slice_mut() {
277		let mut rng = StdRng::seed_from_u64(0);
278		let mut random = || -> usize { rng.random::<u64>() as usize };
279
280		let mut slice: &mut [usize] = &mut [];
281
282		check_collection(&slice, slice);
283		check_collection_get_set(&mut slice, &mut random);
284
285		let mut slice: &mut [usize] = &mut [1, 2, 3];
286		check_collection(&slice, slice);
287		check_collection_get_set(&mut slice, &mut random);
288	}
289
290	#[test]
291	fn test_subrange() {
292		let slice: &[usize] = &[1, 2, 3, 4, 5];
293		let subrange = SequenceSubrange::new(&slice, 1, 3);
294		check_collection(&subrange, &[2, 3, 4]);
295	}
296
297	#[test]
298	fn test_subrange_mut() {
299		let mut rng = StdRng::seed_from_u64(0);
300		let mut random = || -> usize { rng.random::<u64>() as usize };
301
302		let mut slice: &mut [usize] = &mut [1, 2, 3, 4, 5];
303		let values = slice[1..4].to_vec();
304		let mut subrange = SequenceSubrangeMut::new(&mut slice, 1, 3);
305		check_collection(&subrange, &values);
306		check_collection_get_set(&mut subrange, &mut random);
307	}
308}