Skip to main content

binius_field/underlier/
scaled.rs

1// Copyright 2024-2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{
5	array,
6	fmt::{self, LowerHex},
7	mem,
8	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
9};
10
11use binius_utils::{
12	DeserializeBytes, SerializationError, SerializeBytes,
13	bytes::{Buf, BufMut},
14	checked_arithmetics::checked_log_2,
15};
16use bytemuck::{Pod, Zeroable};
17use rand::{
18	Rng,
19	distr::{Distribution, StandardUniform},
20};
21
22use super::{Divisible, NumCast, U1, UnderlierType, mapget};
23use crate::Random;
24
25/// A type that represents N elements of the same underlier type.
26/// Used as an underlier for 256-bit and 512-bit packed fields in the portable implementation.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
28#[repr(transparent)]
29pub struct ScaledUnderlier<U, const N: usize>(pub [U; N]);
30
31impl<U: Default, const N: usize> Default for ScaledUnderlier<U, N> {
32	fn default() -> Self {
33		Self(array::from_fn(|_| U::default()))
34	}
35}
36
37impl<U: Random, const N: usize> Distribution<ScaledUnderlier<U, N>> for StandardUniform {
38	fn sample<R: Rng + ?Sized>(&self, mut rng: &mut R) -> ScaledUnderlier<U, N> {
39		ScaledUnderlier(array::from_fn(|_| U::random(&mut rng)))
40	}
41}
42
43impl<U, const N: usize> From<ScaledUnderlier<U, N>> for [U; N] {
44	fn from(val: ScaledUnderlier<U, N>) -> Self {
45		val.0
46	}
47}
48
49impl<T, U: From<T>, const N: usize> From<[T; N]> for ScaledUnderlier<U, N> {
50	fn from(value: [T; N]) -> Self {
51		Self(value.map(U::from))
52	}
53}
54
55impl<T: Copy, U: From<[T; 2]>> From<[T; 4]> for ScaledUnderlier<U, 2> {
56	fn from(value: [T; 4]) -> Self {
57		Self([[value[0], value[1]], [value[2], value[3]]].map(Into::into))
58	}
59}
60
61unsafe impl<U: Zeroable, const N: usize> Zeroable for ScaledUnderlier<U, N> {}
62unsafe impl<U: Pod, const N: usize> Pod for ScaledUnderlier<U, N> {}
63
64impl<U: BitAnd<Output = U> + Copy, const N: usize> BitAnd for ScaledUnderlier<U, N> {
65	type Output = Self;
66
67	fn bitand(self, rhs: Self) -> Self::Output {
68		Self(array::from_fn(|i| self.0[i] & rhs.0[i]))
69	}
70}
71
72impl<U: BitAndAssign + Copy, const N: usize> BitAndAssign for ScaledUnderlier<U, N> {
73	fn bitand_assign(&mut self, rhs: Self) {
74		for i in 0..N {
75			self.0[i] &= rhs.0[i];
76		}
77	}
78}
79
80impl<U: BitOr<Output = U> + Copy, const N: usize> BitOr for ScaledUnderlier<U, N> {
81	type Output = Self;
82
83	fn bitor(self, rhs: Self) -> Self::Output {
84		Self(array::from_fn(|i| self.0[i] | rhs.0[i]))
85	}
86}
87
88impl<U: BitOrAssign + Copy, const N: usize> BitOrAssign for ScaledUnderlier<U, N> {
89	fn bitor_assign(&mut self, rhs: Self) {
90		for i in 0..N {
91			self.0[i] |= rhs.0[i];
92		}
93	}
94}
95
96impl<U: BitXor<Output = U> + Copy, const N: usize> BitXor for ScaledUnderlier<U, N> {
97	type Output = Self;
98
99	fn bitxor(self, rhs: Self) -> Self::Output {
100		Self(array::from_fn(|i| self.0[i] ^ rhs.0[i]))
101	}
102}
103
104impl<U: BitXorAssign + Copy, const N: usize> BitXorAssign for ScaledUnderlier<U, N> {
105	fn bitxor_assign(&mut self, rhs: Self) {
106		for i in 0..N {
107			self.0[i] ^= rhs.0[i];
108		}
109	}
110}
111
112impl<U: UnderlierType, const N: usize> Shr<usize> for ScaledUnderlier<U, N> {
113	type Output = Self;
114
115	fn shr(self, rhs: usize) -> Self::Output {
116		let mut result = Self::default();
117
118		let shift_in_items = rhs / U::BITS;
119		for i in 0..N.saturating_sub(shift_in_items.saturating_sub(1)) {
120			if i + shift_in_items < N {
121				result.0[i] |= self.0[i + shift_in_items] >> (rhs % U::BITS);
122			}
123			if i + shift_in_items + 1 < N && !rhs.is_multiple_of(U::BITS) {
124				result.0[i] |= self.0[i + shift_in_items + 1] << (U::BITS - (rhs % U::BITS));
125			}
126		}
127
128		result
129	}
130}
131
132impl<U: UnderlierType, const N: usize> Shl<usize> for ScaledUnderlier<U, N> {
133	type Output = Self;
134
135	fn shl(self, rhs: usize) -> Self::Output {
136		let mut result = Self::default();
137
138		let shift_in_items = rhs / U::BITS;
139		for i in shift_in_items.saturating_sub(1)..N {
140			if i >= shift_in_items {
141				result.0[i] |= self.0[i - shift_in_items] << (rhs % U::BITS);
142			}
143			if i > shift_in_items && !rhs.is_multiple_of(U::BITS) {
144				result.0[i] |= self.0[i - shift_in_items - 1] >> (U::BITS - (rhs % U::BITS));
145			}
146		}
147
148		result
149	}
150}
151
152impl<U: Not<Output = U>, const N: usize> Not for ScaledUnderlier<U, N> {
153	type Output = Self;
154
155	fn not(self) -> Self::Output {
156		Self(self.0.map(U::not))
157	}
158}
159
160impl<U: UnderlierType + Pod, const N: usize> UnderlierType for ScaledUnderlier<U, N> {
161	const LOG_BITS: usize = U::LOG_BITS + checked_log_2(N);
162
163	const ZERO: Self = Self([U::ZERO; N]);
164	const ONE: Self = {
165		let mut arr = [U::ZERO; N];
166		arr[0] = U::ONE;
167		Self(arr)
168	};
169	const ONES: Self = Self([U::ONES; N]);
170
171	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
172		if log_block_len < U::LOG_BITS {
173			// Case 1: Delegate to element-wise interleave
174			let pairs: [(U, U); N] =
175				array::from_fn(|i| self.0[i].interleave(other.0[i], log_block_len));
176			(Self(array::from_fn(|i| pairs[i].0)), Self(array::from_fn(|i| pairs[i].1)))
177		} else {
178			// Case 2: Interleave at element level by swapping array elements
179			// Each super-block of 2*block_len elements gets transposed as a 2x2 matrix of blocks
180			let block_len = 1 << (log_block_len - U::LOG_BITS);
181
182			let mut a = self.0;
183			let mut b = other.0;
184			for super_block in 0..(N / (2 * block_len)) {
185				let base = super_block * 2 * block_len;
186				for offset in 0..block_len {
187					mem::swap(&mut a[base + block_len + offset], &mut b[base + offset]);
188				}
189			}
190
191			(Self(a), Self(b))
192		}
193	}
194}
195
196impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for u8
197where
198	Self: NumCast<U>,
199{
200	fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
201		Self::num_cast_from(val.0[0])
202	}
203}
204
205impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for U1
206where
207	Self: NumCast<U>,
208{
209	fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
210		Self::num_cast_from(val.0[0])
211	}
212}
213
214// `M128` is the `BinaryField128bGhash` subfield underlier; this extracts it from the low limb of a
215// `ScaledUnderlier<M128, _>`-backed extension field (`GhashSq256b` off the AVX2 path).
216impl<U: UnderlierType, const N: usize> NumCast<ScaledUnderlier<U, N>> for crate::arch::M128
217where
218	Self: NumCast<U>,
219{
220	fn num_cast_from(val: ScaledUnderlier<U, N>) -> Self {
221		Self::num_cast_from(val.0[0])
222	}
223}
224
225impl<U, const N: usize> From<u8> for ScaledUnderlier<U, N>
226where
227	U: From<u8>,
228{
229	fn from(val: u8) -> Self {
230		Self(array::from_fn(|_| U::from(val)))
231	}
232}
233
234/// Zero-extends an `M128` into the least-significant limb, leaving the rest zero.
235///
236/// This is the embedding of a base-field underlier into a `ScaledUnderlier<M128, _>`-backed
237/// extension field, as used by `impl_field_extension!`'s `from_bases_sparse` (`GhashSq256b` off
238/// the AVX2 path).
239impl<const N: usize> From<crate::arch::M128> for ScaledUnderlier<crate::arch::M128, N> {
240	fn from(val: crate::arch::M128) -> Self {
241		let mut limbs = [<crate::arch::M128 as UnderlierType>::ZERO; N];
242		limbs[0] = val;
243		Self(limbs)
244	}
245}
246
247/// Zero-extends a single bit into the least-significant `M128` limb, leaving the rest zero.
248impl<const N: usize> From<U1> for ScaledUnderlier<crate::arch::M128, N> {
249	fn from(val: U1) -> Self {
250		Self::from(crate::arch::M128::from(val))
251	}
252}
253
254impl<U: UnderlierType + LowerHex, const N: usize> LowerHex for ScaledUnderlier<U, N> {
255	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256		// Most-significant limb first. Print from the highest non-zero limb so there are no
257		// spurious leading zeros, then zero-pad each remaining limb to its full bit width.
258		let width = U::BITS / 4;
259		let top = self
260			.0
261			.iter()
262			.rposition(|limb| *limb != U::ZERO)
263			.unwrap_or(0);
264		write!(f, "{:x}", self.0[top])?;
265		for limb in self.0[..top].iter().rev() {
266			write!(f, "{limb:0width$x}")?;
267		}
268		Ok(())
269	}
270}
271
272impl<U, T, const N: usize> Divisible<T> for ScaledUnderlier<U, N>
273where
274	U: Divisible<T> + Pod + Send + Sync,
275	T: Send + 'static,
276{
277	const LOG_N: usize = <U as Divisible<T>>::LOG_N + checked_log_2(N);
278
279	#[inline]
280	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = T> + Send + Clone {
281		mapget::value_iter(value)
282	}
283
284	#[inline]
285	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
286		mapget::value_iter(*value)
287	}
288
289	#[inline]
290	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = T> + Send + Clone + '_ {
291		mapget::slice_iter(slice)
292	}
293
294	#[inline]
295	unsafe fn get_unchecked(&self, index: usize) -> T {
296		let u_index = index >> <U as Divisible<T>>::LOG_N;
297		let sub_index = index & (<U as Divisible<T>>::N - 1);
298		// Safety: `index < Self::N` by the caller's contract, so `sub_index < <U as
299		// Divisible<T>>::N` and `u_index < N`.
300		unsafe { Divisible::<T>::get_unchecked(self.0.get_unchecked(u_index), sub_index) }
301	}
302
303	#[inline]
304	unsafe fn set_unchecked(&mut self, index: usize, val: T) {
305		let u_index = index >> <U as Divisible<T>>::LOG_N;
306		let sub_index = index & (<U as Divisible<T>>::N - 1);
307		// Safety: `index < Self::N` by the caller's contract, so `sub_index < <U as
308		// Divisible<T>>::N` and `u_index < N`.
309		unsafe { Divisible::<T>::set_unchecked(self.0.get_unchecked_mut(u_index), sub_index, val) };
310	}
311
312	#[inline]
313	fn broadcast(val: T) -> Self {
314		Self([Divisible::<T>::broadcast(val); N])
315	}
316
317	#[inline]
318	fn from_iter(mut iter: impl Iterator<Item = T>) -> Self {
319		Self(array::from_fn(|_| Divisible::<T>::from_iter(&mut iter)))
320	}
321}
322
323impl<U: SerializeBytes, const N: usize> SerializeBytes for ScaledUnderlier<U, N> {
324	fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
325		self.0.serialize(write_buf)
326	}
327}
328
329impl<U: DeserializeBytes, const N: usize> DeserializeBytes for ScaledUnderlier<U, N> {
330	fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError> {
331		<[U; N]>::deserialize(read_buf).map(Self)
332	}
333}
334
335#[cfg(test)]
336mod tests {
337	use super::*;
338
339	#[test]
340	fn test_shr() {
341		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
342		assert_eq!(
343			val >> 1,
344			ScaledUnderlier::<u8, 4>([0b10000000, 0b00000000, 0b10000001, 0b00000001])
345		);
346		assert_eq!(
347			val >> 2,
348			ScaledUnderlier::<u8, 4>([0b01000000, 0b10000000, 0b11000000, 0b00000000])
349		);
350		assert_eq!(
351			val >> 8,
352			ScaledUnderlier::<u8, 4>([0b00000001, 0b00000010, 0b00000011, 0b00000000])
353		);
354		assert_eq!(
355			val >> 9,
356			ScaledUnderlier::<u8, 4>([0b00000000, 0b10000001, 0b00000001, 0b00000000])
357		);
358	}
359
360	#[test]
361	fn test_shl() {
362		let val = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
363		assert_eq!(val << 1, ScaledUnderlier::<u8, 4>([0, 2, 4, 6]));
364		assert_eq!(val << 2, ScaledUnderlier::<u8, 4>([0, 4, 8, 12]));
365		assert_eq!(val << 8, ScaledUnderlier::<u8, 4>([0, 0, 1, 2]));
366		assert_eq!(val << 9, ScaledUnderlier::<u8, 4>([0, 0, 2, 4]));
367	}
368
369	#[test]
370	fn test_interleave_within_element() {
371		// Test case 1: log_block_len < U::LOG_BITS
372		// ScaledUnderlier<u8, 4> has LOG_BITS = 5 (32 bits total)
373		// u8 has LOG_BITS = 3
374		let a = ScaledUnderlier::<u8, 4>([0b01010101, 0b11110000, 0b00001111, 0b10101010]);
375		let b = ScaledUnderlier::<u8, 4>([0b10101010, 0b00001111, 0b11110000, 0b01010101]);
376
377		// At log_block_len = 0 (1-bit blocks), should delegate to u8::interleave
378		let (c, d) = a.interleave(b, 0);
379
380		// Verify element-wise interleave occurred
381		for i in 0..4 {
382			let (expected_c, expected_d) = a.0[i].interleave(b.0[i], 0);
383			assert_eq!(c.0[i], expected_c);
384			assert_eq!(d.0[i], expected_d);
385		}
386	}
387
388	#[test]
389	fn test_interleave_across_elements() {
390		// Test case 2: log_block_len >= U::LOG_BITS
391		let a = ScaledUnderlier::<u8, 4>([0, 1, 2, 3]);
392		let b = ScaledUnderlier::<u8, 4>([4, 5, 6, 7]);
393
394		// At log_block_len = 3 (8-bit blocks = 1 element), swap individual elements
395		let (c, d) = a.interleave(b, 3);
396		assert_eq!(c.0, [0, 4, 2, 6]);
397		assert_eq!(d.0, [1, 5, 3, 7]);
398
399		// At log_block_len = 4 (16-bit blocks = 2 elements), swap pairs
400		let (c, d) = a.interleave(b, 4);
401		assert_eq!(c.0, [0, 1, 4, 5]);
402		assert_eq!(d.0, [2, 3, 6, 7]);
403	}
404}