binius_field/underlier/
small_uint.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	fmt::{Debug, Display, LowerHex},
5	hash::{Hash, Hasher},
6	ops::{Not, Shl, Shr},
7};
8
9use binius_utils::{
10	SerializationError, SerializeBytes,
11	bytes::{Buf, BufMut},
12	checked_arithmetics::checked_log_2,
13	serialization::DeserializeBytes,
14};
15use bytemuck::{NoUninit, Zeroable};
16use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign};
17use rand::{
18	Rng,
19	distr::{Distribution, StandardUniform},
20};
21
22use super::{UnderlierType, underlier_with_bit_ops::UnderlierWithBitOps};
23
24/// Unsigned type with a size strictly less than 8 bits.
25#[derive(
26	Default,
27	Zeroable,
28	Clone,
29	Copy,
30	PartialEq,
31	Eq,
32	PartialOrd,
33	Ord,
34	BitAnd,
35	BitAndAssign,
36	BitOr,
37	BitOrAssign,
38	BitXor,
39	BitXorAssign,
40)]
41#[repr(transparent)]
42pub struct SmallU<const N: usize>(u8);
43
44impl<const N: usize> SmallU<N> {
45	const _CHECK_SIZE: () = {
46		assert!(N < 8);
47	};
48
49	#[inline(always)]
50	pub const fn new(val: u8) -> Self {
51		Self(val & Self::ONES.0)
52	}
53
54	#[inline(always)]
55	pub const fn new_unchecked(val: u8) -> Self {
56		Self(val)
57	}
58
59	#[inline(always)]
60	pub const fn val(&self) -> u8 {
61		self.0
62	}
63
64	pub fn checked_add(self, rhs: Self) -> Option<Self> {
65		self.val()
66			.checked_add(rhs.val())
67			.and_then(|value| (value < Self::ONES.0).then_some(Self(value)))
68	}
69
70	pub fn checked_sub(self, rhs: Self) -> Option<Self> {
71		let a = self.val();
72		let b = rhs.val();
73		(b > a).then_some(Self(b - a))
74	}
75}
76
77impl<const N: usize> Debug for SmallU<N> {
78	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79		Debug::fmt(&self.val(), f)
80	}
81}
82
83impl<const N: usize> Display for SmallU<N> {
84	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85		Display::fmt(&self.val(), f)
86	}
87}
88
89impl<const N: usize> LowerHex for SmallU<N> {
90	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91		LowerHex::fmt(&self.0, f)
92	}
93}
94impl<const N: usize> Hash for SmallU<N> {
95	#[inline]
96	fn hash<H: Hasher>(&self, state: &mut H) {
97		self.val().hash(state);
98	}
99}
100
101impl<const N: usize> Distribution<SmallU<N>> for StandardUniform {
102	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SmallU<N> {
103		SmallU(rng.random_range(0..1u8 << N))
104	}
105}
106
107impl<const N: usize> Shr<usize> for SmallU<N> {
108	type Output = Self;
109
110	#[inline(always)]
111	fn shr(self, rhs: usize) -> Self::Output {
112		Self(self.val() >> rhs)
113	}
114}
115
116impl<const N: usize> Shl<usize> for SmallU<N> {
117	type Output = Self;
118
119	#[inline(always)]
120	fn shl(self, rhs: usize) -> Self::Output {
121		Self(self.val() << rhs) & Self::ONES
122	}
123}
124
125impl<const N: usize> Not for SmallU<N> {
126	type Output = Self;
127
128	fn not(self) -> Self::Output {
129		self ^ Self::ONES
130	}
131}
132
133unsafe impl<const N: usize> NoUninit for SmallU<N> {}
134
135impl<const N: usize> UnderlierType for SmallU<N> {
136	const LOG_BITS: usize = checked_log_2(N);
137}
138
139impl<const N: usize> UnderlierWithBitOps for SmallU<N> {
140	const ZERO: Self = Self(0);
141	const ONE: Self = Self(1);
142	const ONES: Self = Self((1u8 << N) - 1);
143
144	fn fill_with_bit(val: u8) -> Self {
145		Self(u8::fill_with_bit(val)) & Self::ONES
146	}
147
148	fn shl_128b_lanes(self, rhs: usize) -> Self {
149		self << rhs
150	}
151
152	fn shr_128b_lanes(self, rhs: usize) -> Self {
153		self >> rhs
154	}
155}
156
157impl<const N: usize> From<SmallU<N>> for u8 {
158	#[inline(always)]
159	fn from(value: SmallU<N>) -> Self {
160		value.val()
161	}
162}
163
164impl<const N: usize> From<SmallU<N>> for u16 {
165	#[inline(always)]
166	fn from(value: SmallU<N>) -> Self {
167		u8::from(value) as _
168	}
169}
170
171impl<const N: usize> From<SmallU<N>> for u32 {
172	#[inline(always)]
173	fn from(value: SmallU<N>) -> Self {
174		u8::from(value) as _
175	}
176}
177
178impl<const N: usize> From<SmallU<N>> for u64 {
179	#[inline(always)]
180	fn from(value: SmallU<N>) -> Self {
181		u8::from(value) as _
182	}
183}
184
185impl<const N: usize> From<SmallU<N>> for usize {
186	#[inline(always)]
187	fn from(value: SmallU<N>) -> Self {
188		u8::from(value) as _
189	}
190}
191
192impl<const N: usize> From<SmallU<N>> for u128 {
193	#[inline(always)]
194	fn from(value: SmallU<N>) -> Self {
195		u8::from(value) as _
196	}
197}
198
199impl From<SmallU<1>> for SmallU<2> {
200	#[inline(always)]
201	fn from(value: SmallU<1>) -> Self {
202		Self(value.val())
203	}
204}
205
206impl From<SmallU<1>> for SmallU<4> {
207	#[inline(always)]
208	fn from(value: SmallU<1>) -> Self {
209		Self(value.val())
210	}
211}
212
213impl From<SmallU<2>> for SmallU<4> {
214	#[inline(always)]
215	fn from(value: SmallU<2>) -> Self {
216		Self(value.val())
217	}
218}
219
220pub type U1 = SmallU<1>;
221pub type U2 = SmallU<2>;
222pub type U4 = SmallU<4>;
223
224impl From<bool> for U1 {
225	fn from(value: bool) -> Self {
226		Self::new_unchecked(value as u8)
227	}
228}
229
230impl From<U1> for bool {
231	fn from(value: U1) -> Self {
232		value == U1::ONE
233	}
234}
235
236impl<const N: usize> SerializeBytes for SmallU<N> {
237	fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
238		self.val().serialize(write_buf)
239	}
240}
241
242impl<const N: usize> DeserializeBytes for SmallU<N> {
243	fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError>
244	where
245		Self: Sized,
246	{
247		Ok(Self::new(DeserializeBytes::deserialize(read_buf)?))
248	}
249}
250
251#[cfg(test)]
252impl<const N: usize> proptest::arbitrary::Arbitrary for SmallU<N> {
253	type Parameters = ();
254	type Strategy = proptest::strategy::BoxedStrategy<Self>;
255
256	fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
257		use proptest::strategy::Strategy;
258
259		(0u8..(1u8 << N)).prop_map(Self::new_unchecked).boxed()
260	}
261}