1use std::{
4 marker::PhantomData,
5 ops::{Bound, RangeBounds},
6};
7
8use binius_compute::memory::{ComputeMemory, SizedSlice};
9use binius_field::{PackedField, packed::iter_packed_slice_with_offset};
10
11pub struct PackedMemory<P>(PhantomData<P>);
13
14impl<P: PackedField> ComputeMemory<P::Scalar> for PackedMemory<P> {
15 const ALIGNMENT: usize = P::WIDTH;
16
17 type FSlice<'a> = PackedMemorySlice<'a, P>;
18
19 type FSliceMut<'a> = PackedMemorySliceMut<'a, P>;
20
21 fn as_const<'a>(data: &'a Self::FSliceMut<'_>) -> Self::FSlice<'a> {
22 match data {
23 PackedMemorySliceMut::Slice(slice) => PackedMemorySlice::Slice(slice),
24 PackedMemorySliceMut::Owned(chunk) => PackedMemorySlice::Owned(*chunk),
25 }
26 }
27
28 fn slice(data: Self::FSlice<'_>, range: impl std::ops::RangeBounds<usize>) -> Self::FSlice<'_> {
29 let (start, end) = Self::to_packed_range(data.len(), range);
30 if start == 0 && end == data.len() {
31 return data;
32 }
33
34 let PackedMemorySlice::Slice(slice) = data else {
35 panic!("splitting slices of length less than `Self::ALIGNMENT` is not supported");
36 };
37 PackedMemorySlice::Slice(&slice[start..end])
38 }
39
40 fn slice_mut<'a>(
41 data: &'a mut Self::FSliceMut<'_>,
42 range: impl std::ops::RangeBounds<usize>,
43 ) -> Self::FSliceMut<'a> {
44 let (start, end) = Self::to_packed_range(data.len(), range);
45 if start == 0 && end == data.len() {
46 return Self::to_owned_mut(data);
47 }
48
49 let PackedMemorySliceMut::Slice(slice) = data else {
50 panic!("splitting slices of length less than `Self::ALIGNMENT` is not supported");
51 };
52 PackedMemorySliceMut::Slice(&mut slice[start..end])
53 }
54
55 fn split_at_mut(
56 data: Self::FSliceMut<'_>,
57 mid: usize,
58 ) -> (Self::FSliceMut<'_>, Self::FSliceMut<'_>) {
59 assert_eq!(mid % P::WIDTH, 0, "mid must be a multiple of {}", P::WIDTH);
60 let mid = mid >> P::LOG_WIDTH;
61 let PackedMemorySliceMut::Slice(slice) = data else {
62 panic!("splitting slices of length less than `Self::ALIGNMENT` is not supported");
63 };
64 let (left, right) = slice.split_at_mut(mid);
65 (PackedMemorySliceMut::Slice(left), PackedMemorySliceMut::Slice(right))
66 }
67
68 fn narrow<'a>(data: &'a Self::FSlice<'_>) -> Self::FSlice<'a> {
69 match data {
70 PackedMemorySlice::Slice(slice) => PackedMemorySlice::Slice(slice),
71 PackedMemorySlice::Owned(chunk) => PackedMemorySlice::Owned(*chunk),
72 }
73 }
74
75 fn narrow_mut<'a, 'b: 'a>(data: Self::FSliceMut<'b>) -> Self::FSliceMut<'a> {
76 data
77 }
78
79 fn to_owned_mut<'a>(data: &'a mut Self::FSliceMut<'_>) -> Self::FSliceMut<'a> {
80 match data {
81 PackedMemorySliceMut::Slice(slice) => PackedMemorySliceMut::Slice(slice),
82 PackedMemorySliceMut::Owned(chunk) => PackedMemorySliceMut::Owned(*chunk),
83 }
84 }
85
86 fn slice_chunks_mut<'a>(
87 data: Self::FSliceMut<'a>,
88 chunk_len: usize,
89 ) -> impl Iterator<Item = Self::FSliceMut<'a>> {
90 assert_eq!(chunk_len % P::WIDTH, 0, "chunk_len must be a multiple of {}", P::WIDTH);
91 assert_eq!(data.len() % chunk_len, 0, "data.len() must be a multiple of chunk_len");
92
93 let chunk_len = chunk_len >> P::LOG_WIDTH;
94
95 let PackedMemorySliceMut::Slice(slice) = data else {
96 panic!("splitting slices of length less than `Self::ALIGNMENT` is not supported");
97 };
98
99 slice
100 .chunks_mut(chunk_len)
101 .map(|chunk| Self::FSliceMut::new_slice(chunk))
102 }
103
104 fn split_half<'a>(data: Self::FSlice<'a>) -> (Self::FSlice<'a>, Self::FSlice<'a>) {
105 assert!(
106 data.len().is_power_of_two() && data.len() > 1,
107 "data.len() must be a power of two greater than 1"
108 );
109
110 match data {
111 PackedMemorySlice::Slice(slice) => match slice.len() {
112 len if len > 1 => {
113 let mid = slice.len() / 2;
114 let left = &slice[..mid];
115 let right = &slice[mid..];
116 (PackedMemorySlice::Slice(left), PackedMemorySlice::Slice(right))
117 }
118 1 => (
119 PackedMemorySlice::new_owned(slice, 0, P::WIDTH / 2),
120 PackedMemorySlice::new_owned(slice, P::WIDTH / 2, P::WIDTH / 2),
121 ),
122 _ => {
123 unreachable!()
124 }
125 },
126 PackedMemorySlice::Owned(chunk) => {
127 let mid = chunk.len / 2;
128 let left = chunk.subrange(0, mid);
129 let right = chunk.subrange(mid, chunk.len);
130 (PackedMemorySlice::Owned(left), PackedMemorySlice::Owned(right))
131 }
132 }
133 }
134
135 fn split_half_mut<'a>(data: Self::FSliceMut<'a>) -> (Self::FSliceMut<'a>, Self::FSliceMut<'a>) {
136 assert!(
137 data.len().is_power_of_two() && data.len() > 1,
138 "data.len() must be a power of two greater than 1"
139 );
140
141 match data {
142 PackedMemorySliceMut::Slice(slice) => match slice.len() {
143 len if len > 1 => {
144 let mid = slice.len() / 2;
145 let (left, right) = slice.split_at_mut(mid);
146 (PackedMemorySliceMut::Slice(left), PackedMemorySliceMut::Slice(right))
147 }
148 1 => (
149 PackedMemorySliceMut::new_owned(slice, 0, P::WIDTH / 2),
150 PackedMemorySliceMut::new_owned(slice, P::WIDTH / 2, P::WIDTH / 2),
151 ),
152 _ => {
153 unreachable!()
154 }
155 },
156 PackedMemorySliceMut::Owned(chunk) => {
157 let mid = chunk.len / 2;
158 let left = chunk.subrange(0, mid);
159 let right = chunk.subrange(mid, chunk.len);
160 (PackedMemorySliceMut::Owned(left), PackedMemorySliceMut::Owned(right))
161 }
162 }
163 }
164}
165
166impl<P: PackedField> PackedMemory<P> {
167 fn to_packed_range(len: usize, range: impl RangeBounds<usize>) -> (usize, usize) {
168 let start = match range.start_bound() {
169 Bound::Included(&start) => start,
170 Bound::Excluded(&start) => start + P::WIDTH,
171 Bound::Unbounded => 0,
172 };
173 let end = match range.end_bound() {
174 Bound::Included(&end) => end + P::WIDTH,
175 Bound::Excluded(&end) => end,
176 Bound::Unbounded => len,
177 };
178
179 if (start, end) == (0, len) {
180 (0, len)
181 } else {
182 assert_eq!(start % P::WIDTH, 0, "start must be a multiple of {}", P::WIDTH);
183 assert_eq!(end % P::WIDTH, 0, "end must be a multiple of {}", P::WIDTH);
184
185 (start >> P::LOG_WIDTH, end >> P::LOG_WIDTH)
186 }
187 }
188}
189
190#[derive(Clone, Copy, Debug)]
192pub struct SmallOwnedChunk<P: PackedField> {
193 data: P,
194 len: usize,
195}
196
197impl<P: PackedField> SmallOwnedChunk<P> {
198 #[inline(always)]
199 fn new_from_slice(data: &[P], offset: usize, len: usize) -> Self {
200 debug_assert!(len < P::WIDTH, "len must be less than {}", P::WIDTH);
201
202 let iter = iter_packed_slice_with_offset(data, offset);
203 let data = P::from_scalars(iter.take(len));
204 Self { data, len }
205 }
206
207 #[inline]
208 fn subrange(&self, start: usize, end: usize) -> Self {
209 assert!(end <= self.len, "range out of bounds");
210
211 let data = if start == 0 {
212 self.data
213 } else {
214 P::from_scalars(self.data.iter().skip(start).take(end - start))
215 };
216 Self {
217 data,
218 len: end - start,
219 }
220 }
221
222 #[cfg(test)]
224 fn iter_scalars(&self) -> impl Iterator<Item = P::Scalar> {
225 self.data.iter().take(self.len)
226 }
227}
228
229#[derive(Clone, Copy, Debug)]
232pub enum PackedMemorySlice<'a, P: PackedField> {
233 Slice(&'a [P]),
234 Owned(SmallOwnedChunk<P>),
235}
236
237impl<'a, P: PackedField> PackedMemorySlice<'a, P> {
238 #[inline(always)]
239 pub fn new_slice(data: &'a [P]) -> Self {
240 Self::Slice(data)
241 }
242
243 #[inline(always)]
244 pub fn new_owned(data: &[P], offset: usize, len: usize) -> Self {
245 let chunk = SmallOwnedChunk::new_from_slice(data, offset, len);
246 Self::Owned(chunk)
247 }
248
249 #[inline(always)]
250 pub fn as_slice(&'a self) -> &'a [P] {
251 match self {
252 Self::Slice(data) => data,
253 Self::Owned(chunk) => std::slice::from_ref(&chunk.data),
254 }
255 }
256
257 #[cfg(test)]
259 fn iter_scalars(&self) -> impl Iterator<Item = P::Scalar> {
260 use itertools::Either;
261
262 match self {
263 Self::Slice(data) => Either::Left(data.iter().flat_map(|p| p.iter())),
264 Self::Owned(chunk) => Either::Right(chunk.iter_scalars()),
265 }
266 }
267}
268
269impl<'a, P: PackedField> SizedSlice for PackedMemorySlice<'a, P> {
270 #[inline(always)]
271 fn is_empty(&self) -> bool {
272 match self {
273 Self::Slice(data) => data.is_empty(),
274 Self::Owned(chunk) => chunk.len == 0,
275 }
276 }
277
278 #[inline(always)]
279 fn len(&self) -> usize {
280 match self {
281 Self::Slice(data) => data.len() << P::LOG_WIDTH,
282 Self::Owned(chunk) => chunk.len,
283 }
284 }
285}
286
287pub enum PackedMemorySliceMut<'a, P: PackedField> {
288 Slice(&'a mut [P]),
289 Owned(SmallOwnedChunk<P>),
290}
291
292impl<'a, P: PackedField> PackedMemorySliceMut<'a, P> {
293 #[inline(always)]
294 pub fn new_slice(data: &'a mut [P]) -> Self {
295 Self::Slice(data)
296 }
297
298 #[inline(always)]
299 pub fn new_owned(data: &mut [P], offset: usize, len: usize) -> Self {
300 let chunk = SmallOwnedChunk::new_from_slice(data, offset, len);
301 Self::Owned(chunk)
302 }
303
304 #[inline(always)]
305 pub fn as_const(&self) -> PackedMemorySlice<'_, P> {
306 match self {
307 Self::Slice(data) => PackedMemorySlice::Slice(data),
308 Self::Owned(chunk) => PackedMemorySlice::Owned(*chunk),
309 }
310 }
311
312 #[inline(always)]
313 pub fn as_slice(&'a self) -> &'a [P] {
314 match self {
315 Self::Slice(data) => data,
316 Self::Owned(chunk) => std::slice::from_ref(&chunk.data),
317 }
318 }
319
320 #[inline(always)]
321 pub fn as_slice_mut(&mut self) -> &mut [P] {
322 match self {
323 Self::Slice(data) => data,
324 Self::Owned(chunk) => std::slice::from_mut(&mut chunk.data),
325 }
326 }
327}
328
329impl<'a, P: PackedField> SizedSlice for PackedMemorySliceMut<'a, P> {
330 #[inline(always)]
331 fn is_empty(&self) -> bool {
332 match self {
333 Self::Slice(data) => data.is_empty(),
334 Self::Owned(chunk) => chunk.len == 0,
335 }
336 }
337
338 #[inline(always)]
339 fn len(&self) -> usize {
340 match self {
341 Self::Slice(data) => data.len() << P::LOG_WIDTH,
342 Self::Owned(chunk) => chunk.len,
343 }
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use binius_field::PackedBinaryField4x32b;
350 use itertools::Itertools;
351 use rand::{SeedableRng, rngs::StdRng};
352
353 use super::*;
354
355 type Packed = PackedBinaryField4x32b;
356
357 fn make_random_vec(len: usize) -> Vec<Packed> {
358 let mut rnd = StdRng::seed_from_u64(0);
359
360 (0..len)
361 .map(|_| PackedBinaryField4x32b::random(&mut rnd))
362 .collect()
363 }
364
365 #[test]
366 fn test_try_slice_on_mem_slice() {
367 let data = make_random_vec(3);
368 let data_clone = data.clone();
369 let memory = PackedMemorySlice::new_slice(&data);
370
371 assert_eq!(PackedMemory::slice(memory, 0..2 * Packed::WIDTH).as_slice(), &data_clone[0..2]);
372 assert_eq!(PackedMemory::slice(memory, ..2 * Packed::WIDTH).as_slice(), &data_clone[..2]);
373 assert_eq!(PackedMemory::slice(memory, Packed::WIDTH..).as_slice(), &data_clone[1..]);
374 assert_eq!(PackedMemory::slice(memory, ..).as_slice(), &data_clone[..]);
375
376 let result = std::panic::catch_unwind(|| {
378 PackedMemory::slice(memory, 0..1);
379 });
380 assert!(result.is_err());
381 let result = std::panic::catch_unwind(|| {
382 PackedMemory::slice(memory, ..1);
383 });
384 assert!(result.is_err());
385 let result = std::panic::catch_unwind(|| {
386 PackedMemory::slice(memory, 1..Packed::WIDTH);
387 });
388 assert!(result.is_err());
389 let result = std::panic::catch_unwind(|| {
390 PackedMemory::slice(memory, 1..);
391 });
392 assert!(result.is_err());
393
394 let memory_owned = PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH - 1);
396 let result = std::panic::catch_unwind(|| {
397 PackedMemory::slice(memory_owned, 0..1);
398 });
399 assert!(result.is_err());
400 }
401
402 #[test]
403 fn test_convert_mut_mem_slice_to_const() {
404 let mut data = make_random_vec(3);
405 let data_clone = data.clone();
406 let memory = PackedMemorySliceMut::new_slice(&mut data);
407
408 assert_eq!(PackedMemory::as_const(&memory).as_slice(), &data_clone[..]);
409
410 let owned_memory = PackedMemorySliceMut::new_owned(&mut data, 0, Packed::WIDTH - 1);
411 assert_eq!(
412 PackedMemory::as_const(&owned_memory)
413 .iter_scalars()
414 .collect_vec(),
415 PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH - 1)
416 .iter_scalars()
417 .collect_vec()
418 );
419 }
420
421 #[test]
422 fn test_slice_on_mut_mem_slice() {
423 let mut data = make_random_vec(3);
424 let data_clone = data.clone();
425 let mut memory = PackedMemorySliceMut::new_slice(&mut data);
426
427 assert_eq!(
428 PackedMemory::slice_mut(&mut memory, 0..2 * Packed::WIDTH).as_slice(),
429 &data_clone[0..2]
430 );
431 assert_eq!(
432 PackedMemory::slice_mut(&mut memory, ..2 * Packed::WIDTH).as_slice(),
433 &data_clone[..2]
434 );
435 assert_eq!(
436 PackedMemory::slice_mut(&mut memory, Packed::WIDTH..).as_slice(),
437 &data_clone[1..]
438 );
439 assert_eq!(PackedMemory::slice_mut(&mut memory, ..).as_slice(), &data_clone[..]);
440 }
441
442 #[test]
443 #[should_panic]
444 fn test_slice_mut_on_mem_slice_panic_1() {
445 let mut data = make_random_vec(3);
446 let mut memory = PackedMemorySliceMut::new_slice(&mut data);
447
448 PackedMemory::slice_mut(&mut memory, 0..1);
451 }
452
453 #[test]
454 #[should_panic]
455 fn test_slice_mut_on_mem_slice_panic_2() {
456 let mut data = make_random_vec(3);
457 let mut memory = PackedMemorySliceMut::new_slice(&mut data);
458
459 PackedMemory::slice_mut(&mut memory, ..1);
460 }
461
462 #[test]
463 #[should_panic]
464 fn test_slice_mut_on_mem_slice_panic_3() {
465 let mut data = make_random_vec(3);
466 let mut memory = PackedMemorySliceMut::new_slice(&mut data);
467
468 PackedMemory::slice_mut(&mut memory, 1..Packed::WIDTH);
469 }
470
471 #[test]
472 #[should_panic]
473 fn test_slice_mut_on_mem_slice_panic_4() {
474 let mut data = make_random_vec(3);
475 let mut memory = PackedMemorySliceMut::new_slice(&mut data);
476
477 PackedMemory::slice_mut(&mut memory, 1..);
478 }
479
480 #[test]
481 #[should_panic]
482 fn test_slice_mut_on_mem_slice_panic_5() {
483 let mut data = make_random_vec(3);
484 let mut memory = PackedMemorySliceMut::new_owned(&mut data, 0, Packed::WIDTH - 1);
485
486 PackedMemory::slice_mut(&mut memory, 1..);
487 }
488
489 #[test]
490 fn test_split_at_mut() {
491 let mut data = make_random_vec(3);
492 let data_clone = data.clone();
493 let memory = PackedMemorySliceMut::new_slice(&mut data);
494
495 let (left, right) = PackedMemory::split_at_mut(memory, 2 * Packed::WIDTH);
496 assert_eq!(left.as_slice(), &data_clone[0..2]);
497 assert_eq!(right.as_slice(), &data_clone[2..]);
498 }
499
500 #[test]
501 #[should_panic]
502 fn test_split_at_mut_panic_1() {
503 let mut data = make_random_vec(3);
504 let memory = PackedMemorySliceMut::new_slice(&mut data);
505
506 PackedMemory::split_at_mut(memory, 1);
509 }
510
511 #[test]
512 #[should_panic]
513 fn test_split_at_mut_panic_2() {
514 let mut data = make_random_vec(3);
515 let memory = PackedMemorySliceMut::new_owned(&mut data, 0, Packed::WIDTH - 1);
516
517 PackedMemory::split_at_mut(memory, 1);
520 }
521
522 #[test]
523 fn test_split_half() {
524 let data = make_random_vec(2);
525 let data_clone = data.clone();
526 let memory = PackedMemorySlice::new_slice(&data);
527
528 let (left, right) = PackedMemory::split_half(memory);
529 assert_eq!(left.as_slice(), &data_clone[0..1]);
530 assert_eq!(right.as_slice(), &data_clone[1..]);
531
532 let memory = PackedMemorySlice::new_slice(&data[0..1]);
533 let (left, right) = PackedMemory::split_half(memory);
534 assert_eq!(
535 left.iter_scalars().collect_vec(),
536 PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH / 2)
537 .iter_scalars()
538 .collect_vec()
539 );
540 assert_eq!(
541 right.iter_scalars().collect_vec(),
542 PackedMemorySlice::new_owned(&data, Packed::WIDTH / 2, Packed::WIDTH / 2)
543 .iter_scalars()
544 .collect_vec()
545 );
546
547 let memory = PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH / 2);
548 let (left, right) = PackedMemory::split_half(memory);
549 assert_eq!(
550 left.iter_scalars().collect_vec(),
551 PackedMemorySlice::new_owned(&data, 0, Packed::WIDTH / 4)
552 .iter_scalars()
553 .collect_vec()
554 );
555 assert_eq!(
556 right.iter_scalars().collect_vec(),
557 PackedMemorySlice::new_owned(&data, Packed::WIDTH / 4, Packed::WIDTH / 4)
558 .iter_scalars()
559 .collect_vec()
560 );
561 }
562}