1use std::{
4 arch::x86_64::*,
5 array,
6 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializationMode, SerializeBytes,
11 bytes::{Buf, BufMut},
12 serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{Rng, RngCore};
16use seq_macro::seq;
17use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
18
19use crate::{
20 BinaryField,
21 arch::{
22 binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
23 portable::{
24 packed::{PackedPrimitiveType, impl_pack_scalar},
25 packed_arithmetic::{
26 UnderlierWithBitConstants, interleave_mask_even, interleave_mask_odd,
27 },
28 },
29 },
30 arithmetic_traits::Broadcast,
31 tower_levels::TowerLevel,
32 underlier::{
33 NumCast, Random, SmallU, SpreadToByte, U1, U2, U4, UnderlierType, UnderlierWithBitOps,
34 WithUnderlier, impl_divisible, impl_iteration, spread_fallback, transpose_128b_values,
35 unpack_hi_128b_fallback, unpack_lo_128b_fallback,
36 },
37};
38
39#[derive(Copy, Clone)]
41#[repr(transparent)]
42pub struct M128(pub(super) __m128i);
43
44impl M128 {
45 #[inline(always)]
46 pub const fn from_u128(val: u128) -> Self {
47 let mut result = Self::ZERO;
48 unsafe {
49 result.0 = std::mem::transmute_copy(&val);
50 }
51
52 result
53 }
54}
55
56impl From<__m128i> for M128 {
57 #[inline(always)]
58 fn from(value: __m128i) -> Self {
59 Self(value)
60 }
61}
62
63impl From<u128> for M128 {
64 fn from(value: u128) -> Self {
65 Self(unsafe { _mm_loadu_si128(&raw const value as *const __m128i) })
66 }
67}
68
69impl From<u64> for M128 {
70 fn from(value: u64) -> Self {
71 Self::from(value as u128)
72 }
73}
74
75impl From<u32> for M128 {
76 fn from(value: u32) -> Self {
77 Self::from(value as u128)
78 }
79}
80
81impl From<u16> for M128 {
82 fn from(value: u16) -> Self {
83 Self::from(value as u128)
84 }
85}
86
87impl From<u8> for M128 {
88 fn from(value: u8) -> Self {
89 Self::from(value as u128)
90 }
91}
92
93impl<const N: usize> From<SmallU<N>> for M128 {
94 fn from(value: SmallU<N>) -> Self {
95 Self::from(value.val() as u128)
96 }
97}
98
99impl From<M128> for u128 {
100 fn from(value: M128) -> Self {
101 let mut result = 0u128;
102 unsafe { _mm_storeu_si128(&raw mut result as *mut __m128i, value.0) };
103
104 result
105 }
106}
107
108impl From<M128> for __m128i {
109 #[inline(always)]
110 fn from(value: M128) -> Self {
111 value.0
112 }
113}
114
115impl SerializeBytes for M128 {
116 fn serialize(
117 &self,
118 mut write_buf: impl BufMut,
119 _mode: SerializationMode,
120 ) -> Result<(), SerializationError> {
121 assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
122
123 let raw_value: u128 = (*self).into();
124
125 write_buf.put_u128_le(raw_value);
126 Ok(())
127 }
128}
129
130impl DeserializeBytes for M128 {
131 fn deserialize(
132 mut read_buf: impl Buf,
133 _mode: SerializationMode,
134 ) -> Result<Self, SerializationError>
135 where
136 Self: Sized,
137 {
138 assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
139
140 let raw_value = read_buf.get_u128_le();
141
142 Ok(Self::from(raw_value))
143 }
144}
145
146impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
147impl_pack_scalar!(M128);
148
149impl<U: NumCast<u128>> NumCast<M128> for U {
150 #[inline(always)]
151 fn num_cast_from(val: M128) -> Self {
152 Self::num_cast_from(u128::from(val))
153 }
154}
155
156impl Default for M128 {
157 #[inline(always)]
158 fn default() -> Self {
159 Self(unsafe { _mm_setzero_si128() })
160 }
161}
162
163impl BitAnd for M128 {
164 type Output = Self;
165
166 #[inline(always)]
167 fn bitand(self, rhs: Self) -> Self::Output {
168 Self(unsafe { _mm_and_si128(self.0, rhs.0) })
169 }
170}
171
172impl BitAndAssign for M128 {
173 #[inline(always)]
174 fn bitand_assign(&mut self, rhs: Self) {
175 *self = *self & rhs
176 }
177}
178
179impl BitOr for M128 {
180 type Output = Self;
181
182 #[inline(always)]
183 fn bitor(self, rhs: Self) -> Self::Output {
184 Self(unsafe { _mm_or_si128(self.0, rhs.0) })
185 }
186}
187
188impl BitOrAssign for M128 {
189 #[inline(always)]
190 fn bitor_assign(&mut self, rhs: Self) {
191 *self = *self | rhs
192 }
193}
194
195impl BitXor for M128 {
196 type Output = Self;
197
198 #[inline(always)]
199 fn bitxor(self, rhs: Self) -> Self::Output {
200 Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
201 }
202}
203
204impl BitXorAssign for M128 {
205 #[inline(always)]
206 fn bitxor_assign(&mut self, rhs: Self) {
207 *self = *self ^ rhs;
208 }
209}
210
211impl Not for M128 {
212 type Output = Self;
213
214 fn not(self) -> Self::Output {
215 const ONES: __m128i = m128_from_u128!(u128::MAX);
216
217 self ^ Self(ONES)
218 }
219}
220
221pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
223 if left > right { left } else { right }
224}
225
226macro_rules! bitshift_128b {
231 ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
232 unsafe {
233 let carry = $byte_shift($val, 8);
234 seq!(N in 64..128 {
235 if $shift == N {
236 return $bit_shift_64(
237 carry,
238 crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
239 ).into();
240 }
241 });
242 seq!(N in 0..64 {
243 if $shift == N {
244 let carry = $bit_shift_64_opposite(
245 carry,
246 crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
247 );
248
249 let val = $bit_shift_64($val, N);
250 return $or(val, carry).into();
251 }
252 });
253
254 return Default::default()
255 }
256 };
257}
258
259pub(crate) use bitshift_128b;
260
261impl Shr<usize> for M128 {
262 type Output = Self;
263
264 #[inline(always)]
265 fn shr(self, rhs: usize) -> Self::Output {
266 bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
269 }
270}
271
272impl Shl<usize> for M128 {
273 type Output = Self;
274
275 #[inline(always)]
276 fn shl(self, rhs: usize) -> Self::Output {
277 bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
280 }
281}
282
283impl PartialEq for M128 {
284 fn eq(&self, other: &Self) -> bool {
285 unsafe {
286 let neq = _mm_xor_si128(self.0, other.0);
287 _mm_test_all_zeros(neq, neq) == 1
288 }
289 }
290}
291
292impl Eq for M128 {}
293
294impl PartialOrd for M128 {
295 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
296 Some(self.cmp(other))
297 }
298}
299
300impl Ord for M128 {
301 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
302 u128::from(*self).cmp(&u128::from(*other))
303 }
304}
305
306impl ConstantTimeEq for M128 {
307 fn ct_eq(&self, other: &Self) -> Choice {
308 unsafe {
309 let neq = _mm_xor_si128(self.0, other.0);
310 Choice::from(_mm_test_all_zeros(neq, neq) as u8)
311 }
312 }
313}
314
315impl ConditionallySelectable for M128 {
316 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
317 ConditionallySelectable::conditional_select(&u128::from(*a), &u128::from(*b), choice).into()
318 }
319}
320
321impl Random for M128 {
322 fn random(mut rng: impl RngCore) -> Self {
323 let val: u128 = rng.random();
324 val.into()
325 }
326}
327
328impl std::fmt::Display for M128 {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 let data: u128 = (*self).into();
331 write!(f, "{data:02X?}")
332 }
333}
334
335impl std::fmt::Debug for M128 {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 write!(f, "M128({self})")
338 }
339}
340
341#[repr(align(16))]
342pub struct AlignedData(pub [u128; 1]);
343
344macro_rules! m128_from_u128 {
345 ($val:expr) => {{
346 let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
347 unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
348 }};
349}
350
351pub(super) use m128_from_u128;
352
353impl UnderlierType for M128 {
354 const LOG_BITS: usize = 7;
355}
356
357impl UnderlierWithBitOps for M128 {
358 const ZERO: Self = { Self(m128_from_u128!(0)) };
359 const ONE: Self = { Self(m128_from_u128!(1)) };
360 const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
361
362 #[inline(always)]
363 fn fill_with_bit(val: u8) -> Self {
364 assert!(val == 0 || val == 1);
365 Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
366 }
367
368 #[inline(always)]
369 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
370 where
371 T: UnderlierType,
372 Self: From<T>,
373 {
374 match T::BITS {
375 1 | 2 | 4 => {
376 let mut f = make_func_to_i8::<T, Self>(f);
377
378 unsafe {
379 _mm_set_epi8(
380 f(15),
381 f(14),
382 f(13),
383 f(12),
384 f(11),
385 f(10),
386 f(9),
387 f(8),
388 f(7),
389 f(6),
390 f(5),
391 f(4),
392 f(3),
393 f(2),
394 f(1),
395 f(0),
396 )
397 }
398 .into()
399 }
400 8 => {
401 let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
402 unsafe {
403 _mm_set_epi8(
404 f(15),
405 f(14),
406 f(13),
407 f(12),
408 f(11),
409 f(10),
410 f(9),
411 f(8),
412 f(7),
413 f(6),
414 f(5),
415 f(4),
416 f(3),
417 f(2),
418 f(1),
419 f(0),
420 )
421 }
422 .into()
423 }
424 16 => {
425 let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
426 unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
427 }
428 32 => {
429 let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
430 unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
431 }
432 64 => {
433 let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
434 unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
435 }
436 128 => Self::from(f(0)),
437 _ => panic!("unsupported bit count"),
438 }
439 }
440
441 #[inline(always)]
442 unsafe fn get_subvalue<T>(&self, i: usize) -> T
443 where
444 T: UnderlierType + NumCast<Self>,
445 {
446 match T::BITS {
447 1 | 2 | 4 => {
448 let elements_in_8 = 8 / T::BITS;
449 let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
450 *arr.get_unchecked(i / elements_in_8)
451 });
452
453 let shift = (i % elements_in_8) * T::BITS;
454 value_u8 >>= shift;
455
456 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
457 }
458 8 => {
459 let value_u8 =
460 as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
461 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
462 }
463 16 => {
464 let value_u16 =
465 as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
466 T::from_underlier(T::num_cast_from(Self::from(value_u16)))
467 }
468 32 => {
469 let value_u32 =
470 as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
471 T::from_underlier(T::num_cast_from(Self::from(value_u32)))
472 }
473 64 => {
474 let value_u64 =
475 as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
476 T::from_underlier(T::num_cast_from(Self::from(value_u64)))
477 }
478 128 => T::from_underlier(T::num_cast_from(*self)),
479 _ => panic!("unsupported bit count"),
480 }
481 }
482
483 #[inline(always)]
484 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
485 where
486 T: UnderlierWithBitOps,
487 Self: From<T>,
488 {
489 match T::BITS {
490 1 | 2 | 4 => {
491 let elements_in_8 = 8 / T::BITS;
492 let mask = (1u8 << T::BITS) - 1;
493 let shift = (i % elements_in_8) * T::BITS;
494 let val = u8::num_cast_from(Self::from(val)) << shift;
495 let mask = mask << shift;
496
497 as_array_mut::<_, u8, 16>(self, |array| unsafe {
498 let element = array.get_unchecked_mut(i / elements_in_8);
499 *element &= !mask;
500 *element |= val;
501 });
502 }
503 8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
504 *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
505 }),
506 16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
507 *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
508 }),
509 32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
510 *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
511 }),
512 64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
513 *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
514 }),
515 128 => {
516 *self = Self::from(val);
517 }
518 _ => panic!("unsupported bit count"),
519 }
520 }
521
522 #[inline(always)]
523 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
524 where
525 T: UnderlierWithBitOps + NumCast<Self>,
526 Self: From<T>,
527 {
528 match T::LOG_BITS {
529 0 => match log_block_len {
530 0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
531 1 => unsafe {
532 let bits: [u8; 2] =
533 array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
534
535 _mm_set_epi64x(
536 u64::fill_with_bit(bits[1]) as i64,
537 u64::fill_with_bit(bits[0]) as i64,
538 )
539 .into()
540 },
541 2 => unsafe {
542 let bits: [u8; 4] =
543 array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
544
545 _mm_set_epi32(
546 u32::fill_with_bit(bits[3]) as i32,
547 u32::fill_with_bit(bits[2]) as i32,
548 u32::fill_with_bit(bits[1]) as i32,
549 u32::fill_with_bit(bits[0]) as i32,
550 )
551 .into()
552 },
553 3 => unsafe {
554 let bits: [u8; 8] =
555 array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
556
557 _mm_set_epi16(
558 u16::fill_with_bit(bits[7]) as i16,
559 u16::fill_with_bit(bits[6]) as i16,
560 u16::fill_with_bit(bits[5]) as i16,
561 u16::fill_with_bit(bits[4]) as i16,
562 u16::fill_with_bit(bits[3]) as i16,
563 u16::fill_with_bit(bits[2]) as i16,
564 u16::fill_with_bit(bits[1]) as i16,
565 u16::fill_with_bit(bits[0]) as i16,
566 )
567 .into()
568 },
569 4 => unsafe {
570 let bits: [u8; 16] =
571 array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
572
573 _mm_set_epi8(
574 u8::fill_with_bit(bits[15]) as i8,
575 u8::fill_with_bit(bits[14]) as i8,
576 u8::fill_with_bit(bits[13]) as i8,
577 u8::fill_with_bit(bits[12]) as i8,
578 u8::fill_with_bit(bits[11]) as i8,
579 u8::fill_with_bit(bits[10]) as i8,
580 u8::fill_with_bit(bits[9]) as i8,
581 u8::fill_with_bit(bits[8]) as i8,
582 u8::fill_with_bit(bits[7]) as i8,
583 u8::fill_with_bit(bits[6]) as i8,
584 u8::fill_with_bit(bits[5]) as i8,
585 u8::fill_with_bit(bits[4]) as i8,
586 u8::fill_with_bit(bits[3]) as i8,
587 u8::fill_with_bit(bits[2]) as i8,
588 u8::fill_with_bit(bits[1]) as i8,
589 u8::fill_with_bit(bits[0]) as i8,
590 )
591 .into()
592 },
593 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
594 },
595 1 => match log_block_len {
596 0 => unsafe {
597 let value =
598 U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
599
600 _mm_set1_epi8(value as i8).into()
601 },
602 1 => {
603 let bytes: [u8; 2] = array::from_fn(|i| {
604 U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
605 });
606
607 Self::from_fn::<u8>(|i| bytes[i / 8])
608 }
609 2 => {
610 let bytes: [u8; 4] = array::from_fn(|i| {
611 U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
612 });
613
614 Self::from_fn::<u8>(|i| bytes[i / 4])
615 }
616 3 => {
617 let bytes: [u8; 8] = array::from_fn(|i| {
618 U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
619 .spread_to_byte()
620 });
621
622 Self::from_fn::<u8>(|i| bytes[i / 2])
623 }
624 4 => {
625 let bytes: [u8; 16] = array::from_fn(|i| {
626 U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
627 .spread_to_byte()
628 });
629
630 Self::from_fn::<u8>(|i| bytes[i])
631 }
632 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
633 },
634 2 => match log_block_len {
635 0 => {
636 let value =
637 U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
638
639 unsafe { _mm_set1_epi8(value as i8).into() }
640 }
641 1 => {
642 let values: [u8; 2] = array::from_fn(|i| {
643 U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
644 });
645
646 Self::from_fn::<u8>(|i| values[i / 8])
647 }
648 2 => {
649 let values: [u8; 4] = array::from_fn(|i| {
650 U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
651 .spread_to_byte()
652 });
653
654 Self::from_fn::<u8>(|i| values[i / 4])
655 }
656 3 => {
657 let values: [u8; 8] = array::from_fn(|i| {
658 U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
659 .spread_to_byte()
660 });
661
662 Self::from_fn::<u8>(|i| values[i / 2])
663 }
664 4 => {
665 let values: [u8; 16] = array::from_fn(|i| {
666 U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
667 .spread_to_byte()
668 });
669
670 Self::from_fn::<u8>(|i| values[i])
671 }
672 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
673 },
674 3 => match log_block_len {
675 0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
676 1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
677 2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
678 3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
679 4 => self,
680 _ => panic!("unsupported block length"),
681 },
682 4 => match log_block_len {
683 0 => {
684 let value = (u128::from(self) >> (block_idx * 16)) as u16;
685
686 unsafe { _mm_set1_epi16(value as i16).into() }
687 }
688 1 => {
689 let values: [u16; 2] =
690 array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
691
692 Self::from_fn::<u16>(|i| values[i / 4])
693 }
694 2 => {
695 let values: [u16; 4] =
696 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
697
698 Self::from_fn::<u16>(|i| values[i / 2])
699 }
700 3 => self,
701 _ => panic!("unsupported block length"),
702 },
703 5 => match log_block_len {
704 0 => unsafe {
705 let value = (u128::from(self) >> (block_idx * 32)) as u32;
706
707 _mm_set1_epi32(value as i32).into()
708 },
709 1 => {
710 let values: [u32; 2] =
711 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
712
713 Self::from_fn::<u32>(|i| values[i / 2])
714 }
715 2 => self,
716 _ => panic!("unsupported block length"),
717 },
718 6 => match log_block_len {
719 0 => unsafe {
720 let value = (u128::from(self) >> (block_idx * 64)) as u64;
721
722 _mm_set1_epi64x(value as i64).into()
723 },
724 1 => self,
725 _ => panic!("unsupported block length"),
726 },
727 7 => self,
728 _ => panic!("unsupported bit length"),
729 }
730 }
731
732 #[inline]
733 fn shl_128b_lanes(self, shift: usize) -> Self {
734 self << shift
735 }
736
737 #[inline]
738 fn shr_128b_lanes(self, shift: usize) -> Self {
739 self >> shift
740 }
741
742 #[inline]
743 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
744 match log_block_len {
745 0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
746 3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
747 4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
748 5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
749 6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
750 _ => panic!("unsupported block length"),
751 }
752 }
753
754 #[inline]
755 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
756 match log_block_len {
757 0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
758 3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
759 4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
760 5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
761 6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
762 _ => panic!("unsupported block length"),
763 }
764 }
765
766 #[inline]
767 fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
768 where
769 u8: NumCast<Self>,
770 Self: From<u8>,
771 {
772 transpose_128b_values::<Self, TL>(values, 0);
773 }
774
775 #[inline]
776 fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
777 where
778 u8: NumCast<Self>,
779 Self: From<u8>,
780 {
781 if TL::LOG_WIDTH == 0 {
782 return;
783 }
784
785 match TL::LOG_WIDTH {
786 1 => unsafe {
787 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
788 for v in values.as_mut().iter_mut() {
789 *v = _mm_shuffle_epi8(v.0, shuffle).into();
790 }
791 },
792 2 => unsafe {
793 let shuffle = _mm_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
794 for v in values.as_mut().iter_mut() {
795 *v = _mm_shuffle_epi8(v.0, shuffle).into();
796 }
797 },
798 3 => unsafe {
799 let shuffle = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
800 for v in values.as_mut().iter_mut() {
801 *v = _mm_shuffle_epi8(v.0, shuffle).into();
802 }
803 },
804 4 => {}
805 _ => unreachable!("Log width must be less than 5"),
806 }
807
808 transpose_128b_values::<_, TL>(values, 4 - TL::LOG_WIDTH);
809 }
810}
811
812unsafe impl Zeroable for M128 {}
813
814unsafe impl Pod for M128 {}
815
816unsafe impl Send for M128 {}
817
818unsafe impl Sync for M128 {}
819
820static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
821static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
822static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
823static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
824
825const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
826 log_block_len: usize,
827 t_log_bits: usize,
828) -> [M128; BLOCK_IDX_AMOUNT] {
829 let element_log_width = t_log_bits - 3;
830
831 let element_width = 1 << element_log_width;
832
833 let block_size = 1 << (log_block_len + element_log_width);
834 let repeat = 1 << (4 - element_log_width - log_block_len);
835 let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
836
837 let mut block_idx = 0;
838
839 while block_idx < BLOCK_IDX_AMOUNT {
840 let base = block_idx * block_size;
841 let mut j = 0;
842 while j < 16 {
843 masks[block_idx][j] =
844 (base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
845 j += 1;
846 }
847 block_idx += 1;
848 }
849 let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
850
851 let mut block_idx = 0;
852
853 while block_idx < BLOCK_IDX_AMOUNT {
854 m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
855 block_idx += 1;
856 }
857
858 m128_masks
859}
860
861impl UnderlierWithBitConstants for M128 {
862 const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
863 Self::from_u128(interleave_mask_even!(u128, 0)),
864 Self::from_u128(interleave_mask_even!(u128, 1)),
865 Self::from_u128(interleave_mask_even!(u128, 2)),
866 Self::from_u128(interleave_mask_even!(u128, 3)),
867 Self::from_u128(interleave_mask_even!(u128, 4)),
868 Self::from_u128(interleave_mask_even!(u128, 5)),
869 Self::from_u128(interleave_mask_even!(u128, 6)),
870 ];
871
872 const INTERLEAVE_ODD_MASK: &'static [Self] = &[
873 Self::from_u128(interleave_mask_odd!(u128, 0)),
874 Self::from_u128(interleave_mask_odd!(u128, 1)),
875 Self::from_u128(interleave_mask_odd!(u128, 2)),
876 Self::from_u128(interleave_mask_odd!(u128, 3)),
877 Self::from_u128(interleave_mask_odd!(u128, 4)),
878 Self::from_u128(interleave_mask_odd!(u128, 5)),
879 Self::from_u128(interleave_mask_odd!(u128, 6)),
880 ];
881
882 #[inline(always)]
883 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
884 unsafe {
885 let (c, d) = interleave_bits(
886 Into::<Self>::into(self).into(),
887 Into::<Self>::into(other).into(),
888 log_block_len,
889 );
890 (Self::from(c), Self::from(d))
891 }
892 }
893}
894
895impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
896 fn from(value: __m128i) -> Self {
897 M128::from(value).into()
898 }
899}
900
901impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
902 fn from(value: u128) -> Self {
903 M128::from(value).into()
904 }
905}
906
907impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
908 fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
909 value.to_underlier().into()
910 }
911}
912
913impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
914where
915 u128: From<Scalar::Underlier>,
916{
917 #[inline(always)]
918 fn broadcast(scalar: Scalar) -> Self {
919 let tower_level = Scalar::N_BITS.ilog2() as usize;
920 match tower_level {
921 0..=3 => {
922 let mut value = u128::from(scalar.to_underlier()) as u8;
923 for n in tower_level..3 {
924 value |= value << (1 << n);
925 }
926
927 unsafe { _mm_set1_epi8(value as i8) }.into()
928 }
929 4 => {
930 let value = u128::from(scalar.to_underlier()) as u16;
931 unsafe { _mm_set1_epi16(value as i16) }.into()
932 }
933 5 => {
934 let value = u128::from(scalar.to_underlier()) as u32;
935 unsafe { _mm_set1_epi32(value as i32) }.into()
936 }
937 6 => {
938 let value = u128::from(scalar.to_underlier()) as u64;
939 unsafe { _mm_set1_epi64x(value as i64) }.into()
940 }
941 7 => {
942 let value = u128::from(scalar.to_underlier());
943 value.into()
944 }
945 _ => {
946 unreachable!("invalid tower level")
947 }
948 }
949 }
950}
951
952#[inline]
953unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
954 match log_block_len {
955 0 => unsafe {
956 let mask = _mm_set1_epi8(0x55i8);
957 interleave_bits_imm::<1>(a, b, mask)
958 },
959 1 => unsafe {
960 let mask = _mm_set1_epi8(0x33i8);
961 interleave_bits_imm::<2>(a, b, mask)
962 },
963 2 => unsafe {
964 let mask = _mm_set1_epi8(0x0fi8);
965 interleave_bits_imm::<4>(a, b, mask)
966 },
967 3 => unsafe {
968 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
969 let a = _mm_shuffle_epi8(a, shuffle);
970 let b = _mm_shuffle_epi8(b, shuffle);
971 let a_prime = _mm_unpacklo_epi8(a, b);
972 let b_prime = _mm_unpackhi_epi8(a, b);
973 (a_prime, b_prime)
974 },
975 4 => unsafe {
976 let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
977 let a = _mm_shuffle_epi8(a, shuffle);
978 let b = _mm_shuffle_epi8(b, shuffle);
979 let a_prime = _mm_unpacklo_epi16(a, b);
980 let b_prime = _mm_unpackhi_epi16(a, b);
981 (a_prime, b_prime)
982 },
983 5 => unsafe {
984 let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
985 let a = _mm_shuffle_epi8(a, shuffle);
986 let b = _mm_shuffle_epi8(b, shuffle);
987 let a_prime = _mm_unpacklo_epi32(a, b);
988 let b_prime = _mm_unpackhi_epi32(a, b);
989 (a_prime, b_prime)
990 },
991 6 => unsafe {
992 let a_prime = _mm_unpacklo_epi64(a, b);
993 let b_prime = _mm_unpackhi_epi64(a, b);
994 (a_prime, b_prime)
995 },
996 _ => panic!("unsupported block length"),
997 }
998}
999
1000#[inline]
1001unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
1002 a: __m128i,
1003 b: __m128i,
1004 mask: __m128i,
1005) -> (__m128i, __m128i) {
1006 unsafe {
1007 let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
1008 let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
1009 let b_prime = _mm_xor_si128(b, t);
1010 (a_prime, b_prime)
1011 }
1012}
1013
1014impl_iteration!(M128,
1015 @strategy BitIterationStrategy, U1,
1016 @strategy FallbackStrategy, U2, U4,
1017 @strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
1018);
1019
1020#[cfg(test)]
1021mod tests {
1022 use binius_utils::bytes::BytesMut;
1023 use proptest::{arbitrary::any, proptest};
1024 use rand::{SeedableRng, rngs::StdRng};
1025
1026 use super::*;
1027 use crate::underlier::single_element_mask_bits;
1028
1029 fn check_roundtrip<T>(val: M128)
1030 where
1031 T: From<M128>,
1032 M128: From<T>,
1033 {
1034 assert_eq!(M128::from(T::from(val)), val);
1035 }
1036
1037 #[test]
1038 fn test_constants() {
1039 assert_eq!(M128::default(), M128::ZERO);
1040 assert_eq!(M128::from(0u128), M128::ZERO);
1041 assert_eq!(M128::from(1u128), M128::ONE);
1042 }
1043
1044 fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1045 (value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
1046 }
1047
1048 proptest! {
1049 #[test]
1050 fn test_conversion(a in any::<u128>()) {
1051 check_roundtrip::<u128>(a.into());
1052 check_roundtrip::<__m128i>(a.into());
1053 }
1054
1055 #[test]
1056 fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1057 assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1058 assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1059 assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1060 }
1061
1062 #[test]
1063 fn test_negate(a in any::<u128>()) {
1064 assert_eq!(M128::from(!a), !M128::from(a))
1065 }
1066
1067 #[test]
1068 fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1069 assert_eq!(M128::from(a << b), M128::from(a) << b);
1070 assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1071 }
1072
1073 #[test]
1074 fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1075 let a = M128::from(a);
1076 let b = M128::from(b);
1077
1078 let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1079 let (c, d) = (M128::from(c), M128::from(d));
1080
1081 for i in (0..128>>height).step_by(2) {
1082 assert_eq!(get(c, height, i), get(a, height, i));
1083 assert_eq!(get(c, height, i+1), get(b, height, i));
1084 assert_eq!(get(d, height, i), get(a, height, i+1));
1085 assert_eq!(get(d, height, i+1), get(b, height, i+1));
1086 }
1087 }
1088
1089 #[test]
1090 fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1091 let a = M128::from(a);
1092 let b = M128::from(b);
1093
1094 let result = a.unpack_lo_128b_lanes(b, height);
1095 for i in 0..128>>(height + 1) {
1096 assert_eq!(get(result, height, 2*i), get(a, height, i));
1097 assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1098 }
1099 }
1100
1101 #[test]
1102 fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1103 let a = M128::from(a);
1104 let b = M128::from(b);
1105
1106 let result = a.unpack_hi_128b_lanes(b, height);
1107 let half_block_count = 128>>(height + 1);
1108 for i in 0..half_block_count {
1109 assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1110 assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1111 }
1112 }
1113 }
1114
1115 #[test]
1116 fn test_fill_with_bit() {
1117 assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1118 assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1119 }
1120
1121 #[test]
1122 fn test_eq() {
1123 let a = M128::from(0u128);
1124 let b = M128::from(42u128);
1125 let c = M128::from(u128::MAX);
1126
1127 assert_eq!(a, a);
1128 assert_eq!(b, b);
1129 assert_eq!(c, c);
1130
1131 assert_ne!(a, b);
1132 assert_ne!(a, c);
1133 assert_ne!(b, c);
1134 }
1135
1136 #[test]
1137 fn test_serialize_and_deserialize_m128() {
1138 let mode = SerializationMode::Native;
1139
1140 let mut rng = StdRng::from_seed([0; 32]);
1141
1142 let original_value = M128::from(rng.random::<u128>());
1143
1144 let mut buf = BytesMut::new();
1145 original_value.serialize(&mut buf, mode).unwrap();
1146
1147 let deserialized_value = M128::deserialize(buf.freeze(), mode).unwrap();
1148
1149 assert_eq!(original_value, deserialized_value);
1150 }
1151}