binius_field/underlier/
small_uint.rs1use std::{
4 fmt::{Debug, Display, LowerHex},
5 hash::{Hash, Hasher},
6 ops::{Not, Shl, Shr},
7};
8
9use binius_utils::{
10 SerializationError, SerializeBytes,
11 bytes::{Buf, BufMut},
12 checked_arithmetics::checked_log_2,
13 serialization::DeserializeBytes,
14};
15use bytemuck::{NoUninit, Zeroable};
16use derive_more::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign};
17use rand::{
18 Rng,
19 distr::{Distribution, StandardUniform},
20};
21
22use super::{UnderlierType, underlier_with_bit_ops::UnderlierWithBitOps};
23use crate::arch::{interleave_mask_even, interleave_with_mask};
24
25#[derive(
27 Default,
28 Zeroable,
29 Clone,
30 Copy,
31 PartialEq,
32 Eq,
33 PartialOrd,
34 Ord,
35 BitAnd,
36 BitAndAssign,
37 BitOr,
38 BitOrAssign,
39 BitXor,
40 BitXorAssign,
41)]
42#[repr(transparent)]
43pub struct SmallU<const N: usize>(u8);
44
45impl<const N: usize> SmallU<N> {
46 const _CHECK_SIZE: () = {
47 assert!(N < 8);
48 };
49
50 pub const ONES: Self = Self((1u8 << N) - 1);
52
53 #[inline(always)]
54 pub const fn new(val: u8) -> Self {
55 Self(val & Self::ONES.0)
56 }
57
58 #[inline(always)]
59 pub const fn new_unchecked(val: u8) -> Self {
60 Self(val)
61 }
62
63 #[inline(always)]
64 pub const fn val(&self) -> u8 {
65 self.0
66 }
67
68 pub fn checked_add(self, rhs: Self) -> Option<Self> {
69 self.val()
70 .checked_add(rhs.val())
71 .and_then(|value| (value < Self::ONES.0).then_some(Self(value)))
72 }
73
74 pub fn checked_sub(self, rhs: Self) -> Option<Self> {
75 let a = self.val();
76 let b = rhs.val();
77 (b > a).then_some(Self(b - a))
78 }
79}
80
81impl<const N: usize> Debug for SmallU<N> {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 Debug::fmt(&self.val(), f)
84 }
85}
86
87impl<const N: usize> Display for SmallU<N> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 Display::fmt(&self.val(), f)
90 }
91}
92
93impl<const N: usize> LowerHex for SmallU<N> {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 LowerHex::fmt(&self.0, f)
96 }
97}
98impl<const N: usize> Hash for SmallU<N> {
99 #[inline]
100 fn hash<H: Hasher>(&self, state: &mut H) {
101 self.val().hash(state);
102 }
103}
104
105impl<const N: usize> Distribution<SmallU<N>> for StandardUniform {
106 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SmallU<N> {
107 SmallU(rng.random_range(0..1u8 << N))
108 }
109}
110
111impl<const N: usize> Shr<usize> for SmallU<N> {
112 type Output = Self;
113
114 #[inline(always)]
115 fn shr(self, rhs: usize) -> Self::Output {
116 Self(self.val() >> rhs)
117 }
118}
119
120impl<const N: usize> Shl<usize> for SmallU<N> {
121 type Output = Self;
122
123 #[inline(always)]
124 fn shl(self, rhs: usize) -> Self::Output {
125 Self(self.val() << rhs) & Self::ONES
126 }
127}
128
129impl<const N: usize> Not for SmallU<N> {
130 type Output = Self;
131
132 fn not(self) -> Self::Output {
133 self ^ Self::ONES
134 }
135}
136
137unsafe impl<const N: usize> NoUninit for SmallU<N> {}
138
139impl<const N: usize> UnderlierType for SmallU<N> {
140 const LOG_BITS: usize = checked_log_2(N);
141}
142
143impl UnderlierWithBitOps for U1 {
144 const ZERO: Self = Self(0);
145 const ONE: Self = Self(1);
146 const ONES: Self = Self(1);
147
148 fn interleave(self, _other: Self, _log_block_len: usize) -> (Self, Self) {
149 panic!("interleave not supported for U1")
150 }
151}
152
153impl UnderlierWithBitOps for U2 {
154 const ZERO: Self = Self(0);
155 const ONE: Self = Self(1);
156 const ONES: Self = Self(0b11);
157
158 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
159 const MASKS: &[U2] = &[U2::new(interleave_mask_even!(u8, 0))];
160 interleave_with_mask(self, other, log_block_len, MASKS)
161 }
162}
163
164impl UnderlierWithBitOps for U4 {
165 const ZERO: Self = Self(0);
166 const ONE: Self = Self(1);
167 const ONES: Self = Self(0b1111);
168
169 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
170 const MASKS: &[U4] = &[
171 U4::new(interleave_mask_even!(u8, 0)),
172 U4::new(interleave_mask_even!(u8, 1)),
173 ];
174 interleave_with_mask(self, other, log_block_len, MASKS)
175 }
176}
177
178impl<const N: usize> From<SmallU<N>> for u8 {
179 #[inline(always)]
180 fn from(value: SmallU<N>) -> Self {
181 value.val()
182 }
183}
184
185impl<const N: usize> From<SmallU<N>> for u16 {
186 #[inline(always)]
187 fn from(value: SmallU<N>) -> Self {
188 u8::from(value) as _
189 }
190}
191
192impl<const N: usize> From<SmallU<N>> for u32 {
193 #[inline(always)]
194 fn from(value: SmallU<N>) -> Self {
195 u8::from(value) as _
196 }
197}
198
199impl<const N: usize> From<SmallU<N>> for u64 {
200 #[inline(always)]
201 fn from(value: SmallU<N>) -> Self {
202 u8::from(value) as _
203 }
204}
205
206impl<const N: usize> From<SmallU<N>> for usize {
207 #[inline(always)]
208 fn from(value: SmallU<N>) -> Self {
209 u8::from(value) as _
210 }
211}
212
213impl<const N: usize> From<SmallU<N>> for u128 {
214 #[inline(always)]
215 fn from(value: SmallU<N>) -> Self {
216 u8::from(value) as _
217 }
218}
219
220impl From<SmallU<1>> for SmallU<2> {
221 #[inline(always)]
222 fn from(value: SmallU<1>) -> Self {
223 Self(value.val())
224 }
225}
226
227impl From<SmallU<1>> for SmallU<4> {
228 #[inline(always)]
229 fn from(value: SmallU<1>) -> Self {
230 Self(value.val())
231 }
232}
233
234impl From<SmallU<2>> for SmallU<4> {
235 #[inline(always)]
236 fn from(value: SmallU<2>) -> Self {
237 Self(value.val())
238 }
239}
240
241pub type U1 = SmallU<1>;
242pub type U2 = SmallU<2>;
243pub type U4 = SmallU<4>;
244
245impl From<bool> for U1 {
246 fn from(value: bool) -> Self {
247 Self::new_unchecked(value as u8)
248 }
249}
250
251impl From<U1> for bool {
252 fn from(value: U1) -> Self {
253 value == U1::ONE
254 }
255}
256
257impl<const N: usize> SerializeBytes for SmallU<N> {
258 fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
259 self.val().serialize(write_buf)
260 }
261}
262
263impl<const N: usize> DeserializeBytes for SmallU<N> {
264 fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError>
265 where
266 Self: Sized,
267 {
268 Ok(Self::new(DeserializeBytes::deserialize(read_buf)?))
269 }
270}
271
272#[cfg(test)]
273impl<const N: usize> proptest::arbitrary::Arbitrary for SmallU<N> {
274 type Parameters = ();
275 type Strategy = proptest::strategy::BoxedStrategy<Self>;
276
277 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
278 use proptest::strategy::Strategy;
279
280 (0u8..(1u8 << N)).prop_map(Self::new_unchecked).boxed()
281 }
282}