binius_field/
byte_iteration.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use binius_utils::random_access_sequence::RandomAccessSequence;
4use bytemuck::{Pod, zeroed_vec};
5
6use crate::{
7	AESTowerField8b, BinaryField128bGhash, PackedBinaryGhash1x128b, PackedBinaryGhash2x128b,
8	PackedBinaryGhash4x128b, PackedField,
9	arch::{
10		packed_8::*, packed_16::*, packed_32::*, packed_64::*, packed_128::*, packed_256::*,
11		packed_512::*, packed_aes_8::*, packed_aes_16::*, packed_aes_32::*, packed_aes_64::*,
12		packed_aes_128::*, packed_aes_256::*, packed_aes_512::*,
13	},
14};
15
16/// A marker trait that the slice of packed values can be iterated as a sequence of bytes.
17/// The order of the iteration by BinaryField1b subfield elements and bits within iterated bytes
18/// must be the same.
19///
20/// # Safety
21/// The implementor must ensure that the cast of the slice of packed values to the slice of bytes
22/// is safe and preserves the order of the 1-bit elements.
23#[allow(unused)]
24unsafe trait SequentialBytes: Pod {}
25
26unsafe impl SequentialBytes for PackedBinaryField8x1b {}
27unsafe impl SequentialBytes for PackedBinaryField16x1b {}
28unsafe impl SequentialBytes for PackedBinaryField32x1b {}
29unsafe impl SequentialBytes for PackedBinaryField64x1b {}
30unsafe impl SequentialBytes for PackedBinaryField128x1b {}
31unsafe impl SequentialBytes for PackedBinaryField256x1b {}
32unsafe impl SequentialBytes for PackedBinaryField512x1b {}
33
34unsafe impl SequentialBytes for AESTowerField8b {}
35
36unsafe impl SequentialBytes for PackedAESBinaryField1x8b {}
37unsafe impl SequentialBytes for PackedAESBinaryField2x8b {}
38unsafe impl SequentialBytes for PackedAESBinaryField4x8b {}
39unsafe impl SequentialBytes for PackedAESBinaryField8x8b {}
40unsafe impl SequentialBytes for PackedAESBinaryField16x8b {}
41unsafe impl SequentialBytes for PackedAESBinaryField32x8b {}
42unsafe impl SequentialBytes for PackedAESBinaryField64x8b {}
43
44unsafe impl SequentialBytes for BinaryField128bGhash {}
45
46unsafe impl SequentialBytes for PackedBinaryGhash1x128b {}
47unsafe impl SequentialBytes for PackedBinaryGhash2x128b {}
48unsafe impl SequentialBytes for PackedBinaryGhash4x128b {}
49
50/// Returns true if T implements `SequentialBytes` trait.
51/// Use a hack that exploits that array copying is optimized for the `Copy` types.
52/// Unfortunately there is no more proper way to perform this check this in Rust at runtime.
53#[inline(always)]
54#[allow(clippy::redundant_clone)] // this is intentional in this method
55pub fn is_sequential_bytes<T>() -> bool {
56	struct X<U>(bool, std::marker::PhantomData<U>);
57
58	impl<U> Clone for X<U> {
59		fn clone(&self) -> Self {
60			Self(false, std::marker::PhantomData)
61		}
62	}
63
64	impl<U: SequentialBytes> Copy for X<U> {}
65
66	let value = [X::<T>(true, std::marker::PhantomData)];
67	let cloned = value.clone();
68
69	cloned[0].0
70}
71
72/// Returns if we can iterate over bytes, each representing 8 1-bit values.
73#[inline(always)]
74pub fn can_iterate_bytes<P: PackedField>() -> bool {
75	// Packed fields with sequential byte order
76	is_sequential_bytes::<P>()
77}
78
79/// Callback for byte iteration.
80/// We can't return different types from the `iterate_bytes` and Fn traits don't support associated
81/// types that's why we use a callback with a generic function.
82pub trait ByteIteratorCallback {
83	fn call(&mut self, iter: impl Iterator<Item = u8>);
84}
85
86/// Iterate over bytes of a slice of the packed values.
87/// The method panics if the packed field doesn't support byte iteration, so use `can_iterate_bytes`
88/// to check it.
89#[inline(always)]
90pub fn iterate_bytes<P: PackedField>(data: &[P], callback: &mut impl ByteIteratorCallback) {
91	if is_sequential_bytes::<P>() {
92		// Safety: `P` implements `SequentialBytes` trait, so the following cast is safe
93		// and preserves the order.
94		let bytes = unsafe {
95			std::slice::from_raw_parts(data.as_ptr() as *const u8, std::mem::size_of_val(data))
96		};
97		callback.call(bytes.iter().copied());
98	} else {
99		unreachable!("packed field doesn't support byte iteration")
100	}
101}
102
103/// Create a lookup table for partial sums of 8 consequent elements with coefficients corresponding
104/// to bits in a byte. The lookup table has the following structure:
105/// [
106///     partial_sum_chunk_0_7_byte_0, partial_sum_chunk_0_7_byte_1, ...,
107/// partial_sum_chunk_0_7_byte_255,     partial_sum_chunk_8_15_byte_0,
108/// partial_sum_chunk_8_15_byte_1, ..., partial_sum_chunk_8_15_byte_255,    ...
109/// ]
110pub fn create_partial_sums_lookup_tables<P: PackedField>(
111	values: impl RandomAccessSequence<P>,
112) -> Vec<P> {
113	let len = values.len();
114	assert!(len.is_multiple_of(8));
115
116	let mut result = zeroed_vec(len * 32);
117
118	for (chunk_idx, chunk_start) in (0..len).step_by(8).enumerate() {
119		let sums = &mut result[chunk_idx * 256..(chunk_idx + 1) * 256];
120
121		for j in 0..8 {
122			let value = values.get(chunk_start + j);
123			let mask = 1 << j;
124			for i in (mask..256).step_by(mask * 2) {
125				for k in 0..mask {
126					sums[i + k] += value;
127				}
128			}
129		}
130	}
131
132	result
133}
134
135#[cfg(test)]
136mod tests {
137	use super::*;
138	use crate::{PackedBinaryField1x1b, PackedBinaryField2x1b, PackedBinaryField4x1b};
139
140	#[test]
141	fn test_sequential_bits() {
142		assert!(is_sequential_bytes::<PackedBinaryField8x1b>());
143		assert!(is_sequential_bytes::<PackedBinaryField16x1b>());
144		assert!(is_sequential_bytes::<PackedBinaryField32x1b>());
145		assert!(is_sequential_bytes::<PackedBinaryField64x1b>());
146		assert!(is_sequential_bytes::<PackedBinaryField128x1b>());
147		assert!(is_sequential_bytes::<PackedBinaryField256x1b>());
148		assert!(is_sequential_bytes::<PackedBinaryField512x1b>());
149
150		assert!(is_sequential_bytes::<AESTowerField8b>());
151		assert!(is_sequential_bytes::<PackedAESBinaryField1x8b>());
152		assert!(is_sequential_bytes::<PackedAESBinaryField2x8b>());
153		assert!(is_sequential_bytes::<PackedAESBinaryField4x8b>());
154		assert!(is_sequential_bytes::<PackedAESBinaryField8x8b>());
155		assert!(is_sequential_bytes::<PackedAESBinaryField16x8b>());
156		assert!(is_sequential_bytes::<PackedAESBinaryField32x8b>());
157		assert!(is_sequential_bytes::<PackedAESBinaryField64x8b>());
158
159		assert!(is_sequential_bytes::<BinaryField128bGhash>());
160		assert!(is_sequential_bytes::<PackedBinaryGhash1x128b>());
161		assert!(is_sequential_bytes::<PackedBinaryGhash2x128b>());
162		assert!(is_sequential_bytes::<PackedBinaryGhash4x128b>());
163
164		assert!(!is_sequential_bytes::<PackedBinaryField1x1b>());
165		assert!(!is_sequential_bytes::<PackedBinaryField2x1b>());
166		assert!(!is_sequential_bytes::<PackedBinaryField4x1b>());
167	}
168
169	#[test]
170	fn test_partial_sums_basic() {
171		let v1 = AESTowerField8b::from(1);
172		let v2 = AESTowerField8b::from(2);
173		let v3 = AESTowerField8b::from(3);
174		let v4 = AESTowerField8b::from(4);
175		let v5 = AESTowerField8b::from(5);
176		let v6 = AESTowerField8b::from(6);
177		let v7 = AESTowerField8b::from(7);
178		let v8 = AESTowerField8b::from(8);
179
180		let values = vec![v1, v2, v3, v4, v5, v6, v7, v8];
181
182		let lookup_table = create_partial_sums_lookup_tables(values.as_slice());
183
184		assert_eq!(lookup_table.len(), 256);
185
186		// Check specific precomputed sums
187		assert_eq!(lookup_table[0b0000_0000], AESTowerField8b::from(0));
188		assert_eq!(lookup_table[0b0000_0001], v1);
189		assert_eq!(lookup_table[0b0000_0011], v1 + v2);
190		assert_eq!(lookup_table[0b0000_0111], v1 + v2 + v3);
191		assert_eq!(lookup_table[0b0000_1111], v1 + v2 + v3 + v4);
192		assert_eq!(lookup_table[0b0001_1111], v1 + v2 + v3 + v4 + v5);
193		assert_eq!(lookup_table[0b0011_1111], v1 + v2 + v3 + v4 + v5 + v6);
194		assert_eq!(lookup_table[0b0111_1111], v1 + v2 + v3 + v4 + v5 + v6 + v7);
195		assert_eq!(lookup_table[0b1111_1111], v1 + v2 + v3 + v4 + v5 + v6 + v7 + v8);
196	}
197
198	#[test]
199	fn test_partial_sums_all_zeros() {
200		let values = vec![AESTowerField8b::from(0); 8];
201		let lookup_table = create_partial_sums_lookup_tables(values.as_slice());
202
203		assert_eq!(lookup_table.len(), 256);
204
205		for &l in lookup_table.iter().take(256) {
206			assert_eq!(l, AESTowerField8b::from(0));
207		}
208	}
209
210	#[test]
211	fn test_partial_sums_single_element() {
212		let mut values = vec![AESTowerField8b::from(0); 8];
213		// Set only the fourth element (index 3)
214		values[3] = AESTowerField8b::from(10);
215
216		let lookup_table = create_partial_sums_lookup_tables(values.as_slice());
217
218		assert_eq!(lookup_table.len(), 256);
219
220		// Only cases where the 4th bit is set should have non-zero sums
221		assert_eq!(lookup_table[0b0000_0000], AESTowerField8b::from(0));
222		assert_eq!(lookup_table[0b0000_1000], AESTowerField8b::from(10));
223		assert_eq!(lookup_table[0b0000_1100], AESTowerField8b::from(10));
224		assert_eq!(lookup_table[0b0001_1000], AESTowerField8b::from(10));
225		assert_eq!(lookup_table[0b1111_1111], AESTowerField8b::from(10));
226	}
227
228	#[test]
229	fn test_partial_sums_alternating_values() {
230		let v1 = AESTowerField8b::from(10);
231		let v2 = AESTowerField8b::from(20);
232		let v3 = AESTowerField8b::from(30);
233		let v4 = AESTowerField8b::from(40);
234
235		let zero = AESTowerField8b::from(0);
236
237		let values = vec![v1, zero, v2, zero, v3, zero, v4, zero];
238
239		let lookup_table = create_partial_sums_lookup_tables(values.as_slice());
240
241		assert_eq!(lookup_table.len(), 256);
242
243		// Expect only the even indexed elements to contribute to the sum
244		assert_eq!(lookup_table[0b0000_0000], zero);
245		assert_eq!(lookup_table[0b0000_0001], v1);
246		assert_eq!(lookup_table[0b0000_0101], v1 + v2);
247		assert_eq!(lookup_table[0b0000_1111], v1 + v2);
248		assert_eq!(lookup_table[0b1111_1111], v1 + v2 + v3 + v4);
249	}
250}