binius_field/underlier/
small_uint.rs1use std::{
4 fmt::{Debug, Display, LowerHex},
5 hash::{Hash, Hasher},
6 ops::{Not, Shl, Shr},
7};
8
9use binius_utils::{
10 SerializationError, SerializeBytes,
11 bytes::{Buf, BufMut},
12 checked_arithmetics::checked_log_2,
13 serialization::DeserializeBytes,
14};
15use bytemuck::{NoUninit, Zeroable};
16use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign};
17use rand::{
18 distr::{Distribution, StandardUniform},
19 prelude::*,
20};
21
22use super::UnderlierType;
23use crate::arch::{interleave_mask_even, interleave_with_mask};
24
25#[derive(
27 Default,
28 Zeroable,
29 Clone,
30 Copy,
31 PartialEq,
32 Eq,
33 PartialOrd,
34 Ord,
35 BitAnd,
36 BitAndAssign,
37 BitOr,
38 BitOrAssign,
39 BitXor,
40 BitXorAssign,
41)]
42#[repr(transparent)]
43pub struct SmallU<const N: usize>(u8);
44
45impl<const N: usize> SmallU<N> {
46 const _CHECK_SIZE: () = {
47 assert!(N < 8);
48 };
49
50 pub const ONES: Self = Self((1u8 << N) - 1);
52
53 #[inline(always)]
54 pub const fn new(val: u8) -> Self {
55 Self(val & Self::ONES.0)
56 }
57
58 #[inline(always)]
59 pub const fn new_unchecked(val: u8) -> Self {
60 Self(val)
61 }
62
63 #[inline(always)]
64 pub const fn val(&self) -> u8 {
65 self.0
66 }
67
68 pub fn checked_add(self, rhs: Self) -> Option<Self> {
69 self.val()
70 .checked_add(rhs.val())
71 .and_then(|value| (value < Self::ONES.0).then_some(Self(value)))
72 }
73
74 pub fn checked_sub(self, rhs: Self) -> Option<Self> {
75 let a = self.val();
76 let b = rhs.val();
77 (b > a).then_some(Self(b - a))
78 }
79}
80
81impl<const N: usize> Debug for SmallU<N> {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 Debug::fmt(&self.val(), f)
84 }
85}
86
87impl<const N: usize> Display for SmallU<N> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 Display::fmt(&self.val(), f)
90 }
91}
92
93impl<const N: usize> LowerHex for SmallU<N> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 LowerHex::fmt(&self.0, f)
96 }
97}
98impl<const N: usize> Hash for SmallU<N> {
99 #[inline]
100 fn hash<H: Hasher>(&self, state: &mut H) {
101 self.val().hash(state);
102 }
103}
104
105impl<const N: usize> Distribution<SmallU<N>> for StandardUniform {
106 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SmallU<N> {
107 SmallU(rng.random_range(0..1u8 << N))
108 }
109}
110
111impl<const N: usize> Shr<usize> for SmallU<N> {
112 type Output = Self;
113
114 #[inline(always)]
115 fn shr(self, rhs: usize) -> Self::Output {
116 Self(self.val() >> rhs)
117 }
118}
119
120impl<const N: usize> Shl<usize> for SmallU<N> {
121 type Output = Self;
122
123 #[inline(always)]
124 fn shl(self, rhs: usize) -> Self::Output {
125 Self(self.val() << rhs) & Self::ONES
126 }
127}
128
129impl<const N: usize> Not for SmallU<N> {
130 type Output = Self;
131
132 fn not(self) -> Self::Output {
133 self ^ Self::ONES
134 }
135}
136
137unsafe impl<const N: usize> NoUninit for SmallU<N> {}
138
139impl UnderlierType for U1 {
140 const LOG_BITS: usize = checked_log_2(1);
141
142 const ZERO: Self = Self(0);
143 const ONE: Self = Self(1);
144 const ONES: Self = Self(1);
145
146 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
147 panic!("interleave not supported for U1")
148 }
149}
150
151impl UnderlierType for U2 {
152 const LOG_BITS: usize = checked_log_2(2);
153
154 const ZERO: Self = Self(0);
155 const ONE: Self = Self(1);
156 const ONES: Self = Self(0b11);
157
158 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
159 const MASKS: &[U2] = &[U2::new(interleave_mask_even!(u8, 0))];
160 interleave_with_mask(self, other, log_block_len, MASKS)
161 }
162}
163
164impl UnderlierType for U4 {
165 const LOG_BITS: usize = checked_log_2(4);
166
167 const ZERO: Self = Self(0);
168 const ONE: Self = Self(1);
169 const ONES: Self = Self(0b1111);
170
171 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
172 const MASKS: &[U4] = &[
173 U4::new(interleave_mask_even!(u8, 0)),
174 U4::new(interleave_mask_even!(u8, 1)),
175 ];
176 interleave_with_mask(self, other, log_block_len, MASKS)
177 }
178}
179
180impl<const N: usize> From<SmallU<N>> for u8 {
181 #[inline(always)]
182 fn from(value: SmallU<N>) -> Self {
183 value.val()
184 }
185}
186
187impl<const N: usize> From<SmallU<N>> for u16 {
188 #[inline(always)]
189 fn from(value: SmallU<N>) -> Self {
190 u8::from(value) as _
191 }
192}
193
194impl<const N: usize> From<SmallU<N>> for u32 {
195 #[inline(always)]
196 fn from(value: SmallU<N>) -> Self {
197 u8::from(value) as _
198 }
199}
200
201impl<const N: usize> From<SmallU<N>> for u64 {
202 #[inline(always)]
203 fn from(value: SmallU<N>) -> Self {
204 u8::from(value) as _
205 }
206}
207
208impl<const N: usize> From<SmallU<N>> for usize {
209 #[inline(always)]
210 fn from(value: SmallU<N>) -> Self {
211 u8::from(value) as _
212 }
213}
214
215impl<const N: usize> From<SmallU<N>> for u128 {
216 #[inline(always)]
217 fn from(value: SmallU<N>) -> Self {
218 u8::from(value) as _
219 }
220}
221
222impl From<SmallU<1>> for SmallU<2> {
223 #[inline(always)]
224 fn from(value: SmallU<1>) -> Self {
225 Self(value.val())
226 }
227}
228
229impl From<SmallU<1>> for SmallU<4> {
230 #[inline(always)]
231 fn from(value: SmallU<1>) -> Self {
232 Self(value.val())
233 }
234}
235
236impl From<SmallU<2>> for SmallU<4> {
237 #[inline(always)]
238 fn from(value: SmallU<2>) -> Self {
239 Self(value.val())
240 }
241}
242
243pub type U1 = SmallU<1>;
244pub type U2 = SmallU<2>;
245pub type U4 = SmallU<4>;
246
247impl From<bool> for U1 {
248 fn from(value: bool) -> Self {
249 Self::new_unchecked(value as u8)
250 }
251}
252
253impl From<U1> for bool {
254 fn from(value: U1) -> Self {
255 value == U1::ONE
256 }
257}
258
259impl<const N: usize> SerializeBytes for SmallU<N> {
260 fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
261 self.val().serialize(write_buf)
262 }
263}
264
265impl<const N: usize> DeserializeBytes for SmallU<N> {
266 fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError>
267 where
268 Self: Sized,
269 {
270 Ok(Self::new(DeserializeBytes::deserialize(read_buf)?))
271 }
272}
273
274#[cfg(test)]
275impl<const N: usize> proptest::arbitrary::Arbitrary for SmallU<N> {
276 type Parameters = ();
277 type Strategy = proptest::strategy::BoxedStrategy<Self>;
278
279 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
280 use proptest::strategy::Strategy;
281
282 (0u8..(1u8 << N)).prop_map(Self::new_unchecked).boxed()
283 }
284}