binius_field/underlier/
small_uint.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	fmt::{Debug, Display, LowerHex},
5	hash::{Hash, Hasher},
6	ops::{Not, Shl, Shr},
7};
8
9use binius_utils::{
10	SerializationError, SerializeBytes,
11	bytes::{Buf, BufMut},
12	checked_arithmetics::checked_log_2,
13	serialization::DeserializeBytes,
14};
15use bytemuck::{NoUninit, Zeroable};
16use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign};
17use rand::{
18	Rng,
19	distr::{Distribution, StandardUniform},
20};
21
22use super::{UnderlierType, underlier_with_bit_ops::UnderlierWithBitOps};
23use crate::arch::{interleave_mask_even, interleave_with_mask};
24
25/// Unsigned type with a size strictly less than 8 bits.
26#[derive(
27	Default,
28	Zeroable,
29	Clone,
30	Copy,
31	PartialEq,
32	Eq,
33	PartialOrd,
34	Ord,
35	BitAnd,
36	BitAndAssign,
37	BitOr,
38	BitOrAssign,
39	BitXor,
40	BitXorAssign,
41)]
42#[repr(transparent)]
43pub struct SmallU<const N: usize>(u8);
44
45impl<const N: usize> SmallU<N> {
46	const _CHECK_SIZE: () = {
47		assert!(N < 8);
48	};
49
50	/// All bits set to one.
51	pub const ONES: Self = Self((1u8 << N) - 1);
52
53	#[inline(always)]
54	pub const fn new(val: u8) -> Self {
55		Self(val & Self::ONES.0)
56	}
57
58	#[inline(always)]
59	pub const fn new_unchecked(val: u8) -> Self {
60		Self(val)
61	}
62
63	#[inline(always)]
64	pub const fn val(&self) -> u8 {
65		self.0
66	}
67
68	pub fn checked_add(self, rhs: Self) -> Option<Self> {
69		self.val()
70			.checked_add(rhs.val())
71			.and_then(|value| (value < Self::ONES.0).then_some(Self(value)))
72	}
73
74	pub fn checked_sub(self, rhs: Self) -> Option<Self> {
75		let a = self.val();
76		let b = rhs.val();
77		(b > a).then_some(Self(b - a))
78	}
79}
80
81impl<const N: usize> Debug for SmallU<N> {
82	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83		Debug::fmt(&self.val(), f)
84	}
85}
86
87impl<const N: usize> Display for SmallU<N> {
88	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89		Display::fmt(&self.val(), f)
90	}
91}
92
93impl<const N: usize> LowerHex for SmallU<N> {
94	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95		LowerHex::fmt(&self.0, f)
96	}
97}
98impl<const N: usize> Hash for SmallU<N> {
99	#[inline]
100	fn hash<H: Hasher>(&self, state: &mut H) {
101		self.val().hash(state);
102	}
103}
104
105impl<const N: usize> Distribution<SmallU<N>> for StandardUniform {
106	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SmallU<N> {
107		SmallU(rng.random_range(0..1u8 << N))
108	}
109}
110
111impl<const N: usize> Shr<usize> for SmallU<N> {
112	type Output = Self;
113
114	#[inline(always)]
115	fn shr(self, rhs: usize) -> Self::Output {
116		Self(self.val() >> rhs)
117	}
118}
119
120impl<const N: usize> Shl<usize> for SmallU<N> {
121	type Output = Self;
122
123	#[inline(always)]
124	fn shl(self, rhs: usize) -> Self::Output {
125		Self(self.val() << rhs) & Self::ONES
126	}
127}
128
129impl<const N: usize> Not for SmallU<N> {
130	type Output = Self;
131
132	fn not(self) -> Self::Output {
133		self ^ Self::ONES
134	}
135}
136
137unsafe impl<const N: usize> NoUninit for SmallU<N> {}
138
139impl<const N: usize> UnderlierType for SmallU<N> {
140	const LOG_BITS: usize = checked_log_2(N);
141}
142
143impl UnderlierWithBitOps for U1 {
144	const ZERO: Self = Self(0);
145	const ONE: Self = Self(1);
146	const ONES: Self = Self(1);
147
148	fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
149		panic!("interleave not supported for U1")
150	}
151}
152
153impl UnderlierWithBitOps for U2 {
154	const ZERO: Self = Self(0);
155	const ONE: Self = Self(1);
156	const ONES: Self = Self(0b11);
157
158	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
159		const MASKS: &[U2] = &[U2::new(interleave_mask_even!(u8, 0))];
160		interleave_with_mask(self, other, log_block_len, MASKS)
161	}
162}
163
164impl UnderlierWithBitOps for U4 {
165	const ZERO: Self = Self(0);
166	const ONE: Self = Self(1);
167	const ONES: Self = Self(0b1111);
168
169	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
170		const MASKS: &[U4] = &[
171			U4::new(interleave_mask_even!(u8, 0)),
172			U4::new(interleave_mask_even!(u8, 1)),
173		];
174		interleave_with_mask(self, other, log_block_len, MASKS)
175	}
176}
177
178impl<const N: usize> From<SmallU<N>> for u8 {
179	#[inline(always)]
180	fn from(value: SmallU<N>) -> Self {
181		value.val()
182	}
183}
184
185impl<const N: usize> From<SmallU<N>> for u16 {
186	#[inline(always)]
187	fn from(value: SmallU<N>) -> Self {
188		u8::from(value) as _
189	}
190}
191
192impl<const N: usize> From<SmallU<N>> for u32 {
193	#[inline(always)]
194	fn from(value: SmallU<N>) -> Self {
195		u8::from(value) as _
196	}
197}
198
199impl<const N: usize> From<SmallU<N>> for u64 {
200	#[inline(always)]
201	fn from(value: SmallU<N>) -> Self {
202		u8::from(value) as _
203	}
204}
205
206impl<const N: usize> From<SmallU<N>> for usize {
207	#[inline(always)]
208	fn from(value: SmallU<N>) -> Self {
209		u8::from(value) as _
210	}
211}
212
213impl<const N: usize> From<SmallU<N>> for u128 {
214	#[inline(always)]
215	fn from(value: SmallU<N>) -> Self {
216		u8::from(value) as _
217	}
218}
219
220impl From<SmallU<1>> for SmallU<2> {
221	#[inline(always)]
222	fn from(value: SmallU<1>) -> Self {
223		Self(value.val())
224	}
225}
226
227impl From<SmallU<1>> for SmallU<4> {
228	#[inline(always)]
229	fn from(value: SmallU<1>) -> Self {
230		Self(value.val())
231	}
232}
233
234impl From<SmallU<2>> for SmallU<4> {
235	#[inline(always)]
236	fn from(value: SmallU<2>) -> Self {
237		Self(value.val())
238	}
239}
240
241pub type U1 = SmallU<1>;
242pub type U2 = SmallU<2>;
243pub type U4 = SmallU<4>;
244
245impl From<bool> for U1 {
246	fn from(value: bool) -> Self {
247		Self::new_unchecked(value as u8)
248	}
249}
250
251impl From<U1> for bool {
252	fn from(value: U1) -> Self {
253		value == U1::ONE
254	}
255}
256
257impl<const N: usize> SerializeBytes for SmallU<N> {
258	fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
259		self.val().serialize(write_buf)
260	}
261}
262
263impl<const N: usize> DeserializeBytes for SmallU<N> {
264	fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError>
265	where
266		Self: Sized,
267	{
268		Ok(Self::new(DeserializeBytes::deserialize(read_buf)?))
269	}
270}
271
272#[cfg(test)]
273impl<const N: usize> proptest::arbitrary::Arbitrary for SmallU<N> {
274	type Parameters = ();
275	type Strategy = proptest::strategy::BoxedStrategy<Self>;
276
277	fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
278		use proptest::strategy::Strategy;
279
280		(0u8..(1u8 << N)).prop_map(Self::new_unchecked).boxed()
281	}
282}