binius_fast_compute/
memory.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{
4	marker::PhantomData,
5	ops::{Bound, RangeBounds},
6};
7
8use binius_compute::memory::{ComputeMemory, SizedSlice};
9use binius_field::{PackedField, packed::iter_packed_slice_with_offset};
10
11/// A packed memory implementation that uses slices of packed fields.
12pub struct PackedMemory<P>(PhantomData<P>);
13
14impl<P: PackedField> ComputeMemory<P::Scalar> for PackedMemory<P> {
15	const ALIGNMENT: usize = P::WIDTH;
16
17	type FSlice<'a> = PackedMemorySlice<'a, P>;
18
19	type FSliceMut<'a> = PackedMemorySliceMut<'a, P>;
20
21	fn as_const<'a>(data: &'a Self::FSliceMut<'_>) -> Self::FSlice<'a> {
22		match data {
23			PackedMemorySliceMut::Slice(slice) => PackedMemorySlice::Slice(slice),
24			PackedMemorySliceMut::Owned(chunk) => PackedMemorySlice::Owned(*chunk),
25		}
26	}
27
28	fn slice(data: Self::FSlice<'_>, range: impl std::ops::RangeBounds<usize>) -> Self::FSlice<'_> {
29		let (start, end) = Self::to_packed_range(data.len(), range);
30		if start == 0 && end == data.len() {
31			return data;
32		}
33
34		let PackedMemorySlice::Slice(slice) = data else {
35			panic!("splitting slices of length less than `Self::ALIGNMENT` is not supported");
36		};
37		PackedMemorySlice::Slice(&slice[start..end])
38	}
39
40	fn slice_mut<'a>(
41		data: &'a mut Self::FSliceMut<'_>,
42		range: impl std::ops::RangeBounds<usize>,
43	) -> Self::FSliceMut<'a> {
44		let (start, end) = Self::to_packed_range(data.len(), range);
45		if start == 0 && end == data.len() {
46			return Self::to_owned_mut(data);
47		}
48
49		let PackedMemorySliceMut::Slice(slice) = data else {
50			panic!("splitting slices of length less than `Self::ALIGNMENT` is not supported");
51		};
52		PackedMemorySliceMut::Slice(&mut slice[start..end])
53	}
54
55	fn split_at_mut(
56		data: Self::FSliceMut<'_>,
57		mid: usize,
58	) -> (Self::FSliceMut<'_>, Self::FSliceMut<'_>) {
59		assert_eq!(mid % P::WIDTH, 0, "mid must be a multiple of {}", P::WIDTH);
60		let mid = mid >> P::LOG_WIDTH;
61		let PackedMemorySliceMut::Slice(slice) = data else {
62			panic!("splitting slices of length less than `Self::ALIGNMENT` is not supported");
63		};
64		let (left, right) = slice.split_at_mut(mid);
65		(PackedMemorySliceMut::Slice(left), PackedMemorySliceMut::Slice(right))
66	}
67
68	fn narrow<'a>(data: &'a Self::FSlice<'_>) -> Self::FSlice<'a> {
69		match data {
70			PackedMemorySlice::Slice(slice) => PackedMemorySlice::Slice(slice),
71			PackedMemorySlice::Owned(chunk) => PackedMemorySlice::Owned(*chunk),
72		}
73	}
74
75	fn narrow_mut<'a, 'b: 'a>(data: Self::FSliceMut<'b>) -> Self::FSliceMut<'a> {
76		data
77	}
78
79	fn to_owned_mut<'a>(data: &'a mut Self::FSliceMut<'_>) -> Self::FSliceMut<'a> {
80		match data {
81			PackedMemorySliceMut::Slice(slice) => PackedMemorySliceMut::Slice(slice),
82			PackedMemorySliceMut::Owned(chunk) => PackedMemorySliceMut::Owned(*chunk),
83		}
84	}
85
86	fn slice_chunks_mut<'a>(
87		data: Self::FSliceMut<'a>,
88		chunk_len: usize,
89	) -> impl Iterator<Item = Self::FSliceMut<'a>> {
90		assert_eq!(chunk_len % P::WIDTH, 0, "chunk_len must be a multiple of {}", P::WIDTH);
91		assert_eq!(data.len() % chunk_len, 0, "data.len() must be a multiple of chunk_len");
92
93		let chunk_len = chunk_len >> P::LOG_WIDTH;
94
95		let PackedMemorySliceMut::Slice(slice) = data else {
96			panic!("splitting slices of length less than `Self::ALIGNMENT` is not supported");
97		};
98
99		slice
100			.chunks_mut(chunk_len)
101			.map(|chunk| Self::FSliceMut::new_slice(chunk))
102	}
103
104	fn split_half<'a>(data: Self::FSlice<'a>) -> (Self::FSlice<'a>, Self::FSlice<'a>) {
105		assert!(
106			data.len().is_power_of_two() && data.len() > 1,
107			"data.len() must be a power of two greater than 1"
108		);
109
110		match data {
111			PackedMemorySlice::Slice(slice) => match slice.len() {
112				len if len > 1 => {
113					let mid = slice.len() / 2;
114					let left = &slice[..mid];
115					let right = &slice[mid..];
116					(PackedMemorySlice::Slice(left), PackedMemorySlice::Slice(right))
117				}
118				1 => (
119					PackedMemorySlice::new_owned(slice, 0, P::WIDTH / 2),
120					PackedMemorySlice::new_owned(slice, P::WIDTH / 2, P::WIDTH / 2),
121				),
122				_ => {
123					unreachable!()
124				}
125			},
126			PackedMemorySlice::Owned(chunk) => {
127				let mid = chunk.len / 2;
128				let left = chunk.subrange(0, mid);
129				let right = chunk.subrange(mid, chunk.len);
130				(PackedMemorySlice::Owned(left), PackedMemorySlice::Owned(right))
131			}
132		}
133	}
134
135	fn split_half_mut<'a>(data: Self::FSliceMut<'a>) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
136		assert!(
137			data.len().is_power_of_two() && data.len() > 1,
138			"data.len() must be a power of two greater than 1"
139		);
140
141		match data {
142			PackedMemorySliceMut::Slice(slice) => match slice.len() {
143				len if len > 1 => {
144					let mid = slice.len() / 2;
145					let (left, right) = slice.split_at_mut(mid);
146					(PackedMemorySliceMut::Slice(left), PackedMemorySliceMut::Slice(right))
147				}
148				1 => (
149					PackedMemorySliceMut::new_owned(slice, 0, P::WIDTH / 2),
150					PackedMemorySliceMut::new_owned(slice, P::WIDTH / 2, P::WIDTH / 2),
151				),
152				_ => {
153					unreachable!()
154				}
155			},
156			PackedMemorySliceMut::Owned(chunk) => {
157				let mid = chunk.len / 2;
158				let left = chunk.subrange(0, mid);
159				let right = chunk.subrange(mid, chunk.len);
160				(PackedMemorySliceMut::Owned(left), PackedMemorySliceMut::Owned(right))
161			}
162		}
163	}
164}
165
166impl<P: PackedField> PackedMemory<P> {
167	fn to_packed_range(len: usize, range: impl RangeBounds<usize>) -> (usize, usize) {
168		let start = match range.start_bound() {
169			Bound::Included(&start) => start,
170			Bound::Excluded(&start) => start + P::WIDTH,
171			Bound::Unbounded => 0,
172		};
173		let end = match range.end_bound() {
174			Bound::Included(&end) => end + P::WIDTH,
175			Bound::Excluded(&end) => end,
176			Bound::Unbounded => len,
177		};
178
179		if (start, end) == (0, len) {
180			(0, len)
181		} else {
182			assert_eq!(start % P::WIDTH, 0, "start must be a multiple of {}", P::WIDTH);
183			assert_eq!(end % P::WIDTH, 0, "end must be a multiple of {}", P::WIDTH);
184
185			(start >> P::LOG_WIDTH, end >> P::LOG_WIDTH)
186		}
187	}
188}
189
190/// An in-place storage for the chunk of elements smaller than `P::WIDTH`.
191#[derive(Clone, Copy, Debug)]
192pub struct SmallOwnedChunk<P: PackedField> {
193	data: P,
194	len: usize,
195}
196
197impl<P: PackedField> SmallOwnedChunk<P> {
198	#[inline(always)]
199	fn new_from_slice(data: &[P], offset: usize, len: usize) -> Self {
200		debug_assert!(len < P::WIDTH, "len must be less than {}", P::WIDTH);
201
202		let iter = iter_packed_slice_with_offset(data, offset);
203		let data = P::from_scalars(iter.take(len));
204		Self { data, len }
205	}
206
207	#[inline]
208	fn subrange(&self, start: usize, end: usize) -> Self {
209		assert!(end <= self.len, "range out of bounds");
210
211		let data = if start == 0 {
212			self.data
213		} else {
214			P::from_scalars(self.data.iter().skip(start).take(end - start))
215		};
216		Self {
217			data,
218			len: end - start,
219		}
220	}
221
222	/// Used for tests only
223	#[cfg(test)]
224	fn iter_scalars(&self) -> impl Iterator<Item = P::Scalar> {
225		self.data.iter().take(self.len)
226	}
227}
228
229/// Memory slice that can be either a borrowed slice or an owned small chunk (with length <
230/// `P::WIDTH`).
231#[derive(Clone, Copy, Debug)]
232pub enum PackedMemorySlice<'a, P: PackedField> {
233	Slice(&'a [P]),
234	Owned(SmallOwnedChunk<P>),
235}
236
237impl<'a, P: PackedField> PackedMemorySlice<'a, P> {
238	#[inline(always)]
239	pub fn new_slice(data: &'a [P]) -> Self {
240		Self::Slice(data)
241	}
242
243	#[inline(always)]
244	pub fn new_owned(data: &[P], offset: usize, len: usize) -> Self {
245		let chunk = SmallOwnedChunk::new_from_slice(data, offset, len);
246		Self::Owned(chunk)
247	}
248
249	#[inline(always)]
250	pub fn as_slice(&'a self) -> &'a [P] {
251		match self {
252			Self::Slice(data) => data,
253			Self::Owned(chunk) => std::slice::from_ref(&chunk.data),
254		}
255	}
256
257	/// Used for tests only
258	#[cfg(test)]
259	fn iter_scalars(&self) -> impl Iterator<Item = P::Scalar> {
260		use itertools::Either;
261
262		match self {
263			Self::Slice(data) => Either::Left(data.iter().flat_map(|p| p.iter())),
264			Self::Owned(chunk) => Either::Right(chunk.iter_scalars()),
265		}
266	}
267}
268
269impl<'a, P: PackedField> SizedSlice for PackedMemorySlice<'a, P> {
270	#[inline(always)]
271	fn is_empty(&self) -> bool {
272		match self {
273			Self::Slice(data) => data.is_empty(),
274			Self::Owned(chunk) => chunk.len == 0,
275		}
276	}
277
278	#[inline(always)]
279	fn len(&self) -> usize {
280		match self {
281			Self::Slice(data) => data.len() << P::LOG_WIDTH,
282			Self::Owned(chunk) => chunk.len,
283		}
284	}
285}
286
287pub enum PackedMemorySliceMut<'a, P: PackedField> {
288	Slice(&'a mut [P]),
289	Owned(SmallOwnedChunk<P>),
290}
291
292impl<'a, P: PackedField> PackedMemorySliceMut<'a, P> {
293	#[inline(always)]
294	pub fn new_slice(data: &'a mut [P]) -> Self {
295		Self::Slice(data)
296	}
297
298	#[inline(always)]
299	pub fn new_owned(data: &mut [P], offset: usize, len: usize) -> Self {
300		let chunk = SmallOwnedChunk::new_from_slice(data, offset, len);
301		Self::Owned(chunk)
302	}
303
304	#[inline(always)]
305	pub fn as_const(&self) -> PackedMemorySlice<'_, P> {
306		match self {
307			Self::Slice(data) => PackedMemorySlice::Slice(data),
308			Self::Owned(chunk) => PackedMemorySlice::Owned(*chunk),
309		}
310	}
311
312	#[inline(always)]
313	pub fn as_slice(&'a self) -> &'a [P] {
314		match self {
315			Self::Slice(data) => data,
316			Self::Owned(chunk) => std::slice::from_ref(&chunk.data),
317		}
318	}
319
320	#[inline(always)]
321	pub fn as_slice_mut(&mut self) -> &mut [P] {
322		match self {
323			Self::Slice(data) => data,
324			Self::Owned(chunk) => std::slice::from_mut(&mut chunk.data),
325		}
326	}
327}
328
329impl<'a, P: PackedField> SizedSlice for PackedMemorySliceMut<'a, P> {
330	#[inline(always)]
331	fn is_empty(&self) -> bool {
332		match self {
333			Self::Slice(data) => data.is_empty(),
334			Self::Owned(chunk) => chunk.len == 0,
335		}
336	}
337
338	#[inline(always)]
339	fn len(&self) -> usize {
340		match self {
341			Self::Slice(data) => data.len() << P::LOG_WIDTH,
342			Self::Owned(chunk) => chunk.len,
343		}
344	}
345}
346
347#[cfg(test)]
348mod tests {
349	use binius_field::PackedBinaryField4x32b;
350	use itertools::Itertools;
351	use rand::{SeedableRng, rngs::StdRng};
352
353	use super::*;
354
355	type Packed = PackedBinaryField4x32b;
356
357	fn make_random_vec(len: usize) -> Vec<Packed> {
358		let mut rnd = StdRng::seed_from_u64(0);
359
360		(0..len)
361			.map(|_| PackedBinaryField4x32b::random(&mut rnd))
362			.collect()
363	}
364
365	#[test]
366	fn test_try_slice_on_mem_slice() {
367		let data = make_random_vec(3);
368		let data_clone = data.clone();
369		let memory = PackedMemorySlice::new_slice(&data);
370
371		assert_eq!(PackedMemory::slice(memory, 0..2 * Packed::WIDTH).as_slice(), &data_clone[0..2]);
372		assert_eq!(PackedMemory::slice(memory, ..2 * Packed::WIDTH).as_slice(), &data_clone[..2]);
373		assert_eq!(PackedMemory::slice(memory, Packed::WIDTH..).as_slice(), &data_clone[1..]);
374		assert_eq!(PackedMemory::slice(memory, ..).as_slice(), &data_clone[..]);
375
376		// check panic on non-aligned slice
377		let result = std::panic::catch_unwind(|| {
378			PackedMemory::slice(memory, 0..1);
379		});
380		assert!(result.is_err());
381		let result = std::panic::catch_unwind(|| {
382			PackedMemory::slice(memory, ..1);
383		});
384		assert!(result.is_err());
385		let result = std::panic::catch_unwind(|| {
386			PackedMemory::slice(memory, 1..Packed::WIDTH);
387		});
388		assert!(result.is_err());
389		let result = std::panic::catch_unwind(|| {
390			PackedMemory::slice(memory, 1..);
391		});
392		assert!(result.is_err());
393
394		// check panic on owned slice
395		let memory_owned = PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH - 1);
396		let result = std::panic::catch_unwind(|| {
397			PackedMemory::slice(memory_owned, 0..1);
398		});
399		assert!(result.is_err());
400	}
401
402	#[test]
403	fn test_convert_mut_mem_slice_to_const() {
404		let mut data = make_random_vec(3);
405		let data_clone = data.clone();
406		let memory = PackedMemorySliceMut::new_slice(&mut data);
407
408		assert_eq!(PackedMemory::as_const(&memory).as_slice(), &data_clone[..]);
409
410		let owned_memory = PackedMemorySliceMut::new_owned(&mut data, 0, Packed::WIDTH - 1);
411		assert_eq!(
412			PackedMemory::as_const(&owned_memory)
413				.iter_scalars()
414				.collect_vec(),
415			PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH - 1)
416				.iter_scalars()
417				.collect_vec()
418		);
419	}
420
421	#[test]
422	fn test_slice_on_mut_mem_slice() {
423		let mut data = make_random_vec(3);
424		let data_clone = data.clone();
425		let mut memory = PackedMemorySliceMut::new_slice(&mut data);
426
427		assert_eq!(
428			PackedMemory::slice_mut(&mut memory, 0..2 * Packed::WIDTH).as_slice(),
429			&data_clone[0..2]
430		);
431		assert_eq!(
432			PackedMemory::slice_mut(&mut memory, ..2 * Packed::WIDTH).as_slice(),
433			&data_clone[..2]
434		);
435		assert_eq!(
436			PackedMemory::slice_mut(&mut memory, Packed::WIDTH..).as_slice(),
437			&data_clone[1..]
438		);
439		assert_eq!(PackedMemory::slice_mut(&mut memory, ..).as_slice(), &data_clone[..]);
440	}
441
442	#[test]
443	#[should_panic]
444	fn test_slice_mut_on_mem_slice_panic_1() {
445		let mut data = make_random_vec(3);
446		let mut memory = PackedMemorySliceMut::new_slice(&mut data);
447
448		// `&mut T` can't cross the catch unwind boundary, so we have to use several tests
449		// to test the panic cases.
450		PackedMemory::slice_mut(&mut memory, 0..1);
451	}
452
453	#[test]
454	#[should_panic]
455	fn test_slice_mut_on_mem_slice_panic_2() {
456		let mut data = make_random_vec(3);
457		let mut memory = PackedMemorySliceMut::new_slice(&mut data);
458
459		PackedMemory::slice_mut(&mut memory, ..1);
460	}
461
462	#[test]
463	#[should_panic]
464	fn test_slice_mut_on_mem_slice_panic_3() {
465		let mut data = make_random_vec(3);
466		let mut memory = PackedMemorySliceMut::new_slice(&mut data);
467
468		PackedMemory::slice_mut(&mut memory, 1..Packed::WIDTH);
469	}
470
471	#[test]
472	#[should_panic]
473	fn test_slice_mut_on_mem_slice_panic_4() {
474		let mut data = make_random_vec(3);
475		let mut memory = PackedMemorySliceMut::new_slice(&mut data);
476
477		PackedMemory::slice_mut(&mut memory, 1..);
478	}
479
480	#[test]
481	#[should_panic]
482	fn test_slice_mut_on_mem_slice_panic_5() {
483		let mut data = make_random_vec(3);
484		let mut memory = PackedMemorySliceMut::new_owned(&mut data, 0, Packed::WIDTH - 1);
485
486		PackedMemory::slice_mut(&mut memory, 1..);
487	}
488
489	#[test]
490	fn test_split_at_mut() {
491		let mut data = make_random_vec(3);
492		let data_clone = data.clone();
493		let memory = PackedMemorySliceMut::new_slice(&mut data);
494
495		let (left, right) = PackedMemory::split_at_mut(memory, 2 * Packed::WIDTH);
496		assert_eq!(left.as_slice(), &data_clone[0..2]);
497		assert_eq!(right.as_slice(), &data_clone[2..]);
498	}
499
500	#[test]
501	#[should_panic]
502	fn test_split_at_mut_panic_1() {
503		let mut data = make_random_vec(3);
504		let memory = PackedMemorySliceMut::new_slice(&mut data);
505
506		// `&mut T` can't cross the catch unwind boundary, so we have to use several tests
507		// to test the panic cases.
508		PackedMemory::split_at_mut(memory, 1);
509	}
510
511	#[test]
512	#[should_panic]
513	fn test_split_at_mut_panic_2() {
514		let mut data = make_random_vec(3);
515		let memory = PackedMemorySliceMut::new_owned(&mut data, 0, Packed::WIDTH - 1);
516
517		// `&mut T` can't cross the catch unwind boundary, so we have to use several tests
518		// to test the panic cases.
519		PackedMemory::split_at_mut(memory, 1);
520	}
521
522	#[test]
523	fn test_split_half() {
524		let data = make_random_vec(2);
525		let data_clone = data.clone();
526		let memory = PackedMemorySlice::new_slice(&data);
527
528		let (left, right) = PackedMemory::split_half(memory);
529		assert_eq!(left.as_slice(), &data_clone[0..1]);
530		assert_eq!(right.as_slice(), &data_clone[1..]);
531
532		let memory = PackedMemorySlice::new_slice(&data[0..1]);
533		let (left, right) = PackedMemory::split_half(memory);
534		assert_eq!(
535			left.iter_scalars().collect_vec(),
536			PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH / 2)
537				.iter_scalars()
538				.collect_vec()
539		);
540		assert_eq!(
541			right.iter_scalars().collect_vec(),
542			PackedMemorySlice::new_owned(&data, Packed::WIDTH / 2, Packed::WIDTH / 2)
543				.iter_scalars()
544				.collect_vec()
545		);
546
547		let memory = PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH / 2);
548		let (left, right) = PackedMemory::split_half(memory);
549		assert_eq!(
550			left.iter_scalars().collect_vec(),
551			PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH / 4)
552				.iter_scalars()
553				.collect_vec()
554		);
555		assert_eq!(
556			right.iter_scalars().collect_vec(),
557			PackedMemorySlice::new_owned(&data, Packed::WIDTH / 4, Packed::WIDTH / 4)
558				.iter_scalars()
559				.collect_vec()
560		);
561	}
562}