binius_field/underlier/
scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	array, mem,
5	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
6};
7
8use binius_utils::{
9	DeserializeBytes, SerializationError, SerializeBytes,
10	bytes::{Buf, BufMut},
11	checked_arithmetics::checked_log_2,
12};
13use bytemuck::{Pod, Zeroable};
14use rand::{
15	Rng,
16	distr::{Distribution, StandardUniform},
17};
18
19use super::{Divisible, NumCast, UnderlierType, UnderlierWithBitOps, mapget};
20use crate::Random;
21
22/// A type that represents N elements of the same underlier type.
23/// Used as an underlier for 256-bit and 512-bit packed fields in the portable implementation.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
25#[repr(transparent)]
26pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
27
28impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
29	fn default() -> Self {
30		Self(array::from_fn(|_| U::default()))
31	}
32}
33
34impl<U: Random, const N: usize> Distribution<ScaledUnderlier<U, N>> for StandardUniform {
35	fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledUnderlier<U, N> {
36		ScaledUnderlier(array::from_fn(|_| U::random(&mut rng)))
37	}
38}
39
40impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
41	fn from(val: ScaledUnderlier<U, N>) -> Self {
42		val.0
43	}
44}
45
46impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
47	fn from(value: [T; N]) -> Self {
48		Self(value.map(U::from))
49	}
50}
51
52impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
53	fn from(value: [T; 4]) -> Self {
54		Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
55	}
56}
57
58unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
59unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
60
61impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
62	const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
63}
64
65impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
66	type Output = Self;
67
68	fn bitand(self, rhs: Self) -> Self::Output {
69		Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
70	}
71}
72
73impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
74	fn bitand_assign(&mut self, rhs: Self) {
75		for i in 0..N {
76			self.0[i] &= rhs.0[i];
77		}
78	}
79}
80
81impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
82	type Output = Self;
83
84	fn bitor(self, rhs: Self) -> Self::Output {
85		Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
86	}
87}
88
89impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
90	fn bitor_assign(&mut self, rhs: Self) {
91		for i in 0..N {
92			self.0[i] |= rhs.0[i];
93		}
94	}
95}
96
97impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
98	type Output = Self;
99
100	fn bitxor(self, rhs: Self) -> Self::Output {
101		Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
102	}
103}
104
105impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
106	fn bitxor_assign(&mut self, rhs: Self) {
107		for i in 0..N {
108			self.0[i] ^= rhs.0[i];
109		}
110	}
111}
112
113impl<U: UnderlierWithBitOps, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
114	type Output = Self;
115
116	fn shr(self, rhs: usize) -> Self::Output {
117		let mut result = Self::default();
118
119		let shift_in_items = rhs / U::BITS;
120		for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
121			if i + shift_in_items < N {
122				result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
123			}
124			if i + shift_in_items + 1 < N && !rhs.is_multiple_of(U::BITS) {
125				result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
126			}
127		}
128
129		result
130	}
131}
132
133impl<U: UnderlierWithBitOps, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
134	type Output = Self;
135
136	fn shl(self, rhs: usize) -> Self::Output {
137		let mut result = Self::default();
138
139		let shift_in_items = rhs / U::BITS;
140		for i in shift_in_items.saturating_sub(1)..N {
141			if i >= shift_in_items {
142				result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
143			}
144			if i > shift_in_items && !rhs.is_multiple_of(U::BITS) {
145				result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
146			}
147		}
148
149		result
150	}
151}
152
153impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
154	type Output = Self;
155
156	fn not(self) -> Self::Output {
157		Self(self.0.map(U::not))
158	}
159}
160
161impl<U: UnderlierWithBitOps + Pod, const N: usize> UnderlierWithBitOps for ScaledUnderlier<U, N> {
162	const ZERO: Self = Self([U::ZERO; N]);
163	const ONE: Self = {
164		let mut arr = [U::ZERO; N];
165		arr[0] = U::ONE;
166		Self(arr)
167	};
168	const ONES: Self = Self([U::ONES; N]);
169
170	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
171		if log_block_len < U::LOG_BITS {
172			// Case 1: Delegate to element-wise interleave
173			let pairs: [(U, U); N] =
174				array::from_fn(|i| self.0[i].interleave(other.0[i], log_block_len));
175			(Self(array::from_fn(|i| pairs[i].0)), Self(array::from_fn(|i| pairs[i].1)))
176		} else {
177			// Case 2: Interleave at element level by swapping array elements
178			// Each super-block of 2*block_len elements gets transposed as a 2x2 matrix of blocks
179			let block_len = 1 << (log_block_len - U::LOG_BITS);
180
181			let mut a = self.0;
182			let mut b = other.0;
183			for super_block in 0..(N / (2 * block_len)) {
184				let base = super_block * 2 * block_len;
185				for offset in 0..block_len {
186					mem::swap(&mut a[base + block_len + offset], &mut b[base + offset]);
187				}
188			}
189
190			(Self(a), Self(b))
191		}
192	}
193}
194
195impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
196where
197	Self: NumCast<U>,
198{
199	fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
200		Self::num_cast_from(val.0[0])
201	}
202}
203
204impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
205where
206	U: From<u8>,
207{
208	fn from(val: u8) -> Self {
209		Self(array::from_fn(|_| U::from(val)))
210	}
211}
212
213impl<U, T, const N: usize> Divisible<T> for ScaledUnderlier<U, N>
214where
215	U: Divisible<T> + Pod + Send + Sync,
216	T: Send + 'static,
217{
218	const LOG_N: usize = <U as Divisible<T>>::LOG_N + checked_log_2(N);
219
220	#[inline]
221	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = T> + Send + Clone {
222		mapget::value_iter(value)
223	}
224
225	#[inline]
226	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
227		mapget::value_iter(*value)
228	}
229
230	#[inline]
231	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
232		mapget::slice_iter(slice)
233	}
234
235	#[inline]
236	fn get(self, index: usize) -> T {
237		let u_index = index >> <U as Divisible<T>>::LOG_N;
238		let sub_index = index & (<U as Divisible<T>>::N - 1);
239		Divisible::<T>::get(self.0[u_index], sub_index)
240	}
241
242	#[inline]
243	fn set(self, index: usize, val: T) -> Self {
244		let u_index = index >> <U as Divisible<T>>::LOG_N;
245		let sub_index = index & (<U as Divisible<T>>::N - 1);
246		let mut arr = self.0;
247		arr[u_index] = Divisible::<T>::set(arr[u_index], sub_index, val);
248		Self(arr)
249	}
250
251	#[inline]
252	fn broadcast(val: T) -> Self {
253		Self([Divisible::<T>::broadcast(val); N])
254	}
255
256	#[inline]
257	fn from_iter(mut iter: impl Iterator<Item = T>) -> Self {
258		Self(array::from_fn(|_| Divisible::<T>::from_iter(&mut iter)))
259	}
260}
261
262impl<U: SerializeBytes, const N: usize> SerializeBytes for ScaledUnderlier<U, N> {
263	fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
264		self.0.serialize(write_buf)
265	}
266}
267
268impl<U: DeserializeBytes, const N: usize> DeserializeBytes for ScaledUnderlier<U, N> {
269	fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError> {
270		<[U; N]>::deserialize(read_buf).map(Self)
271	}
272}
273
274#[cfg(test)]
275mod tests {
276	use super::*;
277
278	#[test]
279	fn test_shr() {
280		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
281		assert_eq!(
282			val >> 1,
283			ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
284		);
285		assert_eq!(
286			val >> 2,
287			ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
288		);
289		assert_eq!(
290			val >> 8,
291			ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
292		);
293		assert_eq!(
294			val >> 9,
295			ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
296		);
297	}
298
299	#[test]
300	fn test_shl() {
301		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
302		assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
303		assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
304		assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
305		assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
306	}
307
308	#[test]
309	fn test_interleave_within_element() {
310		// Test case 1: log_block_len < U::LOG_BITS
311		// ScaledUnderlier<u8, 4> has LOG_BITS = 5 (32 bits total)
312		// u8 has LOG_BITS = 3
313		let a = ScaledUnderlier::<u8, 4>([0b01010101, 0b11110000, 0b00001111, 0b10101010]);
314		let b = ScaledUnderlier::<u8, 4>([0b10101010, 0b00001111, 0b11110000, 0b01010101]);
315
316		// At log_block_len = 0 (1-bit blocks), should delegate to u8::interleave
317		let (c, d) = a.interleave(b, 0);
318
319		// Verify element-wise interleave occurred
320		for i in 0..4 {
321			let (expected_c, expected_d) = a.0[i].interleave(b.0[i], 0);
322			assert_eq!(c.0[i], expected_c);
323			assert_eq!(d.0[i], expected_d);
324		}
325	}
326
327	#[test]
328	fn test_interleave_across_elements() {
329		// Test case 2: log_block_len >= U::LOG_BITS
330		let a = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
331		let b = ScaledUnderlier::<u8, 4>([4, 5, 6, 7]);
332
333		// At log_block_len = 3 (8-bit blocks = 1 element), swap individual elements
334		let (c, d) = a.interleave(b, 3);
335		assert_eq!(c.0, [0, 4, 2, 6]);
336		assert_eq!(d.0, [1, 5, 3, 7]);
337
338		// At log_block_len = 4 (16-bit blocks = 2 elements), swap pairs
339		let (c, d) = a.interleave(b, 4);
340		assert_eq!(c.0, [0, 1, 4, 5]);
341		assert_eq!(d.0, [2, 3, 6, 7]);
342	}
343}