1use std::{
4 arch::x86_64::*,
5 array,
6 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializeBytes,
11 bytes::{Buf, BufMut},
12 serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{
16 Rng,
17 distr::{Distribution, StandardUniform},
18};
19use seq_macro::seq;
20
21use crate::{
22 BinaryField,
23 arch::{
24 binary_utils::make_func_to_i8,
25 portable::{
26 packed::{PackedPrimitiveType, impl_pack_scalar},
27 packed_arithmetic::{
28 UnderlierWithBitConstants, interleave_mask_even, interleave_mask_odd,
29 },
30 },
31 },
32 arithmetic_traits::Broadcast,
33 underlier::{
34 Divisible, NumCast, SmallU, SpreadToByte, U2, U4, UnderlierType, UnderlierWithBitOps,
35 WithUnderlier, impl_divisible_bitmask, mapget, spread_fallback,
36 },
37};
38
39#[derive(Copy, Clone)]
41#[repr(transparent)]
42pub struct M128(pub(super) __m128i);
43
44impl M128 {
45 #[inline(always)]
46 pub const fn from_u128(val: u128) -> Self {
47 let mut result = Self::ZERO;
48 unsafe {
49 result.0 = std::mem::transmute_copy(&val);
50 }
51
52 result
53 }
54}
55
56impl From<__m128i> for M128 {
57 #[inline(always)]
58 fn from(value: __m128i) -> Self {
59 Self(value)
60 }
61}
62
63impl From<u128> for M128 {
64 fn from(value: u128) -> Self {
65 Self(unsafe {
66 assert_eq!(align_of::<u128>(), 16);
68 _mm_load_si128(&raw const value as *const __m128i)
69 })
70 }
71}
72
73impl From<u64> for M128 {
74 fn from(value: u64) -> Self {
75 Self::from(value as u128)
76 }
77}
78
79impl From<u32> for M128 {
80 fn from(value: u32) -> Self {
81 Self::from(value as u128)
82 }
83}
84
85impl From<u16> for M128 {
86 fn from(value: u16) -> Self {
87 Self::from(value as u128)
88 }
89}
90
91impl From<u8> for M128 {
92 fn from(value: u8) -> Self {
93 Self::from(value as u128)
94 }
95}
96
97impl<const N: usize> From<SmallU<N>> for M128 {
98 fn from(value: SmallU<N>) -> Self {
99 Self::from(value.val() as u128)
100 }
101}
102
103impl From<M128> for u128 {
104 fn from(value: M128) -> Self {
105 let mut result = 0u128;
106 unsafe {
107 assert_eq!(align_of::<u128>(), 16);
109 _mm_store_si128(&raw mut result as *mut __m128i, value.0)
110 };
111 result
112 }
113}
114
115impl From<M128> for __m128i {
116 #[inline(always)]
117 fn from(value: M128) -> Self {
118 value.0
119 }
120}
121
122impl SerializeBytes for M128 {
123 fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
124 assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
125
126 let raw_value: u128 = (*self).into();
127
128 write_buf.put_u128_le(raw_value);
129 Ok(())
130 }
131}
132
133impl DeserializeBytes for M128 {
134 fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
135 where
136 Self: Sized,
137 {
138 assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
139
140 let raw_value = read_buf.get_u128_le();
141
142 Ok(Self::from(raw_value))
143 }
144}
145
146impl_divisible_bitmask!(M128, 1, 2, 4);
147impl_pack_scalar!(M128);
148
149impl<U: NumCast<u128>> NumCast<M128> for U {
150 #[inline(always)]
151 fn num_cast_from(val: M128) -> Self {
152 Self::num_cast_from(u128::from(val))
153 }
154}
155
156impl Default for M128 {
157 #[inline(always)]
158 fn default() -> Self {
159 Self(unsafe { _mm_setzero_si128() })
160 }
161}
162
163impl BitAnd for M128 {
164 type Output = Self;
165
166 #[inline(always)]
167 fn bitand(self, rhs: Self) -> Self::Output {
168 Self(unsafe { _mm_and_si128(self.0, rhs.0) })
169 }
170}
171
172impl BitAndAssign for M128 {
173 #[inline(always)]
174 fn bitand_assign(&mut self, rhs: Self) {
175 *self = *self & rhs
176 }
177}
178
179impl BitOr for M128 {
180 type Output = Self;
181
182 #[inline(always)]
183 fn bitor(self, rhs: Self) -> Self::Output {
184 Self(unsafe { _mm_or_si128(self.0, rhs.0) })
185 }
186}
187
188impl BitOrAssign for M128 {
189 #[inline(always)]
190 fn bitor_assign(&mut self, rhs: Self) {
191 *self = *self | rhs
192 }
193}
194
195impl BitXor for M128 {
196 type Output = Self;
197
198 #[inline(always)]
199 fn bitxor(self, rhs: Self) -> Self::Output {
200 Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
201 }
202}
203
204impl BitXorAssign for M128 {
205 #[inline(always)]
206 fn bitxor_assign(&mut self, rhs: Self) {
207 *self = *self ^ rhs;
208 }
209}
210
211impl Not for M128 {
212 type Output = Self;
213
214 fn not(self) -> Self::Output {
215 const ONES: __m128i = m128_from_u128!(u128::MAX);
216
217 self ^ Self(ONES)
218 }
219}
220
221pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
223 if left > right { left } else { right }
224}
225
226macro_rules! bitshift_128b {
231 ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
232 unsafe {
233 let carry = $byte_shift($val, 8);
234 seq!(N in 64..128 {
235 if $shift == N {
236 return $bit_shift_64(
237 carry,
238 crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
239 ).into();
240 }
241 });
242 seq!(N in 0..64 {
243 if $shift == N {
244 let carry = $bit_shift_64_opposite(
245 carry,
246 crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
247 );
248
249 let val = $bit_shift_64($val, N);
250 return $or(val, carry).into();
251 }
252 });
253
254 return Default::default()
255 }
256 };
257}
258
259pub(crate) use bitshift_128b;
260
261impl Shr<usize> for M128 {
262 type Output = Self;
263
264 #[inline(always)]
265 fn shr(self, rhs: usize) -> Self::Output {
266 bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
269 }
270}
271
272impl Shl<usize> for M128 {
273 type Output = Self;
274
275 #[inline(always)]
276 fn shl(self, rhs: usize) -> Self::Output {
277 bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
280 }
281}
282
283impl PartialEq for M128 {
284 fn eq(&self, other: &Self) -> bool {
285 unsafe {
286 let neq = _mm_xor_si128(self.0, other.0);
287 _mm_test_all_zeros(neq, neq) == 1
288 }
289 }
290}
291
292impl Eq for M128 {}
293
294impl PartialOrd for M128 {
295 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
296 Some(self.cmp(other))
297 }
298}
299
300impl Ord for M128 {
301 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
302 u128::from(*self).cmp(&u128::from(*other))
303 }
304}
305
306impl Distribution<M128> for StandardUniform {
307 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
308 M128(rng.random())
309 }
310}
311
312impl std::fmt::Display for M128 {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 let data: u128 = (*self).into();
315 write!(f, "{data:02X?}")
316 }
317}
318
319impl std::fmt::Debug for M128 {
320 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321 write!(f, "M128({self})")
322 }
323}
324
325#[repr(align(16))]
326pub struct AlignedData(pub [u128; 1]);
327
328macro_rules! m128_from_u128 {
329 ($val:expr) => {{
330 let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
331 unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
332 }};
333}
334
335pub(super) use m128_from_u128;
336
337impl UnderlierType for M128 {
338 const LOG_BITS: usize = 7;
339}
340
341impl UnderlierWithBitOps for M128 {
342 const ZERO: Self = { Self(m128_from_u128!(0)) };
343 const ONE: Self = { Self(m128_from_u128!(1)) };
344 const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
345
346 #[inline(always)]
347 fn fill_with_bit(val: u8) -> Self {
348 assert!(val == 0 || val == 1);
349 Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
350 }
351
352 #[inline(always)]
353 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
354 where
355 T: UnderlierType,
356 Self: From<T>,
357 {
358 match T::BITS {
359 1 | 2 | 4 => {
360 let mut f = make_func_to_i8::<T, Self>(f);
361
362 unsafe {
363 _mm_set_epi8(
364 f(15),
365 f(14),
366 f(13),
367 f(12),
368 f(11),
369 f(10),
370 f(9),
371 f(8),
372 f(7),
373 f(6),
374 f(5),
375 f(4),
376 f(3),
377 f(2),
378 f(1),
379 f(0),
380 )
381 }
382 .into()
383 }
384 8 => {
385 let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
386 unsafe {
387 _mm_set_epi8(
388 f(15),
389 f(14),
390 f(13),
391 f(12),
392 f(11),
393 f(10),
394 f(9),
395 f(8),
396 f(7),
397 f(6),
398 f(5),
399 f(4),
400 f(3),
401 f(2),
402 f(1),
403 f(0),
404 )
405 }
406 .into()
407 }
408 16 => {
409 let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
410 unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
411 }
412 32 => {
413 let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
414 unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
415 }
416 64 => {
417 let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
418 unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
419 }
420 128 => Self::from(f(0)),
421 _ => panic!("unsupported bit count"),
422 }
423 }
424
425 #[inline(always)]
426 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
427 where
428 T: UnderlierWithBitOps + NumCast<Self>,
429 Self: Divisible<T> + From<T>,
430 {
431 match T::LOG_BITS {
432 0 => match log_block_len {
433 0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
434 1 => unsafe {
435 let bits: [u8; 2] =
436 array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
437
438 _mm_set_epi64x(
439 u64::fill_with_bit(bits[1]) as i64,
440 u64::fill_with_bit(bits[0]) as i64,
441 )
442 .into()
443 },
444 2 => unsafe {
445 let bits: [u8; 4] =
446 array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
447
448 _mm_set_epi32(
449 u32::fill_with_bit(bits[3]) as i32,
450 u32::fill_with_bit(bits[2]) as i32,
451 u32::fill_with_bit(bits[1]) as i32,
452 u32::fill_with_bit(bits[0]) as i32,
453 )
454 .into()
455 },
456 3 => unsafe {
457 let bits: [u8; 8] =
458 array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
459
460 _mm_set_epi16(
461 u16::fill_with_bit(bits[7]) as i16,
462 u16::fill_with_bit(bits[6]) as i16,
463 u16::fill_with_bit(bits[5]) as i16,
464 u16::fill_with_bit(bits[4]) as i16,
465 u16::fill_with_bit(bits[3]) as i16,
466 u16::fill_with_bit(bits[2]) as i16,
467 u16::fill_with_bit(bits[1]) as i16,
468 u16::fill_with_bit(bits[0]) as i16,
469 )
470 .into()
471 },
472 4 => unsafe {
473 let bits: [u8; 16] =
474 array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
475
476 _mm_set_epi8(
477 u8::fill_with_bit(bits[15]) as i8,
478 u8::fill_with_bit(bits[14]) as i8,
479 u8::fill_with_bit(bits[13]) as i8,
480 u8::fill_with_bit(bits[12]) as i8,
481 u8::fill_with_bit(bits[11]) as i8,
482 u8::fill_with_bit(bits[10]) as i8,
483 u8::fill_with_bit(bits[9]) as i8,
484 u8::fill_with_bit(bits[8]) as i8,
485 u8::fill_with_bit(bits[7]) as i8,
486 u8::fill_with_bit(bits[6]) as i8,
487 u8::fill_with_bit(bits[5]) as i8,
488 u8::fill_with_bit(bits[4]) as i8,
489 u8::fill_with_bit(bits[3]) as i8,
490 u8::fill_with_bit(bits[2]) as i8,
491 u8::fill_with_bit(bits[1]) as i8,
492 u8::fill_with_bit(bits[0]) as i8,
493 )
494 .into()
495 },
496 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
497 },
498 1 => match log_block_len {
499 0 => unsafe {
500 let value =
501 U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
502
503 _mm_set1_epi8(value as i8).into()
504 },
505 1 => {
506 let bytes: [u8; 2] = array::from_fn(|i| {
507 U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
508 });
509
510 Self::from_fn::<u8>(|i| bytes[i / 8])
511 }
512 2 => {
513 let bytes: [u8; 4] = array::from_fn(|i| {
514 U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
515 });
516
517 Self::from_fn::<u8>(|i| bytes[i / 4])
518 }
519 3 => {
520 let bytes: [u8; 8] = array::from_fn(|i| {
521 U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
522 .spread_to_byte()
523 });
524
525 Self::from_fn::<u8>(|i| bytes[i / 2])
526 }
527 4 => {
528 let bytes: [u8; 16] = array::from_fn(|i| {
529 U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
530 .spread_to_byte()
531 });
532
533 Self::from_fn::<u8>(|i| bytes[i])
534 }
535 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
536 },
537 2 => match log_block_len {
538 0 => {
539 let value =
540 U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
541
542 unsafe { _mm_set1_epi8(value as i8).into() }
543 }
544 1 => {
545 let values: [u8; 2] = array::from_fn(|i| {
546 U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
547 });
548
549 Self::from_fn::<u8>(|i| values[i / 8])
550 }
551 2 => {
552 let values: [u8; 4] = array::from_fn(|i| {
553 U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
554 .spread_to_byte()
555 });
556
557 Self::from_fn::<u8>(|i| values[i / 4])
558 }
559 3 => {
560 let values: [u8; 8] = array::from_fn(|i| {
561 U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
562 .spread_to_byte()
563 });
564
565 Self::from_fn::<u8>(|i| values[i / 2])
566 }
567 4 => {
568 let values: [u8; 16] = array::from_fn(|i| {
569 U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
570 .spread_to_byte()
571 });
572
573 Self::from_fn::<u8>(|i| values[i])
574 }
575 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
576 },
577 3 => match log_block_len {
578 0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
579 1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
580 2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
581 3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
582 4 => self,
583 _ => panic!("unsupported block length"),
584 },
585 4 => match log_block_len {
586 0 => {
587 let value = (u128::from(self) >> (block_idx * 16)) as u16;
588
589 unsafe { _mm_set1_epi16(value as i16).into() }
590 }
591 1 => {
592 let values: [u16; 2] =
593 array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
594
595 Self::from_fn::<u16>(|i| values[i / 4])
596 }
597 2 => {
598 let values: [u16; 4] =
599 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
600
601 Self::from_fn::<u16>(|i| values[i / 2])
602 }
603 3 => self,
604 _ => panic!("unsupported block length"),
605 },
606 5 => match log_block_len {
607 0 => unsafe {
608 let value = (u128::from(self) >> (block_idx * 32)) as u32;
609
610 _mm_set1_epi32(value as i32).into()
611 },
612 1 => {
613 let values: [u32; 2] =
614 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
615
616 Self::from_fn::<u32>(|i| values[i / 2])
617 }
618 2 => self,
619 _ => panic!("unsupported block length"),
620 },
621 6 => match log_block_len {
622 0 => unsafe {
623 let value = (u128::from(self) >> (block_idx * 64)) as u64;
624
625 _mm_set1_epi64x(value as i64).into()
626 },
627 1 => self,
628 _ => panic!("unsupported block length"),
629 },
630 7 => self,
631 _ => panic!("unsupported bit length"),
632 }
633 }
634}
635
636unsafe impl Zeroable for M128 {}
637
638unsafe impl Pod for M128 {}
639
640unsafe impl Send for M128 {}
641
642unsafe impl Sync for M128 {}
643
644static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
645static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
646static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
647static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
648
649const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
650 log_block_len: usize,
651 t_log_bits: usize,
652) -> [M128; BLOCK_IDX_AMOUNT] {
653 let element_log_width = t_log_bits - 3;
654
655 let element_width = 1 << element_log_width;
656
657 let block_size = 1 << (log_block_len + element_log_width);
658 let repeat = 1 << (4 - element_log_width - log_block_len);
659 let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
660
661 let mut block_idx = 0;
662
663 while block_idx < BLOCK_IDX_AMOUNT {
664 let base = block_idx * block_size;
665 let mut j = 0;
666 while j < 16 {
667 masks[block_idx][j] =
668 (base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
669 j += 1;
670 }
671 block_idx += 1;
672 }
673 let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
674
675 let mut block_idx = 0;
676
677 while block_idx < BLOCK_IDX_AMOUNT {
678 m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
679 block_idx += 1;
680 }
681
682 m128_masks
683}
684
685impl UnderlierWithBitConstants for M128 {
686 const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
687 Self::from_u128(interleave_mask_even!(u128, 0)),
688 Self::from_u128(interleave_mask_even!(u128, 1)),
689 Self::from_u128(interleave_mask_even!(u128, 2)),
690 Self::from_u128(interleave_mask_even!(u128, 3)),
691 Self::from_u128(interleave_mask_even!(u128, 4)),
692 Self::from_u128(interleave_mask_even!(u128, 5)),
693 Self::from_u128(interleave_mask_even!(u128, 6)),
694 ];
695
696 const INTERLEAVE_ODD_MASK: &'static [Self] = &[
697 Self::from_u128(interleave_mask_odd!(u128, 0)),
698 Self::from_u128(interleave_mask_odd!(u128, 1)),
699 Self::from_u128(interleave_mask_odd!(u128, 2)),
700 Self::from_u128(interleave_mask_odd!(u128, 3)),
701 Self::from_u128(interleave_mask_odd!(u128, 4)),
702 Self::from_u128(interleave_mask_odd!(u128, 5)),
703 Self::from_u128(interleave_mask_odd!(u128, 6)),
704 ];
705
706 #[inline(always)]
707 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
708 unsafe {
709 let (c, d) = interleave_bits(
710 Into::<Self>::into(self).into(),
711 Into::<Self>::into(other).into(),
712 log_block_len,
713 );
714 (Self::from(c), Self::from(d))
715 }
716 }
717}
718
719impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
720 fn from(value: __m128i) -> Self {
721 M128::from(value).into()
722 }
723}
724
725impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
726 fn from(value: u128) -> Self {
727 M128::from(value).into()
728 }
729}
730
731impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
732 fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
733 value.to_underlier().into()
734 }
735}
736
737impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
738where
739 u128: From<Scalar::Underlier>,
740{
741 #[inline(always)]
742 fn broadcast(scalar: Scalar) -> Self {
743 let tower_level = Scalar::N_BITS.ilog2() as usize;
744 match tower_level {
745 0..=3 => {
746 let mut value = u128::from(scalar.to_underlier()) as u8;
747 for n in tower_level..3 {
748 value |= value << (1 << n);
749 }
750
751 unsafe { _mm_set1_epi8(value as i8) }.into()
752 }
753 4 => {
754 let value = u128::from(scalar.to_underlier()) as u16;
755 unsafe { _mm_set1_epi16(value as i16) }.into()
756 }
757 5 => {
758 let value = u128::from(scalar.to_underlier()) as u32;
759 unsafe { _mm_set1_epi32(value as i32) }.into()
760 }
761 6 => {
762 let value = u128::from(scalar.to_underlier()) as u64;
763 unsafe { _mm_set1_epi64x(value as i64) }.into()
764 }
765 7 => {
766 let value = u128::from(scalar.to_underlier());
767 value.into()
768 }
769 _ => {
770 unreachable!("invalid tower level")
771 }
772 }
773 }
774}
775
776#[inline]
777unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
778 match log_block_len {
779 0 => unsafe {
780 let mask = _mm_set1_epi8(0x55i8);
781 interleave_bits_imm::<1>(a, b, mask)
782 },
783 1 => unsafe {
784 let mask = _mm_set1_epi8(0x33i8);
785 interleave_bits_imm::<2>(a, b, mask)
786 },
787 2 => unsafe {
788 let mask = _mm_set1_epi8(0x0fi8);
789 interleave_bits_imm::<4>(a, b, mask)
790 },
791 3 => unsafe {
792 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
793 let a = _mm_shuffle_epi8(a, shuffle);
794 let b = _mm_shuffle_epi8(b, shuffle);
795 let a_prime = _mm_unpacklo_epi8(a, b);
796 let b_prime = _mm_unpackhi_epi8(a, b);
797 (a_prime, b_prime)
798 },
799 4 => unsafe {
800 let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
801 let a = _mm_shuffle_epi8(a, shuffle);
802 let b = _mm_shuffle_epi8(b, shuffle);
803 let a_prime = _mm_unpacklo_epi16(a, b);
804 let b_prime = _mm_unpackhi_epi16(a, b);
805 (a_prime, b_prime)
806 },
807 5 => unsafe {
808 let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
809 let a = _mm_shuffle_epi8(a, shuffle);
810 let b = _mm_shuffle_epi8(b, shuffle);
811 let a_prime = _mm_unpacklo_epi32(a, b);
812 let b_prime = _mm_unpackhi_epi32(a, b);
813 (a_prime, b_prime)
814 },
815 6 => unsafe {
816 let a_prime = _mm_unpacklo_epi64(a, b);
817 let b_prime = _mm_unpackhi_epi64(a, b);
818 (a_prime, b_prime)
819 },
820 _ => panic!("unsupported block length"),
821 }
822}
823
824#[inline]
825unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
826 a: __m128i,
827 b: __m128i,
828 mask: __m128i,
829) -> (__m128i, __m128i) {
830 unsafe {
831 let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
832 let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
833 let b_prime = _mm_xor_si128(b, t);
834 (a_prime, b_prime)
835 }
836}
837
838impl Divisible<u128> for M128 {
841 const LOG_N: usize = 0;
842
843 #[inline]
844 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone {
845 std::iter::once(u128::from(value))
846 }
847
848 #[inline]
849 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
850 std::iter::once(u128::from(*value))
851 }
852
853 #[inline]
854 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
855 slice.iter().map(|&v| u128::from(v))
856 }
857
858 #[inline]
859 fn get(self, index: usize) -> u128 {
860 assert!(index == 0, "index out of bounds");
861 u128::from(self)
862 }
863
864 #[inline]
865 fn set(self, index: usize, val: u128) -> Self {
866 assert!(index == 0, "index out of bounds");
867 Self::from(val)
868 }
869}
870
871impl Divisible<u64> for M128 {
872 const LOG_N: usize = 1;
873
874 #[inline]
875 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone {
876 mapget::value_iter(value)
877 }
878
879 #[inline]
880 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
881 mapget::value_iter(*value)
882 }
883
884 #[inline]
885 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
886 mapget::slice_iter(slice)
887 }
888
889 #[inline]
890 fn get(self, index: usize) -> u64 {
891 unsafe {
892 match index {
893 0 => _mm_extract_epi64(self.0, 0) as u64,
894 1 => _mm_extract_epi64(self.0, 1) as u64,
895 _ => panic!("index out of bounds"),
896 }
897 }
898 }
899
900 #[inline]
901 fn set(self, index: usize, val: u64) -> Self {
902 unsafe {
903 match index {
904 0 => Self(_mm_insert_epi64(self.0, val as i64, 0)),
905 1 => Self(_mm_insert_epi64(self.0, val as i64, 1)),
906 _ => panic!("index out of bounds"),
907 }
908 }
909 }
910}
911
912impl Divisible<u32> for M128 {
913 const LOG_N: usize = 2;
914
915 #[inline]
916 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone {
917 mapget::value_iter(value)
918 }
919
920 #[inline]
921 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
922 mapget::value_iter(*value)
923 }
924
925 #[inline]
926 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
927 mapget::slice_iter(slice)
928 }
929
930 #[inline]
931 fn get(self, index: usize) -> u32 {
932 unsafe {
933 match index {
934 0 => _mm_extract_epi32(self.0, 0) as u32,
935 1 => _mm_extract_epi32(self.0, 1) as u32,
936 2 => _mm_extract_epi32(self.0, 2) as u32,
937 3 => _mm_extract_epi32(self.0, 3) as u32,
938 _ => panic!("index out of bounds"),
939 }
940 }
941 }
942
943 #[inline]
944 fn set(self, index: usize, val: u32) -> Self {
945 unsafe {
946 match index {
947 0 => Self(_mm_insert_epi32(self.0, val as i32, 0)),
948 1 => Self(_mm_insert_epi32(self.0, val as i32, 1)),
949 2 => Self(_mm_insert_epi32(self.0, val as i32, 2)),
950 3 => Self(_mm_insert_epi32(self.0, val as i32, 3)),
951 _ => panic!("index out of bounds"),
952 }
953 }
954 }
955}
956
957impl Divisible<u16> for M128 {
958 const LOG_N: usize = 3;
959
960 #[inline]
961 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone {
962 mapget::value_iter(value)
963 }
964
965 #[inline]
966 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
967 mapget::value_iter(*value)
968 }
969
970 #[inline]
971 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
972 mapget::slice_iter(slice)
973 }
974
975 #[inline]
976 fn get(self, index: usize) -> u16 {
977 unsafe {
978 match index {
979 0 => _mm_extract_epi16(self.0, 0) as u16,
980 1 => _mm_extract_epi16(self.0, 1) as u16,
981 2 => _mm_extract_epi16(self.0, 2) as u16,
982 3 => _mm_extract_epi16(self.0, 3) as u16,
983 4 => _mm_extract_epi16(self.0, 4) as u16,
984 5 => _mm_extract_epi16(self.0, 5) as u16,
985 6 => _mm_extract_epi16(self.0, 6) as u16,
986 7 => _mm_extract_epi16(self.0, 7) as u16,
987 _ => panic!("index out of bounds"),
988 }
989 }
990 }
991
992 #[inline]
993 fn set(self, index: usize, val: u16) -> Self {
994 unsafe {
995 match index {
996 0 => Self(_mm_insert_epi16(self.0, val as i32, 0)),
997 1 => Self(_mm_insert_epi16(self.0, val as i32, 1)),
998 2 => Self(_mm_insert_epi16(self.0, val as i32, 2)),
999 3 => Self(_mm_insert_epi16(self.0, val as i32, 3)),
1000 4 => Self(_mm_insert_epi16(self.0, val as i32, 4)),
1001 5 => Self(_mm_insert_epi16(self.0, val as i32, 5)),
1002 6 => Self(_mm_insert_epi16(self.0, val as i32, 6)),
1003 7 => Self(_mm_insert_epi16(self.0, val as i32, 7)),
1004 _ => panic!("index out of bounds"),
1005 }
1006 }
1007 }
1008}
1009
1010impl Divisible<u8> for M128 {
1011 const LOG_N: usize = 4;
1012
1013 #[inline]
1014 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone {
1015 mapget::value_iter(value)
1016 }
1017
1018 #[inline]
1019 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
1020 mapget::value_iter(*value)
1021 }
1022
1023 #[inline]
1024 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
1025 mapget::slice_iter(slice)
1026 }
1027
1028 #[inline]
1029 fn get(self, index: usize) -> u8 {
1030 unsafe {
1031 match index {
1032 0 => _mm_extract_epi8(self.0, 0) as u8,
1033 1 => _mm_extract_epi8(self.0, 1) as u8,
1034 2 => _mm_extract_epi8(self.0, 2) as u8,
1035 3 => _mm_extract_epi8(self.0, 3) as u8,
1036 4 => _mm_extract_epi8(self.0, 4) as u8,
1037 5 => _mm_extract_epi8(self.0, 5) as u8,
1038 6 => _mm_extract_epi8(self.0, 6) as u8,
1039 7 => _mm_extract_epi8(self.0, 7) as u8,
1040 8 => _mm_extract_epi8(self.0, 8) as u8,
1041 9 => _mm_extract_epi8(self.0, 9) as u8,
1042 10 => _mm_extract_epi8(self.0, 10) as u8,
1043 11 => _mm_extract_epi8(self.0, 11) as u8,
1044 12 => _mm_extract_epi8(self.0, 12) as u8,
1045 13 => _mm_extract_epi8(self.0, 13) as u8,
1046 14 => _mm_extract_epi8(self.0, 14) as u8,
1047 15 => _mm_extract_epi8(self.0, 15) as u8,
1048 _ => panic!("index out of bounds"),
1049 }
1050 }
1051 }
1052
1053 #[inline]
1054 fn set(self, index: usize, val: u8) -> Self {
1055 unsafe {
1056 match index {
1057 0 => Self(_mm_insert_epi8(self.0, val as i32, 0)),
1058 1 => Self(_mm_insert_epi8(self.0, val as i32, 1)),
1059 2 => Self(_mm_insert_epi8(self.0, val as i32, 2)),
1060 3 => Self(_mm_insert_epi8(self.0, val as i32, 3)),
1061 4 => Self(_mm_insert_epi8(self.0, val as i32, 4)),
1062 5 => Self(_mm_insert_epi8(self.0, val as i32, 5)),
1063 6 => Self(_mm_insert_epi8(self.0, val as i32, 6)),
1064 7 => Self(_mm_insert_epi8(self.0, val as i32, 7)),
1065 8 => Self(_mm_insert_epi8(self.0, val as i32, 8)),
1066 9 => Self(_mm_insert_epi8(self.0, val as i32, 9)),
1067 10 => Self(_mm_insert_epi8(self.0, val as i32, 10)),
1068 11 => Self(_mm_insert_epi8(self.0, val as i32, 11)),
1069 12 => Self(_mm_insert_epi8(self.0, val as i32, 12)),
1070 13 => Self(_mm_insert_epi8(self.0, val as i32, 13)),
1071 14 => Self(_mm_insert_epi8(self.0, val as i32, 14)),
1072 15 => Self(_mm_insert_epi8(self.0, val as i32, 15)),
1073 _ => panic!("index out of bounds"),
1074 }
1075 }
1076 }
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081 use binius_utils::bytes::BytesMut;
1082 use proptest::{arbitrary::any, proptest};
1083 use rand::{SeedableRng, rngs::StdRng};
1084
1085 use super::*;
1086
1087 fn check_roundtrip<T>(val: M128)
1088 where
1089 T: From<M128>,
1090 M128: From<T>,
1091 {
1092 assert_eq!(M128::from(T::from(val)), val);
1093 }
1094
1095 #[test]
1096 fn test_constants() {
1097 assert_eq!(M128::default(), M128::ZERO);
1098 assert_eq!(M128::from(0u128), M128::ZERO);
1099 assert_eq!(M128::from(1u128), M128::ONE);
1100 }
1101
1102 fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1103 (value >> (index << log_block_len)) & M128::from(1u128 << log_block_len)
1104 }
1105
1106 proptest! {
1107 #[test]
1108 fn test_conversion(a in any::<u128>()) {
1109 check_roundtrip::<u128>(a.into());
1110 check_roundtrip::<__m128i>(a.into());
1111 }
1112
1113 #[test]
1114 fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1115 assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1116 assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1117 assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1118 }
1119
1120 #[test]
1121 fn test_negate(a in any::<u128>()) {
1122 assert_eq!(M128::from(!a), !M128::from(a))
1123 }
1124
1125 #[test]
1126 fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1127 assert_eq!(M128::from(a << b), M128::from(a) << b);
1128 assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1129 }
1130
1131 #[test]
1132 fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1133 let a = M128::from(a);
1134 let b = M128::from(b);
1135
1136 let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1137 let (c, d) = (M128::from(c), M128::from(d));
1138
1139 for i in (0..128>>height).step_by(2) {
1140 assert_eq!(get(c, height, i), get(a, height, i));
1141 assert_eq!(get(c, height, i+1), get(b, height, i));
1142 assert_eq!(get(d, height, i), get(a, height, i+1));
1143 assert_eq!(get(d, height, i+1), get(b, height, i+1));
1144 }
1145 }
1146 }
1147
1148 #[test]
1149 fn test_fill_with_bit() {
1150 assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1151 assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1152 }
1153
1154 #[test]
1155 fn test_eq() {
1156 let a = M128::from(0u128);
1157 let b = M128::from(42u128);
1158 let c = M128::from(u128::MAX);
1159
1160 assert_eq!(a, a);
1161 assert_eq!(b, b);
1162 assert_eq!(c, c);
1163
1164 assert_ne!(a, b);
1165 assert_ne!(a, c);
1166 assert_ne!(b, c);
1167 }
1168
1169 #[test]
1170 fn test_serialize_and_deserialize_m128() {
1171 let mut rng = StdRng::from_seed([0; 32]);
1172
1173 let original_value = M128::from(rng.random::<u128>());
1174
1175 let mut buf = BytesMut::new();
1176 original_value.serialize(&mut buf).unwrap();
1177
1178 let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1179
1180 assert_eq!(original_value, deserialized_value);
1181 }
1182}