1use std::{
8	fmt::Debug,
9	iter::{self, Product, Sum},
10	ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign},
11};
12
13use binius_utils::{
14	iter::IterExtensions,
15	random_access_sequence::{RandomAccessSequence, RandomAccessSequenceMut},
16};
17use bytemuck::Zeroable;
18
19use super::{
20	Error, Random,
21	arithmetic_traits::{Broadcast, MulAlpha, Square},
22	binary_field_arithmetic::TowerFieldArithmetic,
23};
24use crate::{
25	BinaryField, Field, PackedExtension, arithmetic_traits::InvertOrZero,
26	is_packed_field_indexable, underlier::WithUnderlier,
27};
28
29pub trait PackedField:
35	Default
36	+ Debug
37	+ Clone
38	+ Copy
39	+ Eq
40	+ Sized
41	+ Add<Output = Self>
42	+ Sub<Output = Self>
43	+ Mul<Output = Self>
44	+ AddAssign
45	+ SubAssign
46	+ MulAssign
47	+ Add<Self::Scalar, Output = Self>
48	+ Sub<Self::Scalar, Output = Self>
49	+ Mul<Self::Scalar, Output = Self>
50	+ AddAssign<Self::Scalar>
51	+ SubAssign<Self::Scalar>
52	+ MulAssign<Self::Scalar>
53	+ Sum
55	+ Product
56	+ Send
57	+ Sync
58	+ Zeroable
59	+ Random
60	+ 'static
61{
62	type Scalar: Field;
63
64	const LOG_WIDTH: usize;
66
67	const WIDTH: usize = 1 << Self::LOG_WIDTH;
71
72	unsafe fn get_unchecked(&self, i: usize) -> Self::Scalar;
76
77	unsafe fn set_unchecked(&mut self, i: usize, scalar: Self::Scalar);
81
82	#[inline]
84	fn get_checked(&self, i: usize) -> Result<Self::Scalar, Error> {
85		(i < Self::WIDTH)
86			.then_some(unsafe { self.get_unchecked(i) })
87			.ok_or(Error::IndexOutOfRange {
88				index: i,
89				max: Self::WIDTH,
90			})
91	}
92
93	#[inline]
95	fn set_checked(&mut self, i: usize, scalar: Self::Scalar) -> Result<(), Error> {
96		(i < Self::WIDTH)
97			.then(|| unsafe { self.set_unchecked(i, scalar) })
98			.ok_or(Error::IndexOutOfRange {
99				index: i,
100				max: Self::WIDTH,
101			})
102	}
103
104	#[inline]
106	fn get(&self, i: usize) -> Self::Scalar {
107		self.get_checked(i).expect("index must be less than width")
108	}
109
110	#[inline]
112	fn set(&mut self, i: usize, scalar: Self::Scalar) {
113		self.set_checked(i, scalar).expect("index must be less than width")
114	}
115
116	#[inline]
117	fn into_iter(self) -> impl Iterator<Item=Self::Scalar> + Send + Clone {
118		(0..Self::WIDTH).map_skippable(move |i|
119			unsafe { self.get_unchecked(i) })
121	}
122
123	#[inline]
124	fn iter(&self) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
125		(0..Self::WIDTH).map_skippable(move |i|
126			unsafe { self.get_unchecked(i) })
128	}
129
130	#[inline]
131	fn iter_slice(slice: &[Self]) -> impl Iterator<Item=Self::Scalar> + Send + Clone + '_ {
132		slice.iter().flat_map(Self::iter)
133	}
134
135	#[inline]
136	fn zero() -> Self {
137		Self::broadcast(Self::Scalar::ZERO)
138	}
139
140	#[inline]
141	fn one() -> Self {
142		Self::broadcast(Self::Scalar::ONE)
143	}
144
145	#[inline(always)]
147	fn set_single(scalar: Self::Scalar) -> Self {
148		let mut result = Self::default();
149		result.set(0, scalar);
150
151		result
152	}
153
154	fn broadcast(scalar: Self::Scalar) -> Self;
155
156	fn from_fn(f: impl FnMut(usize) -> Self::Scalar) -> Self;
158
159	fn try_from_fn<E>(
161            mut f: impl FnMut(usize) -> Result<Self::Scalar, E>,
162        ) -> Result<Self, E> {
163            let mut result = Self::default();
164            for i in 0..Self::WIDTH {
165                let scalar = f(i)?;
166                unsafe {
167                    result.set_unchecked(i, scalar);
168                };
169            }
170            Ok(result)
171        }
172
173	#[inline]
179	fn from_scalars(values: impl IntoIterator<Item=Self::Scalar>) -> Self {
180		let mut result = Self::default();
181		for (i, val) in values.into_iter().take(Self::WIDTH).enumerate() {
182			result.set(i, val);
183		}
184		result
185	}
186
187	fn square(self) -> Self;
189
190	fn pow(self, exp: u64) -> Self {
192		let mut res = Self::one();
193		for i in (0..64).rev() {
194			res = res.square();
195			if ((exp >> i) & 1) == 1 {
196				res.mul_assign(self)
197			}
198		}
199		res
200	}
201
202	fn invert_or_zero(self) -> Self;
204
205	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
221
222	fn unzip(self, other: Self, log_block_len: usize) -> (Self, Self);
235
236	#[inline]
268	fn spread(self, log_block_len: usize, block_idx: usize) -> Self {
269		assert!(log_block_len <= Self::LOG_WIDTH);
270		assert!(block_idx < 1 << (Self::LOG_WIDTH - log_block_len));
271
272		unsafe {
274			self.spread_unchecked(log_block_len, block_idx)
275		}
276	}
277
278	#[inline]
283	unsafe fn spread_unchecked(self, log_block_len: usize, block_idx: usize) -> Self {
284		let block_len = 1 << log_block_len;
285		let repeat = 1 << (Self::LOG_WIDTH - log_block_len);
286
287		Self::from_scalars(
288			self.iter().skip(block_idx * block_len).take(block_len).flat_map(|elem| iter::repeat_n(elem, repeat))
289		)
290	}
291}
292
293#[inline]
298pub fn iter_packed_slice_with_offset<P: PackedField>(
299	packed: &[P],
300	offset: usize,
301) -> impl Iterator<Item = P::Scalar> + '_ + Send {
302	let (packed, offset): (&[P], usize) = if offset < packed.len() * P::WIDTH {
303		(&packed[(offset / P::WIDTH)..], offset % P::WIDTH)
304	} else {
305		(&[], 0)
306	};
307
308	P::iter_slice(packed).skip(offset)
309}
310
311#[inline(always)]
312pub fn get_packed_slice<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
313	assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
314
315	unsafe { get_packed_slice_unchecked(packed, i) }
316}
317
318#[inline(always)]
322pub unsafe fn get_packed_slice_unchecked<P: PackedField>(packed: &[P], i: usize) -> P::Scalar {
323	if is_packed_field_indexable::<P>() {
324		unsafe { *(packed.as_ptr() as *const P::Scalar).add(i) }
328	} else {
329		unsafe {
334			packed
335				.get_unchecked(i >> P::LOG_WIDTH)
336				.get_unchecked(i % P::WIDTH)
337		}
338	}
339}
340
341#[inline]
342pub fn get_packed_slice_checked<P: PackedField>(
343	packed: &[P],
344	i: usize,
345) -> Result<P::Scalar, Error> {
346	if i >> P::LOG_WIDTH < packed.len() {
347		Ok(unsafe { get_packed_slice_unchecked(packed, i) })
349	} else {
350		Err(Error::IndexOutOfRange {
351			index: i,
352			max: len_packed_slice(packed),
353		})
354	}
355}
356
357#[inline]
361pub unsafe fn set_packed_slice_unchecked<P: PackedField>(
362	packed: &mut [P],
363	i: usize,
364	scalar: P::Scalar,
365) {
366	if is_packed_field_indexable::<P>() {
367		unsafe {
371			*(packed.as_mut_ptr() as *mut P::Scalar).add(i) = scalar;
372		}
373	} else {
374		unsafe {
378			packed
379				.get_unchecked_mut(i >> P::LOG_WIDTH)
380				.set_unchecked(i % P::WIDTH, scalar)
381		}
382	}
383}
384
385#[inline]
386pub fn set_packed_slice<P: PackedField>(packed: &mut [P], i: usize, scalar: P::Scalar) {
387	assert!(i >> P::LOG_WIDTH < packed.len(), "index out of bounds");
388
389	unsafe { set_packed_slice_unchecked(packed, i, scalar) }
390}
391
392#[inline]
393pub fn set_packed_slice_checked<P: PackedField>(
394	packed: &mut [P],
395	i: usize,
396	scalar: P::Scalar,
397) -> Result<(), Error> {
398	if i >> P::LOG_WIDTH < packed.len() {
399		unsafe { set_packed_slice_unchecked(packed, i, scalar) };
401		Ok(())
402	} else {
403		Err(Error::IndexOutOfRange {
404			index: i,
405			max: len_packed_slice(packed),
406		})
407	}
408}
409
410#[inline(always)]
411pub const fn len_packed_slice<P: PackedField>(packed: &[P]) -> usize {
412	packed.len() << P::LOG_WIDTH
413}
414
415#[inline]
419pub fn packed_from_fn_with_offset<P: PackedField>(
420	offset: usize,
421	mut f: impl FnMut(usize) -> P::Scalar,
422) -> P {
423	P::from_fn(|i| f(i + offset * P::WIDTH))
424}
425
426pub fn mul_by_subfield_scalar<P: PackedExtension<FS>, FS: Field>(val: P, multiplier: FS) -> P {
428	use crate::underlier::UnderlierType;
429
430	let subfield_bits = FS::Underlier::BITS;
433	let extension_bits = <<P as PackedField>::Scalar as WithUnderlier>::Underlier::BITS;
434
435	if (subfield_bits == 1 && extension_bits > 8) || extension_bits >= 32 {
436		P::from_fn(|i| unsafe { val.get_unchecked(i) } * multiplier)
437	} else {
438		P::cast_ext(P::cast_base(val) * P::PackedSubfield::broadcast(multiplier))
439	}
440}
441
442pub fn pack_slice<P: PackedField>(scalars: &[P::Scalar]) -> Vec<P> {
444	scalars
445		.chunks(P::WIDTH)
446		.map(|chunk| P::from_scalars(chunk.iter().copied()))
447		.collect()
448}
449
450#[derive(Clone)]
452pub struct PackedSlice<'a, P: PackedField> {
453	slice: &'a [P],
454	len: usize,
455}
456
457impl<'a, P: PackedField> PackedSlice<'a, P> {
458	#[inline(always)]
459	pub fn new(slice: &'a [P]) -> Self {
460		Self {
461			slice,
462			len: len_packed_slice(slice),
463		}
464	}
465
466	#[inline(always)]
467	pub fn new_with_len(slice: &'a [P], len: usize) -> Self {
468		assert!(len <= len_packed_slice(slice));
469
470		Self { slice, len }
471	}
472}
473
474impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSlice<'_, P> {
475	#[inline(always)]
476	fn len(&self) -> usize {
477		self.len
478	}
479
480	#[inline(always)]
481	unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
482		unsafe { get_packed_slice_unchecked(self.slice, index) }
483	}
484}
485
486pub struct PackedSliceMut<'a, P: PackedField> {
488	slice: &'a mut [P],
489	len: usize,
490}
491
492impl<'a, P: PackedField> PackedSliceMut<'a, P> {
493	#[inline(always)]
494	pub fn new(slice: &'a mut [P]) -> Self {
495		let len = len_packed_slice(slice);
496		Self { slice, len }
497	}
498
499	#[inline(always)]
500	pub fn new_with_len(slice: &'a mut [P], len: usize) -> Self {
501		assert!(len <= len_packed_slice(slice));
502
503		Self { slice, len }
504	}
505}
506
507impl<P: PackedField> RandomAccessSequence<P::Scalar> for PackedSliceMut<'_, P> {
508	#[inline(always)]
509	fn len(&self) -> usize {
510		self.len
511	}
512
513	#[inline(always)]
514	unsafe fn get_unchecked(&self, index: usize) -> P::Scalar {
515		unsafe { get_packed_slice_unchecked(self.slice, index) }
516	}
517}
518impl<P: PackedField> RandomAccessSequenceMut<P::Scalar> for PackedSliceMut<'_, P> {
519	#[inline(always)]
520	unsafe fn set_unchecked(&mut self, index: usize, value: P::Scalar) {
521		unsafe { set_packed_slice_unchecked(self.slice, index, value) }
522	}
523}
524
525impl<F: Field> Broadcast<F> for F {
526	#[inline]
527	fn broadcast(scalar: F) -> Self {
528		scalar
529	}
530}
531
532impl<T: TowerFieldArithmetic> MulAlpha for T {
533	#[inline]
534	fn mul_alpha(self) -> Self {
535		<Self as TowerFieldArithmetic>::multiply_alpha(self)
536	}
537}
538
539impl<F: Field> PackedField for F {
540	type Scalar = F;
541
542	const LOG_WIDTH: usize = 0;
543
544	#[inline]
545	unsafe fn get_unchecked(&self, _i: usize) -> Self::Scalar {
546		*self
547	}
548
549	#[inline]
550	unsafe fn set_unchecked(&mut self, _i: usize, scalar: Self::Scalar) {
551		*self = scalar;
552	}
553
554	#[inline]
555	fn iter(&self) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
556		iter::once(*self)
557	}
558
559	#[inline]
560	fn into_iter(self) -> impl Iterator<Item = Self::Scalar> + Send + Clone {
561		iter::once(self)
562	}
563
564	#[inline]
565	fn iter_slice(slice: &[Self]) -> impl Iterator<Item = Self::Scalar> + Send + Clone + '_ {
566		slice.iter().copied()
567	}
568
569	fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
570		panic!("cannot interleave when WIDTH = 1");
571	}
572
573	fn unzip(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
574		panic!("cannot transpose when WIDTH = 1");
575	}
576
577	#[inline]
578	fn broadcast(scalar: Self::Scalar) -> Self {
579		scalar
580	}
581
582	#[inline]
583	fn zero() -> Self {
584		Self::ZERO
585	}
586
587	#[inline]
588	fn one() -> Self {
589		Self::ONE
590	}
591
592	#[inline]
593	fn square(self) -> Self {
594		<Self as Square>::square(self)
595	}
596
597	#[inline]
598	fn invert_or_zero(self) -> Self {
599		<Self as InvertOrZero>::invert_or_zero(self)
600	}
601
602	#[inline]
603	fn from_fn(mut f: impl FnMut(usize) -> Self::Scalar) -> Self {
604		f(0)
605	}
606
607	#[inline]
608	unsafe fn spread_unchecked(self, _log_block_len: usize, _block_idx: usize) -> Self {
609		self
610	}
611}
612
613pub trait PackedBinaryField: PackedField<Scalar: BinaryField> {}
615
616impl<PT> PackedBinaryField for PT where PT: PackedField<Scalar: BinaryField> {}
617
618#[cfg(test)]
619mod tests {
620	use itertools::Itertools;
621	use rand::{Rng, RngCore, SeedableRng, rngs::StdRng};
622
623	use super::*;
624	use crate::{
625		AESTowerField8b, BinaryField1b, BinaryField128bGhash, PackedBinaryGhash1x128b,
626		PackedBinaryGhash2x128b, PackedBinaryGhash4x128b, PackedField,
627		arch::{
628			packed_1::*, packed_2::*, packed_4::*, packed_8::*, packed_16::*, packed_32::*,
629			packed_64::*, packed_128::*, packed_256::*, packed_512::*, packed_aes_8::*,
630			packed_aes_16::*, packed_aes_32::*, packed_aes_64::*, packed_aes_128::*,
631			packed_aes_256::*, packed_aes_512::*,
632		},
633	};
634
635	trait PackedFieldTest {
636		fn run<P: PackedField>(&self);
637	}
638
639	fn run_for_all_packed_fields(test: &impl PackedFieldTest) {
641		test.run::<BinaryField1b>();
643		test.run::<PackedBinaryField1x1b>();
644		test.run::<PackedBinaryField2x1b>();
645		test.run::<PackedBinaryField4x1b>();
646		test.run::<PackedBinaryField8x1b>();
647		test.run::<PackedBinaryField16x1b>();
648		test.run::<PackedBinaryField32x1b>();
649		test.run::<PackedBinaryField64x1b>();
650		test.run::<PackedBinaryField128x1b>();
651		test.run::<PackedBinaryField256x1b>();
652		test.run::<PackedBinaryField512x1b>();
653
654		test.run::<AESTowerField8b>();
656		test.run::<PackedAESBinaryField1x8b>();
657		test.run::<PackedAESBinaryField2x8b>();
658		test.run::<PackedAESBinaryField4x8b>();
659		test.run::<PackedAESBinaryField8x8b>();
660		test.run::<PackedAESBinaryField16x8b>();
661		test.run::<PackedAESBinaryField32x8b>();
662		test.run::<PackedAESBinaryField64x8b>();
663
664		test.run::<BinaryField128bGhash>();
666		test.run::<PackedBinaryGhash1x128b>();
667		test.run::<PackedBinaryGhash2x128b>();
668		test.run::<PackedBinaryGhash4x128b>();
669	}
670
671	fn check_value_iteration<P: PackedField>(mut rng: impl RngCore) {
672		let packed = P::random(&mut rng);
673		let mut iter = packed.iter();
674		for i in 0..P::WIDTH {
675			assert_eq!(packed.get(i), iter.next().unwrap());
676		}
677		assert!(iter.next().is_none());
678	}
679
680	fn check_ref_iteration<P: PackedField>(mut rng: impl RngCore) {
681		let packed = P::random(&mut rng);
682		let mut iter = packed.into_iter();
683		for i in 0..P::WIDTH {
684			assert_eq!(packed.get(i), iter.next().unwrap());
685		}
686		assert!(iter.next().is_none());
687	}
688
689	fn check_slice_iteration<P: PackedField>(mut rng: impl RngCore) {
690		for len in [0, 1, 5] {
691			let packed = std::iter::repeat_with(|| P::random(&mut rng))
692				.take(len)
693				.collect::<Vec<_>>();
694
695			let elements_count = len * P::WIDTH;
696			for offset in [
697				0,
698				1,
699				rng.random_range(0..elements_count.max(1)),
700				elements_count.saturating_sub(1),
701				elements_count,
702			] {
703				let actual = iter_packed_slice_with_offset(&packed, offset).collect::<Vec<_>>();
704				let expected = (offset..elements_count)
705					.map(|i| get_packed_slice(&packed, i))
706					.collect::<Vec<_>>();
707
708				assert_eq!(actual, expected);
709			}
710		}
711	}
712
713	struct PackedFieldIterationTest;
714
715	impl PackedFieldTest for PackedFieldIterationTest {
716		fn run<P: PackedField>(&self) {
717			let mut rng = StdRng::seed_from_u64(0);
718
719			check_value_iteration::<P>(&mut rng);
720			check_ref_iteration::<P>(&mut rng);
721			check_slice_iteration::<P>(&mut rng);
722		}
723	}
724
725	#[test]
726	fn test_iteration() {
727		run_for_all_packed_fields(&PackedFieldIterationTest);
728	}
729
730	fn check_collection<F: Field>(collection: &impl RandomAccessSequence<F>, expected: &[F]) {
731		assert_eq!(collection.len(), expected.len());
732
733		for (i, v) in expected.iter().enumerate() {
734			assert_eq!(&collection.get(i), v);
735			assert_eq!(&unsafe { collection.get_unchecked(i) }, v);
736		}
737	}
738
739	fn check_collection_get_set<F: Field>(
740		collection: &mut impl RandomAccessSequenceMut<F>,
741		random: &mut impl FnMut() -> F,
742	) {
743		for i in 0..collection.len() {
744			let value = random();
745			collection.set(i, value);
746			assert_eq!(collection.get(i), value);
747			assert_eq!(unsafe { collection.get_unchecked(i) }, value);
748		}
749	}
750
751	#[test]
752	fn check_packed_slice() {
753		let slice: &[PackedAESBinaryField16x8b] = &[];
754		let packed_slice = PackedSlice::new(slice);
755		check_collection(&packed_slice, &[]);
756		let packed_slice = PackedSlice::new_with_len(slice, 0);
757		check_collection(&packed_slice, &[]);
758
759		let mut rng = StdRng::seed_from_u64(0);
760		let slice: &[PackedAESBinaryField16x8b] = &[
761			PackedAESBinaryField16x8b::random(&mut rng),
762			PackedAESBinaryField16x8b::random(&mut rng),
763		];
764		let packed_slice = PackedSlice::new(slice);
765		check_collection(&packed_slice, &PackedField::iter_slice(slice).collect_vec());
766
767		let packed_slice = PackedSlice::new_with_len(slice, 3);
768		check_collection(&packed_slice, &PackedField::iter_slice(slice).take(3).collect_vec());
769	}
770
771	#[test]
772	fn check_packed_slice_mut() {
773		let mut rng = StdRng::seed_from_u64(0);
774		let mut random = || AESTowerField8b::random(&mut rng);
775
776		let slice: &mut [PackedAESBinaryField16x8b] = &mut [];
777		let packed_slice = PackedSliceMut::new(slice);
778		check_collection(&packed_slice, &[]);
779		let packed_slice = PackedSliceMut::new_with_len(slice, 0);
780		check_collection(&packed_slice, &[]);
781
782		let mut rng = StdRng::seed_from_u64(0);
783		let slice: &mut [PackedAESBinaryField16x8b] = &mut [
784			PackedAESBinaryField16x8b::random(&mut rng),
785			PackedAESBinaryField16x8b::random(&mut rng),
786		];
787		let values = PackedField::iter_slice(slice).collect_vec();
788		let mut packed_slice = PackedSliceMut::new(slice);
789		check_collection(&packed_slice, &values);
790		check_collection_get_set(&mut packed_slice, &mut random);
791
792		let values = PackedField::iter_slice(slice).collect_vec();
793		let mut packed_slice = PackedSliceMut::new_with_len(slice, 3);
794		check_collection(&packed_slice, &values[..3]);
795		check_collection_get_set(&mut packed_slice, &mut random);
796	}
797}