1use std::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr};
4
5use binius_utils::checked_arithmetics::checked_int_div;
6
7use super::{
8 U1, U2, U4,
9 underlier_type::{NumCast, UnderlierType},
10};
11use crate::Divisible;
12
13pub trait UnderlierWithBitOps:
15 UnderlierType
16 + BitAnd<Self, Output = Self>
17 + BitAndAssign<Self>
18 + BitOr<Self, Output = Self>
19 + BitOrAssign<Self>
20 + BitXor<Self, Output = Self>
21 + BitXorAssign<Self>
22 + Shr<usize, Output = Self>
23 + Shl<usize, Output = Self>
24 + Not<Output = Self>
25{
26 const ZERO: Self;
27 const ONE: Self;
28 const ONES: Self;
29
30 fn fill_with_bit(val: u8) -> Self;
33
34 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self);
36
37 fn transpose(mut self, mut other: Self, log_block_len: usize) -> (Self, Self) {
39 assert!(log_block_len < Self::LOG_BITS);
40
41 for log_block_len in (log_block_len..Self::LOG_BITS).rev() {
42 (self, other) = self.interleave(other, log_block_len);
43 }
44
45 (self, other)
46 }
47
48 #[inline]
49 fn from_fn<T>(f: impl FnMut(usize) -> T) -> Self
50 where
51 T: UnderlierType,
52 Self: Divisible<T>,
53 {
54 Self::from_iter((0..Self::N).map(f))
55 }
56
57 #[inline]
60 fn broadcast_subvalue<T>(value: T) -> Self
61 where
62 T: UnderlierType,
63 Self: Divisible<T>,
64 {
65 Divisible::<T>::broadcast(value)
66 }
67
68 #[inline]
74 unsafe fn get_subvalue<T>(&self, i: usize) -> T
75 where
76 T: UnderlierType,
77 Self: Divisible<T>,
78 {
79 debug_assert!(
80 i < checked_int_div(Self::BITS, T::BITS),
81 "i: {} Self::BITS: {}, T::BITS: {}",
82 i,
83 Self::BITS,
84 T::BITS
85 );
86 Divisible::<T>::get(*self, i)
87 }
88
89 #[inline]
95 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
96 where
97 T: UnderlierWithBitOps,
98 Self: Divisible<T>,
99 {
100 debug_assert!(i < checked_int_div(Self::BITS, T::BITS));
101 *self = (*self).set(i, val);
102 }
103
104 #[inline]
111 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
112 where
113 T: UnderlierWithBitOps,
114 Self: Divisible<T>,
115 {
116 unsafe { spread_fallback(self, log_block_len, block_idx) }
117 }
118}
119
120pub(crate) unsafe fn spread_fallback<U, T>(value: U, log_block_len: usize, block_idx: usize) -> U
126where
127 U: UnderlierWithBitOps + Divisible<T>,
128 T: UnderlierWithBitOps,
129{
130 debug_assert!(
131 log_block_len + T::LOG_BITS <= U::LOG_BITS,
132 "log_block_len: {}, U::BITS: {}, T::BITS: {}",
133 log_block_len,
134 U::BITS,
135 T::BITS
136 );
137 debug_assert!(
138 block_idx < 1 << (U::LOG_BITS - log_block_len),
139 "block_idx: {}, U::BITS: {}, log_block_len: {}",
140 block_idx,
141 U::BITS,
142 log_block_len
143 );
144
145 let mut result = U::ZERO;
146 let block_offset = block_idx << log_block_len;
147 let log_repeat = U::LOG_BITS - T::LOG_BITS - log_block_len;
148 for i in 0..1 << log_block_len {
149 unsafe {
150 result.set_subvalue(i << log_repeat, value.get_subvalue(block_offset + i));
151 }
152 }
153
154 for i in 0..log_repeat {
155 result |= result << (1 << (T::LOG_BITS + i));
156 }
157
158 result
159}
160
161#[cfg(test)]
162#[allow(unused)]
163pub(crate) fn single_element_mask_bits<T: UnderlierWithBitOps>(bits_count: usize) -> T {
164 use binius_utils::checked_arithmetics::checked_log_2;
165
166 if bits_count == T::BITS {
167 !T::ZERO
168 } else {
169 let mut result = T::ONE;
170 for height in 0..checked_log_2(bits_count) {
171 result |= result << (1 << height)
172 }
173
174 result
175 }
176}
177
178pub(crate) trait SpreadToByte {
180 fn spread_to_byte(self) -> u8;
181}
182
183impl SpreadToByte for U1 {
184 #[inline(always)]
185 fn spread_to_byte(self) -> u8 {
186 u8::fill_with_bit(self.val())
187 }
188}
189
190impl SpreadToByte for U2 {
191 #[inline(always)]
192 fn spread_to_byte(self) -> u8 {
193 let mut result = self.val();
194 result |= result << 2;
195 result |= result << 4;
196
197 result
198 }
199}
200
201impl SpreadToByte for U4 {
202 #[inline(always)]
203 fn spread_to_byte(self) -> u8 {
204 let mut result = self.val();
205 result |= result << 4;
206
207 result
208 }
209}
210
211#[allow(unused)]
216#[inline(always)]
217pub(crate) unsafe fn get_block_values<U, T, const BLOCK_LEN: usize>(
218 value: U,
219 block_idx: usize,
220) -> [T; BLOCK_LEN]
221where
222 U: UnderlierWithBitOps + From<T> + Divisible<T>,
223 T: UnderlierType + NumCast<U>,
224{
225 std::array::from_fn(|i| unsafe { value.get_subvalue::<T>(block_idx * BLOCK_LEN + i) })
226}
227
228#[allow(unused)]
233#[inline(always)]
234pub(crate) unsafe fn get_spread_bytes<U, T, const BLOCK_LEN: usize>(
235 value: U,
236 block_idx: usize,
237) -> [u8; BLOCK_LEN]
238where
239 U: UnderlierWithBitOps + From<T> + Divisible<T>,
240 T: UnderlierType + SpreadToByte + NumCast<U>,
241{
242 unsafe { get_block_values::<U, T, BLOCK_LEN>(value, block_idx) }
243 .map(SpreadToByte::spread_to_byte)
244}
245
246#[cfg(test)]
247mod tests {
248 use proptest::{arbitrary::any, bits, proptest};
249
250 use super::{
251 super::small_uint::{U1, U2, U4},
252 *,
253 };
254
255 #[test]
256 fn test_from_fn() {
257 assert_eq!(u32::from_fn(|_| U1::new(0)), 0);
258 assert_eq!(u32::from_fn(|i| U1::new((i % 2) as u8)), 0xaaaaaaaa);
259 assert_eq!(u32::from_fn(|_| U1::new(1)), u32::MAX);
260
261 assert_eq!(u32::from_fn(|_| U2::new(0)), 0);
262 assert_eq!(u32::from_fn(|_| U2::new(1)), 0x55555555);
263 assert_eq!(u32::from_fn(|_| U2::new(2)), 0xaaaaaaaa);
264 assert_eq!(u32::from_fn(|_| U2::new(3)), u32::MAX);
265 assert_eq!(u32::from_fn(|i| U2::new((i % 4) as u8)), 0xe4e4e4e4);
266
267 assert_eq!(u32::from_fn(|_| U4::new(0)), 0);
268 assert_eq!(u32::from_fn(|_| U4::new(1)), 0x11111111);
269 assert_eq!(u32::from_fn(|_| U4::new(8)), 0x88888888);
270 assert_eq!(u32::from_fn(|_| U4::new(31)), 0xffffffff);
271 assert_eq!(u32::from_fn(|i| U4::new(i as u8)), 0x76543210);
272
273 assert_eq!(u32::from_fn(|_| 0u8), 0);
274 assert_eq!(u32::from_fn(|_| 0xabu8), 0xabababab);
275 assert_eq!(u32::from_fn(|_| 255u8), 0xffffffff);
276 assert_eq!(u32::from_fn(|i| i as u8), 0x03020100);
277 }
278
279 #[test]
280 fn test_broadcast_subvalue() {
281 assert_eq!(u32::broadcast_subvalue(U1::new(0)), 0);
282 assert_eq!(u32::broadcast_subvalue(U1::new(1)), u32::MAX);
283
284 assert_eq!(u32::broadcast_subvalue(U2::new(0)), 0);
285 assert_eq!(u32::broadcast_subvalue(U2::new(1)), 0x55555555);
286 assert_eq!(u32::broadcast_subvalue(U2::new(2)), 0xaaaaaaaa);
287 assert_eq!(u32::broadcast_subvalue(U2::new(3)), u32::MAX);
288
289 assert_eq!(u32::broadcast_subvalue(U4::new(0)), 0);
290 assert_eq!(u32::broadcast_subvalue(U4::new(1)), 0x11111111);
291 assert_eq!(u32::broadcast_subvalue(U4::new(8)), 0x88888888);
292 assert_eq!(u32::broadcast_subvalue(U4::new(31)), 0xffffffff);
293
294 assert_eq!(u32::broadcast_subvalue(0u8), 0);
295 assert_eq!(u32::broadcast_subvalue(0xabu8), 0xabababab);
296 assert_eq!(u32::broadcast_subvalue(255u8), 0xffffffff);
297 }
298
299 #[test]
300 fn test_get_subvalue() {
301 let value = 0xab12cd34u32;
302
303 unsafe {
304 assert_eq!(value.get_subvalue::<U1>(0), U1::new(0));
305 assert_eq!(value.get_subvalue::<U1>(1), U1::new(0));
306 assert_eq!(value.get_subvalue::<U1>(2), U1::new(1));
307 assert_eq!(value.get_subvalue::<U1>(31), U1::new(1));
308
309 assert_eq!(value.get_subvalue::<U2>(0), U2::new(0));
310 assert_eq!(value.get_subvalue::<U2>(1), U2::new(1));
311 assert_eq!(value.get_subvalue::<U2>(2), U2::new(3));
312 assert_eq!(value.get_subvalue::<U2>(15), U2::new(2));
313
314 assert_eq!(value.get_subvalue::<U4>(0), U4::new(4));
315 assert_eq!(value.get_subvalue::<U4>(1), U4::new(3));
316 assert_eq!(value.get_subvalue::<U4>(2), U4::new(13));
317 assert_eq!(value.get_subvalue::<U4>(7), U4::new(10));
318
319 assert_eq!(value.get_subvalue::<u8>(0), 0x34u8);
320 assert_eq!(value.get_subvalue::<u8>(1), 0xcdu8);
321 assert_eq!(value.get_subvalue::<u8>(2), 0x12u8);
322 assert_eq!(value.get_subvalue::<u8>(3), 0xabu8);
323 }
324 }
325
326 proptest! {
327 #[test]
328 fn test_set_subvalue_1b(mut init_val in any::<u32>(), i in 0usize..31, val in bits::u8::masked(1)) {
329 unsafe {
330 init_val.set_subvalue(i, U1::new(val));
331 assert_eq!(init_val.get_subvalue::<U1>(i), U1::new(val));
332 }
333 }
334
335 #[test]
336 fn test_set_subvalue_2b(mut init_val in any::<u32>(), i in 0usize..15, val in bits::u8::masked(3)) {
337 unsafe {
338 init_val.set_subvalue(i, U2::new(val));
339 assert_eq!(init_val.get_subvalue::<U2>(i), U2::new(val));
340 }
341 }
342
343 #[test]
344 fn test_set_subvalue_4b(mut init_val in any::<u32>(), i in 0usize..7, val in bits::u8::masked(7)) {
345 unsafe {
346 init_val.set_subvalue(i, U4::new(val));
347 assert_eq!(init_val.get_subvalue::<U4>(i), U4::new(val));
348 }
349 }
350
351 #[test]
352 fn test_set_subvalue_8b(mut init_val in any::<u32>(), i in 0usize..3, val in bits::u8::masked(15)) {
353 unsafe {
354 init_val.set_subvalue(i, val);
355 assert_eq!(init_val.get_subvalue::<u8>(i), val);
356 }
357 }
358 }
359}