Skip to main content

binius_field/underlier/
small_uint.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	fmt::{Debug, Display, LowerHex},
5	hash::{Hash, Hasher},
6	ops::{Not, Shl, Shr},
7};
8
9use binius_utils::{
10	SerializationError, SerializeBytes,
11	bytes::{Buf, BufMut},
12	checked_arithmetics::checked_log_2,
13	serialization::DeserializeBytes,
14};
15use bytemuck::{NoUninit, Zeroable};
16use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign};
17use rand::{
18	distr::{Distribution, StandardUniform},
19	prelude::*,
20};
21
22use super::UnderlierType;
23use crate::arch::{interleave_mask_even, interleave_with_mask};
24
25/// Unsigned type with a size strictly less than 8 bits.
26#[derive(
27	Default,
28	Zeroable,
29	Clone,
30	Copy,
31	PartialEq,
32	Eq,
33	PartialOrd,
34	Ord,
35	BitAnd,
36	BitAndAssign,
37	BitOr,
38	BitOrAssign,
39	BitXor,
40	BitXorAssign,
41)]
42#[repr(transparent)]
43pub struct SmallU<const N: usize>(u8);
44
45impl<const N: usize> SmallU<N> {
46	const _CHECK_SIZE: () = {
47		assert!(N < 8);
48	};
49
50	/// All bits set to one.
51	pub const ONES: Self = Self((1u8 << N) - 1);
52
53	#[inline(always)]
54	pub const fn new(val: u8) -> Self {
55		Self(val & Self::ONES.0)
56	}
57
58	#[inline(always)]
59	pub const fn new_unchecked(val: u8) -> Self {
60		Self(val)
61	}
62
63	#[inline(always)]
64	pub const fn val(&self) -> u8 {
65		self.0
66	}
67
68	pub fn checked_add(self, rhs: Self) -> Option<Self> {
69		self.val()
70			.checked_add(rhs.val())
71			.and_then(|value| (value < Self::ONES.0).then_some(Self(value)))
72	}
73
74	pub fn checked_sub(self, rhs: Self) -> Option<Self> {
75		let a = self.val();
76		let b = rhs.val();
77		(b > a).then_some(Self(b - a))
78	}
79}
80
81impl<const N: usize> Debug for SmallU<N> {
82	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83		Debug::fmt(&self.val(), f)
84	}
85}
86
87impl<const N: usize> Display for SmallU<N> {
88	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89		Display::fmt(&self.val(), f)
90	}
91}
92
93impl<const N: usize> LowerHex for SmallU<N> {
94	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95		LowerHex::fmt(&self.0, f)
96	}
97}
98impl<const N: usize> Hash for SmallU<N> {
99	#[inline]
100	fn hash<H: Hasher>(&self, state: &mut H) {
101		self.val().hash(state);
102	}
103}
104
105impl<const N: usize> Distribution<SmallU<N>> for StandardUniform {
106	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SmallU<N> {
107		SmallU(rng.random_range(0..1u8 << N))
108	}
109}
110
111impl<const N: usize> Shr<usize> for SmallU<N> {
112	type Output = Self;
113
114	#[inline(always)]
115	fn shr(self, rhs: usize) -> Self::Output {
116		Self(self.val() >> rhs)
117	}
118}
119
120impl<const N: usize> Shl<usize> for SmallU<N> {
121	type Output = Self;
122
123	#[inline(always)]
124	fn shl(self, rhs: usize) -> Self::Output {
125		Self(self.val() << rhs) & Self::ONES
126	}
127}
128
129impl<const N: usize> Not for SmallU<N> {
130	type Output = Self;
131
132	fn not(self) -> Self::Output {
133		self ^ Self::ONES
134	}
135}
136
137unsafe impl<const N: usize> NoUninit for SmallU<N> {}
138
139impl UnderlierType for U1 {
140	const LOG_BITS: usize = checked_log_2(1);
141
142	const ZERO: Self = Self(0);
143	const ONE: Self = Self(1);
144	const ONES: Self = Self(1);
145
146	fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
147		panic!("interleave not supported for U1")
148	}
149}
150
151impl UnderlierType for U2 {
152	const LOG_BITS: usize = checked_log_2(2);
153
154	const ZERO: Self = Self(0);
155	const ONE: Self = Self(1);
156	const ONES: Self = Self(0b11);
157
158	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
159		const MASKS: &[U2] = &[U2::new(interleave_mask_even!(u8, 0))];
160		interleave_with_mask(self, other, log_block_len, MASKS)
161	}
162}
163
164impl UnderlierType for U4 {
165	const LOG_BITS: usize = checked_log_2(4);
166
167	const ZERO: Self = Self(0);
168	const ONE: Self = Self(1);
169	const ONES: Self = Self(0b1111);
170
171	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
172		const MASKS: &[U4] = &[
173			U4::new(interleave_mask_even!(u8, 0)),
174			U4::new(interleave_mask_even!(u8, 1)),
175		];
176		interleave_with_mask(self, other, log_block_len, MASKS)
177	}
178}
179
180impl<const N: usize> From<SmallU<N>> for u8 {
181	#[inline(always)]
182	fn from(value: SmallU<N>) -> Self {
183		value.val()
184	}
185}
186
187impl<const N: usize> From<SmallU<N>> for u16 {
188	#[inline(always)]
189	fn from(value: SmallU<N>) -> Self {
190		u8::from(value) as _
191	}
192}
193
194impl<const N: usize> From<SmallU<N>> for u32 {
195	#[inline(always)]
196	fn from(value: SmallU<N>) -> Self {
197		u8::from(value) as _
198	}
199}
200
201impl<const N: usize> From<SmallU<N>> for u64 {
202	#[inline(always)]
203	fn from(value: SmallU<N>) -> Self {
204		u8::from(value) as _
205	}
206}
207
208impl<const N: usize> From<SmallU<N>> for usize {
209	#[inline(always)]
210	fn from(value: SmallU<N>) -> Self {
211		u8::from(value) as _
212	}
213}
214
215impl<const N: usize> From<SmallU<N>> for u128 {
216	#[inline(always)]
217	fn from(value: SmallU<N>) -> Self {
218		u8::from(value) as _
219	}
220}
221
222impl From<SmallU<1>> for SmallU<2> {
223	#[inline(always)]
224	fn from(value: SmallU<1>) -> Self {
225		Self(value.val())
226	}
227}
228
229impl From<SmallU<1>> for SmallU<4> {
230	#[inline(always)]
231	fn from(value: SmallU<1>) -> Self {
232		Self(value.val())
233	}
234}
235
236impl From<SmallU<2>> for SmallU<4> {
237	#[inline(always)]
238	fn from(value: SmallU<2>) -> Self {
239		Self(value.val())
240	}
241}
242
243pub type U1 = SmallU<1>;
244pub type U2 = SmallU<2>;
245pub type U4 = SmallU<4>;
246
247impl From<bool> for U1 {
248	fn from(value: bool) -> Self {
249		Self::new_unchecked(value as u8)
250	}
251}
252
253impl From<U1> for bool {
254	fn from(value: U1) -> Self {
255		value == U1::ONE
256	}
257}
258
259impl<const N: usize> SerializeBytes for SmallU<N> {
260	fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
261		self.val().serialize(write_buf)
262	}
263}
264
265impl<const N: usize> DeserializeBytes for SmallU<N> {
266	fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError>
267	where
268		Self: Sized,
269	{
270		Ok(Self::new(DeserializeBytes::deserialize(read_buf)?))
271	}
272}
273
274#[cfg(test)]
275impl<const N: usize> proptest::arbitrary::Arbitrary for SmallU<N> {
276	type Parameters = ();
277	type Strategy = proptest::strategy::BoxedStrategy<Self>;
278
279	fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
280		use proptest::strategy::Strategy;
281
282		(0u8..(1u8 << N)).prop_map(Self::new_unchecked).boxed()
283	}
284}