binius_field/underlier/
small_uint.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	fmt::{Debug, Display, LowerHex},
5	hash::{Hash, Hasher},
6	ops::{Not, Shl, Shr},
7};
8
9use binius_utils::{
10	bytes::{Buf, BufMut},
11	checked_arithmetics::checked_log_2,
12	serialization::DeserializeBytes,
13	SerializationError, SerializationMode, SerializeBytes,
14};
15use bytemuck::{NoUninit, Zeroable};
16use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign};
17use rand::{
18	distributions::{Distribution, Uniform},
19	RngCore,
20};
21use subtle::{ConditionallySelectable, ConstantTimeEq};
22
23use super::{underlier_with_bit_ops::UnderlierWithBitOps, Random, UnderlierType};
24
25/// Unsigned type with a size strictly less than 8 bits.
26#[derive(
27	Default,
28	Zeroable,
29	Clone,
30	Copy,
31	PartialEq,
32	Eq,
33	PartialOrd,
34	Ord,
35	BitAnd,
36	BitAndAssign,
37	BitOr,
38	BitOrAssign,
39	BitXor,
40	BitXorAssign,
41)]
42#[repr(transparent)]
43pub struct SmallU<const N: usize>(u8);
44
45impl<const N: usize> SmallU<N> {
46	const _CHECK_SIZE: () = {
47		assert!(N < 8);
48	};
49
50	#[inline(always)]
51	pub const fn new(val: u8) -> Self {
52		Self(val & Self::ONES.0)
53	}
54
55	#[inline(always)]
56	pub const fn new_unchecked(val: u8) -> Self {
57		Self(val)
58	}
59
60	#[inline(always)]
61	pub const fn val(&self) -> u8 {
62		self.0
63	}
64
65	pub fn checked_add(self, rhs: Self) -> Option<Self> {
66		self.val()
67			.checked_add(rhs.val())
68			.and_then(|value| (value < Self::ONES.0).then_some(Self(value)))
69	}
70
71	pub fn checked_sub(self, rhs: Self) -> Option<Self> {
72		let a = self.val();
73		let b = rhs.val();
74		(b > a).then_some(Self(b - a))
75	}
76}
77
78impl<const N: usize> Debug for SmallU<N> {
79	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80		Debug::fmt(&self.val(), f)
81	}
82}
83
84impl<const N: usize> Display for SmallU<N> {
85	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86		Display::fmt(&self.val(), f)
87	}
88}
89
90impl<const N: usize> LowerHex for SmallU<N> {
91	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92		LowerHex::fmt(&self.0, f)
93	}
94}
95impl<const N: usize> Hash for SmallU<N> {
96	#[inline]
97	fn hash<H: Hasher>(&self, state: &mut H) {
98		self.val().hash(state);
99	}
100}
101
102impl<const N: usize> ConstantTimeEq for SmallU<N> {
103	fn ct_eq(&self, other: &Self) -> subtle::Choice {
104		self.val().ct_eq(&other.val())
105	}
106}
107
108impl<const N: usize> ConditionallySelectable for SmallU<N> {
109	fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self {
110		Self(u8::conditional_select(&a.0, &b.0, choice))
111	}
112}
113
114impl<const N: usize> Random for SmallU<N> {
115	fn random(mut rng: impl RngCore) -> Self {
116		let distr = Uniform::from(0u8..1u8 << N);
117
118		Self(distr.sample(&mut rng))
119	}
120}
121
122impl<const N: usize> Shr<usize> for SmallU<N> {
123	type Output = Self;
124
125	#[inline(always)]
126	fn shr(self, rhs: usize) -> Self::Output {
127		Self(self.val() >> rhs)
128	}
129}
130
131impl<const N: usize> Shl<usize> for SmallU<N> {
132	type Output = Self;
133
134	#[inline(always)]
135	fn shl(self, rhs: usize) -> Self::Output {
136		Self(self.val() << rhs) & Self::ONES
137	}
138}
139
140impl<const N: usize> Not for SmallU<N> {
141	type Output = Self;
142
143	fn not(self) -> Self::Output {
144		self ^ Self::ONES
145	}
146}
147
148unsafe impl<const N: usize> NoUninit for SmallU<N> {}
149
150impl<const N: usize> UnderlierType for SmallU<N> {
151	const LOG_BITS: usize = checked_log_2(N);
152}
153
154impl<const N: usize> UnderlierWithBitOps for SmallU<N> {
155	const ZERO: Self = Self(0);
156	const ONE: Self = Self(1);
157	const ONES: Self = Self((1u8 << N) - 1);
158
159	fn fill_with_bit(val: u8) -> Self {
160		Self(u8::fill_with_bit(val)) & Self::ONES
161	}
162
163	fn shl_128b_lanes(self, rhs: usize) -> Self {
164		self << rhs
165	}
166
167	fn shr_128b_lanes(self, rhs: usize) -> Self {
168		self >> rhs
169	}
170}
171
172impl<const N: usize> From<SmallU<N>> for u8 {
173	#[inline(always)]
174	fn from(value: SmallU<N>) -> Self {
175		value.val()
176	}
177}
178
179impl<const N: usize> From<SmallU<N>> for u16 {
180	#[inline(always)]
181	fn from(value: SmallU<N>) -> Self {
182		u8::from(value) as _
183	}
184}
185
186impl<const N: usize> From<SmallU<N>> for u32 {
187	#[inline(always)]
188	fn from(value: SmallU<N>) -> Self {
189		u8::from(value) as _
190	}
191}
192
193impl<const N: usize> From<SmallU<N>> for u64 {
194	#[inline(always)]
195	fn from(value: SmallU<N>) -> Self {
196		u8::from(value) as _
197	}
198}
199
200impl<const N: usize> From<SmallU<N>> for usize {
201	#[inline(always)]
202	fn from(value: SmallU<N>) -> Self {
203		u8::from(value) as _
204	}
205}
206
207impl<const N: usize> From<SmallU<N>> for u128 {
208	#[inline(always)]
209	fn from(value: SmallU<N>) -> Self {
210		u8::from(value) as _
211	}
212}
213
214impl From<SmallU<1>> for SmallU<2> {
215	#[inline(always)]
216	fn from(value: SmallU<1>) -> Self {
217		Self(value.val())
218	}
219}
220
221impl From<SmallU<1>> for SmallU<4> {
222	#[inline(always)]
223	fn from(value: SmallU<1>) -> Self {
224		Self(value.val())
225	}
226}
227
228impl From<SmallU<2>> for SmallU<4> {
229	#[inline(always)]
230	fn from(value: SmallU<2>) -> Self {
231		Self(value.val())
232	}
233}
234
235pub type U1 = SmallU<1>;
236pub type U2 = SmallU<2>;
237pub type U4 = SmallU<4>;
238
239impl<const N: usize> SerializeBytes for SmallU<N> {
240	fn serialize(
241		&self,
242		write_buf: impl BufMut,
243		mode: SerializationMode,
244	) -> Result<(), SerializationError> {
245		self.val().serialize(write_buf, mode)
246	}
247}
248
249impl<const N: usize> DeserializeBytes for SmallU<N> {
250	fn deserialize(read_buf: impl Buf, mode: SerializationMode) -> Result<Self, SerializationError>
251	where
252		Self: Sized,
253	{
254		Ok(Self::new(DeserializeBytes::deserialize(read_buf, mode)?))
255	}
256}