1use std::{
4 arch::x86_64::*,
5 array,
6 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializationMode, SerializeBytes,
11 bytes::{Buf, BufMut},
12 serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable, must_cast};
15use rand::{Rng, RngCore};
16use seq_macro::seq;
17use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
18
19use crate::{
20 BinaryField,
21 arch::{
22 binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
23 portable::{
24 packed::{PackedPrimitiveType, impl_pack_scalar},
25 packed_arithmetic::{
26 UnderlierWithBitConstants, interleave_mask_even, interleave_mask_odd,
27 },
28 },
29 },
30 arithmetic_traits::Broadcast,
31 tower_levels::TowerLevel,
32 underlier::{
33 NumCast, Random, SmallU, SpreadToByte, U1, U2, U4, UnderlierType, UnderlierWithBitOps,
34 WithUnderlier, impl_divisible, impl_iteration, spread_fallback, transpose_128b_values,
35 unpack_hi_128b_fallback, unpack_lo_128b_fallback,
36 },
37};
38
39#[derive(Copy, Clone)]
41#[repr(transparent)]
42pub struct M128(pub(super) __m128i);
43
44impl M128 {
45 #[inline(always)]
46 pub const fn from_u128(val: u128) -> Self {
47 let mut result = Self::ZERO;
48 unsafe {
49 result.0 = std::mem::transmute_copy(&val);
50 }
51
52 result
53 }
54}
55
56impl From<__m128i> for M128 {
57 #[inline(always)]
58 fn from(value: __m128i) -> Self {
59 Self(value)
60 }
61}
62
63impl From<u128> for M128 {
64 fn from(value: u128) -> Self {
65 Self(unsafe { _mm_loadu_si128(&raw const value as *const __m128i) })
66 }
67}
68
69impl From<u64> for M128 {
70 fn from(value: u64) -> Self {
71 Self::from(value as u128)
72 }
73}
74
75impl From<u32> for M128 {
76 fn from(value: u32) -> Self {
77 Self::from(value as u128)
78 }
79}
80
81impl From<u16> for M128 {
82 fn from(value: u16) -> Self {
83 Self::from(value as u128)
84 }
85}
86
87impl From<u8> for M128 {
88 fn from(value: u8) -> Self {
89 Self::from(value as u128)
90 }
91}
92
93impl<const N: usize> From<SmallU<N>> for M128 {
94 fn from(value: SmallU<N>) -> Self {
95 Self::from(value.val() as u128)
96 }
97}
98
99impl From<M128> for u128 {
100 fn from(value: M128) -> Self {
101 let mut result = 0u128;
102 unsafe { _mm_storeu_si128(&raw mut result as *mut __m128i, value.0) };
103
104 result
105 }
106}
107
108impl From<M128> for __m128i {
109 #[inline(always)]
110 fn from(value: M128) -> Self {
111 value.0
112 }
113}
114
115impl SerializeBytes for M128 {
116 fn serialize(
117 &self,
118 mut write_buf: impl BufMut,
119 _mode: SerializationMode,
120 ) -> Result<(), SerializationError> {
121 assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
122
123 let raw_value: u128 = (*self).into();
124
125 write_buf.put_u128_le(raw_value);
126 Ok(())
127 }
128}
129
130impl DeserializeBytes for M128 {
131 fn deserialize(
132 mut read_buf: impl Buf,
133 _mode: SerializationMode,
134 ) -> Result<Self, SerializationError>
135 where
136 Self: Sized,
137 {
138 assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
139
140 let raw_value = read_buf.get_u128_le();
141
142 Ok(Self::from(raw_value))
143 }
144}
145
146impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
147impl_pack_scalar!(M128);
148
149impl<U: NumCast<u128>> NumCast<M128> for U {
150 #[inline(always)]
151 fn num_cast_from(val: M128) -> Self {
152 Self::num_cast_from(u128::from(val))
153 }
154}
155
156impl Default for M128 {
157 #[inline(always)]
158 fn default() -> Self {
159 Self(unsafe { _mm_setzero_si128() })
160 }
161}
162
163impl BitAnd for M128 {
164 type Output = Self;
165
166 #[inline(always)]
167 fn bitand(self, rhs: Self) -> Self::Output {
168 Self(unsafe { _mm_and_si128(self.0, rhs.0) })
169 }
170}
171
172impl BitAndAssign for M128 {
173 #[inline(always)]
174 fn bitand_assign(&mut self, rhs: Self) {
175 *self = *self & rhs
176 }
177}
178
179impl BitOr for M128 {
180 type Output = Self;
181
182 #[inline(always)]
183 fn bitor(self, rhs: Self) -> Self::Output {
184 Self(unsafe { _mm_or_si128(self.0, rhs.0) })
185 }
186}
187
188impl BitOrAssign for M128 {
189 #[inline(always)]
190 fn bitor_assign(&mut self, rhs: Self) {
191 *self = *self | rhs
192 }
193}
194
195impl BitXor for M128 {
196 type Output = Self;
197
198 #[inline(always)]
199 fn bitxor(self, rhs: Self) -> Self::Output {
200 Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
201 }
202}
203
204impl BitXorAssign for M128 {
205 #[inline(always)]
206 fn bitxor_assign(&mut self, rhs: Self) {
207 *self = *self ^ rhs;
208 }
209}
210
211impl Not for M128 {
212 type Output = Self;
213
214 fn not(self) -> Self::Output {
215 const ONES: __m128i = m128_from_u128!(u128::MAX);
216
217 self ^ Self(ONES)
218 }
219}
220
221pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
223 if left > right { left } else { right }
224}
225
226macro_rules! bitshift_128b {
231 ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
232 unsafe {
233 let carry = $byte_shift($val, 8);
234 seq!(N in 64..128 {
235 if $shift == N {
236 return $bit_shift_64(
237 carry,
238 crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
239 ).into();
240 }
241 });
242 seq!(N in 0..64 {
243 if $shift == N {
244 let carry = $bit_shift_64_opposite(
245 carry,
246 crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
247 );
248
249 let val = $bit_shift_64($val, N);
250 return $or(val, carry).into();
251 }
252 });
253
254 return Default::default()
255 }
256 };
257}
258
259pub(crate) use bitshift_128b;
260
261impl Shr<usize> for M128 {
262 type Output = Self;
263
264 #[inline(always)]
265 fn shr(self, rhs: usize) -> Self::Output {
266 bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
269 }
270}
271
272impl Shl<usize> for M128 {
273 type Output = Self;
274
275 #[inline(always)]
276 fn shl(self, rhs: usize) -> Self::Output {
277 bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
280 }
281}
282
283impl PartialEq for M128 {
284 fn eq(&self, other: &Self) -> bool {
285 unsafe {
286 let neq = _mm_xor_si128(self.0, other.0);
287 _mm_test_all_zeros(neq, neq) == 1
288 }
289 }
290}
291
292impl Eq for M128 {}
293
294impl PartialOrd for M128 {
295 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
296 Some(self.cmp(other))
297 }
298}
299
300impl Ord for M128 {
301 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
302 u128::from(*self).cmp(&u128::from(*other))
303 }
304}
305
306impl ConstantTimeEq for M128 {
307 fn ct_eq(&self, other: &Self) -> Choice {
308 unsafe {
309 let neq = _mm_xor_si128(self.0, other.0);
310 Choice::from(_mm_test_all_zeros(neq, neq) as u8)
311 }
312 }
313}
314
315impl ConditionallySelectable for M128 {
316 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
317 ConditionallySelectable::conditional_select(&u128::from(*a), &u128::from(*b), choice).into()
318 }
319}
320
321impl Random for M128 {
322 fn random(mut rng: impl RngCore) -> Self {
323 let val: u128 = rng.r#gen();
324 val.into()
325 }
326}
327
328impl std::fmt::Display for M128 {
329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330 let data: u128 = (*self).into();
331 write!(f, "{data:02X?}")
332 }
333}
334
335impl std::fmt::Debug for M128 {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 write!(f, "M128({self})")
338 }
339}
340
341#[repr(align(16))]
342pub struct AlignedData(pub [u128; 1]);
343
344macro_rules! m128_from_u128 {
345 ($val:expr) => {{
346 let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
347 unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
348 }};
349}
350
351pub(super) use m128_from_u128;
352
353impl UnderlierType for M128 {
354 const LOG_BITS: usize = 7;
355}
356
357impl UnderlierWithBitOps for M128 {
358 const ZERO: Self = { Self(m128_from_u128!(0)) };
359 const ONE: Self = { Self(m128_from_u128!(1)) };
360 const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
361
362 #[inline(always)]
363 fn fill_with_bit(val: u8) -> Self {
364 assert!(val == 0 || val == 1);
365 Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
366 }
367
368 #[inline(always)]
369 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
370 where
371 T: UnderlierType,
372 Self: From<T>,
373 {
374 match T::BITS {
375 1 | 2 | 4 => {
376 let mut f = make_func_to_i8::<T, Self>(f);
377
378 unsafe {
379 _mm_set_epi8(
380 f(15),
381 f(14),
382 f(13),
383 f(12),
384 f(11),
385 f(10),
386 f(9),
387 f(8),
388 f(7),
389 f(6),
390 f(5),
391 f(4),
392 f(3),
393 f(2),
394 f(1),
395 f(0),
396 )
397 }
398 .into()
399 }
400 8 => {
401 let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
402 unsafe {
403 _mm_set_epi8(
404 f(15),
405 f(14),
406 f(13),
407 f(12),
408 f(11),
409 f(10),
410 f(9),
411 f(8),
412 f(7),
413 f(6),
414 f(5),
415 f(4),
416 f(3),
417 f(2),
418 f(1),
419 f(0),
420 )
421 }
422 .into()
423 }
424 16 => {
425 let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
426 unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
427 }
428 32 => {
429 let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
430 unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
431 }
432 64 => {
433 let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
434 unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
435 }
436 128 => Self::from(f(0)),
437 _ => panic!("unsupported bit count"),
438 }
439 }
440
441 #[inline(always)]
442 unsafe fn get_subvalue<T>(&self, i: usize) -> T
443 where
444 T: UnderlierType + NumCast<Self>,
445 {
446 match T::BITS {
447 1 | 2 | 4 => {
448 let elements_in_8 = 8 / T::BITS;
449 let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
450 *arr.get_unchecked(i / elements_in_8)
451 });
452
453 let shift = (i % elements_in_8) * T::BITS;
454 value_u8 >>= shift;
455
456 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
457 }
458 8 => {
459 let value_u8 =
460 as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
461 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
462 }
463 16 => {
464 let value_u16 =
465 as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
466 T::from_underlier(T::num_cast_from(Self::from(value_u16)))
467 }
468 32 => {
469 let value_u32 =
470 as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
471 T::from_underlier(T::num_cast_from(Self::from(value_u32)))
472 }
473 64 => {
474 let value_u64 =
475 as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
476 T::from_underlier(T::num_cast_from(Self::from(value_u64)))
477 }
478 128 => T::from_underlier(T::num_cast_from(*self)),
479 _ => panic!("unsupported bit count"),
480 }
481 }
482
483 #[inline(always)]
484 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
485 where
486 T: UnderlierWithBitOps,
487 Self: From<T>,
488 {
489 match T::BITS {
490 1 | 2 | 4 => {
491 let elements_in_8 = 8 / T::BITS;
492 let mask = (1u8 << T::BITS) - 1;
493 let shift = (i % elements_in_8) * T::BITS;
494 let val = u8::num_cast_from(Self::from(val)) << shift;
495 let mask = mask << shift;
496
497 as_array_mut::<_, u8, 16>(self, |array| unsafe {
498 let element = array.get_unchecked_mut(i / elements_in_8);
499 *element &= !mask;
500 *element |= val;
501 });
502 }
503 8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
504 *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
505 }),
506 16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
507 *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
508 }),
509 32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
510 *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
511 }),
512 64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
513 *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
514 }),
515 128 => {
516 *self = Self::from(val);
517 }
518 _ => panic!("unsupported bit count"),
519 }
520 }
521
522 #[inline(always)]
523 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
524 where
525 T: UnderlierWithBitOps + NumCast<Self>,
526 Self: From<T>,
527 {
528 match T::LOG_BITS {
529 0 => match log_block_len {
530 0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
531 1 => unsafe {
532 let bits: [u8; 2] =
533 array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
534
535 _mm_set_epi64x(
536 u64::fill_with_bit(bits[1]) as i64,
537 u64::fill_with_bit(bits[0]) as i64,
538 )
539 .into()
540 },
541 2 => unsafe {
542 let bits: [u8; 4] =
543 array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
544
545 _mm_set_epi32(
546 u32::fill_with_bit(bits[3]) as i32,
547 u32::fill_with_bit(bits[2]) as i32,
548 u32::fill_with_bit(bits[1]) as i32,
549 u32::fill_with_bit(bits[0]) as i32,
550 )
551 .into()
552 },
553 3 => unsafe {
554 let bits: [u8; 8] =
555 array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
556
557 _mm_set_epi16(
558 u16::fill_with_bit(bits[7]) as i16,
559 u16::fill_with_bit(bits[6]) as i16,
560 u16::fill_with_bit(bits[5]) as i16,
561 u16::fill_with_bit(bits[4]) as i16,
562 u16::fill_with_bit(bits[3]) as i16,
563 u16::fill_with_bit(bits[2]) as i16,
564 u16::fill_with_bit(bits[1]) as i16,
565 u16::fill_with_bit(bits[0]) as i16,
566 )
567 .into()
568 },
569 4 => unsafe {
570 let bits: [u8; 16] =
571 array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
572
573 _mm_set_epi8(
574 u8::fill_with_bit(bits[15]) as i8,
575 u8::fill_with_bit(bits[14]) as i8,
576 u8::fill_with_bit(bits[13]) as i8,
577 u8::fill_with_bit(bits[12]) as i8,
578 u8::fill_with_bit(bits[11]) as i8,
579 u8::fill_with_bit(bits[10]) as i8,
580 u8::fill_with_bit(bits[9]) as i8,
581 u8::fill_with_bit(bits[8]) as i8,
582 u8::fill_with_bit(bits[7]) as i8,
583 u8::fill_with_bit(bits[6]) as i8,
584 u8::fill_with_bit(bits[5]) as i8,
585 u8::fill_with_bit(bits[4]) as i8,
586 u8::fill_with_bit(bits[3]) as i8,
587 u8::fill_with_bit(bits[2]) as i8,
588 u8::fill_with_bit(bits[1]) as i8,
589 u8::fill_with_bit(bits[0]) as i8,
590 )
591 .into()
592 },
593 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
594 },
595 1 => match log_block_len {
596 0 => unsafe {
597 let value =
598 U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
599
600 _mm_set1_epi8(value as i8).into()
601 },
602 1 => {
603 let bytes: [u8; 2] = array::from_fn(|i| {
604 U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
605 });
606
607 Self::from_fn::<u8>(|i| bytes[i / 8])
608 }
609 2 => {
610 let bytes: [u8; 4] = array::from_fn(|i| {
611 U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
612 });
613
614 Self::from_fn::<u8>(|i| bytes[i / 4])
615 }
616 3 => {
617 let bytes: [u8; 8] = array::from_fn(|i| {
618 U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
619 .spread_to_byte()
620 });
621
622 Self::from_fn::<u8>(|i| bytes[i / 2])
623 }
624 4 => {
625 let bytes: [u8; 16] = array::from_fn(|i| {
626 U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
627 .spread_to_byte()
628 });
629
630 Self::from_fn::<u8>(|i| bytes[i])
631 }
632 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
633 },
634 2 => match log_block_len {
635 0 => {
636 let value =
637 U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
638
639 unsafe { _mm_set1_epi8(value as i8).into() }
640 }
641 1 => {
642 let values: [u8; 2] = array::from_fn(|i| {
643 U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
644 });
645
646 Self::from_fn::<u8>(|i| values[i / 8])
647 }
648 2 => {
649 let values: [u8; 4] = array::from_fn(|i| {
650 U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
651 .spread_to_byte()
652 });
653
654 Self::from_fn::<u8>(|i| values[i / 4])
655 }
656 3 => {
657 let values: [u8; 8] = array::from_fn(|i| {
658 U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
659 .spread_to_byte()
660 });
661
662 Self::from_fn::<u8>(|i| values[i / 2])
663 }
664 4 => {
665 let values: [u8; 16] = array::from_fn(|i| {
666 U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
667 .spread_to_byte()
668 });
669
670 Self::from_fn::<u8>(|i| values[i])
671 }
672 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
673 },
674 3 => match log_block_len {
675 0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
676 1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
677 2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
678 3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
679 4 => self,
680 _ => panic!("unsupported block length"),
681 },
682 4 => match log_block_len {
683 0 => {
684 let value = (u128::from(self) >> (block_idx * 16)) as u16;
685
686 unsafe { _mm_set1_epi16(value as i16).into() }
687 }
688 1 => {
689 let values: [u16; 2] =
690 array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
691
692 Self::from_fn::<u16>(|i| values[i / 4])
693 }
694 2 => {
695 let values: [u16; 4] =
696 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
697
698 Self::from_fn::<u16>(|i| values[i / 2])
699 }
700 3 => self,
701 _ => panic!("unsupported block length"),
702 },
703 5 => match log_block_len {
704 0 => unsafe {
705 let value = (u128::from(self) >> (block_idx * 32)) as u32;
706
707 _mm_set1_epi32(value as i32).into()
708 },
709 1 => {
710 let values: [u32; 2] =
711 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
712
713 Self::from_fn::<u32>(|i| values[i / 2])
714 }
715 2 => self,
716 _ => panic!("unsupported block length"),
717 },
718 6 => match log_block_len {
719 0 => unsafe {
720 let value = (u128::from(self) >> (block_idx * 64)) as u64;
721
722 _mm_set1_epi64x(value as i64).into()
723 },
724 1 => self,
725 _ => panic!("unsupported block length"),
726 },
727 7 => self,
728 _ => panic!("unsupported bit length"),
729 }
730 }
731
732 #[inline]
733 fn shl_128b_lanes(self, shift: usize) -> Self {
734 self << shift
735 }
736
737 #[inline]
738 fn shr_128b_lanes(self, shift: usize) -> Self {
739 self >> shift
740 }
741
742 #[inline]
743 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
744 match log_block_len {
745 0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
746 3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
747 4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
748 5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
749 6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
750 _ => panic!("unsupported block length"),
751 }
752 }
753
754 #[inline]
755 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
756 match log_block_len {
757 0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
758 3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
759 4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
760 5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
761 6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
762 _ => panic!("unsupported block length"),
763 }
764 }
765
766 #[inline]
767 fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
768 where
769 u8: NumCast<Self>,
770 Self: From<u8>,
771 {
772 transpose_128b_values::<Self, TL>(values, 0);
773 }
774
775 #[inline]
776 fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
777 where
778 u8: NumCast<Self>,
779 Self: From<u8>,
780 {
781 if TL::LOG_WIDTH == 0 {
782 return;
783 }
784
785 match TL::LOG_WIDTH {
786 1 => unsafe {
787 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
788 for v in values.as_mut().iter_mut() {
789 *v = _mm_shuffle_epi8(v.0, shuffle).into();
790 }
791 },
792 2 => unsafe {
793 let shuffle = _mm_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
794 for v in values.as_mut().iter_mut() {
795 *v = _mm_shuffle_epi8(v.0, shuffle).into();
796 }
797 },
798 3 => unsafe {
799 let shuffle = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
800 for v in values.as_mut().iter_mut() {
801 *v = _mm_shuffle_epi8(v.0, shuffle).into();
802 }
803 },
804 4 => {}
805 _ => unreachable!("Log width must be less than 5"),
806 }
807
808 transpose_128b_values::<_, TL>(values, 4 - TL::LOG_WIDTH);
809 }
810}
811
812unsafe impl Zeroable for M128 {}
813
814unsafe impl Pod for M128 {}
815
816unsafe impl Send for M128 {}
817
818unsafe impl Sync for M128 {}
819
820static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
821static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
822static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
823static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
824
825const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
826 log_block_len: usize,
827 t_log_bits: usize,
828) -> [M128; BLOCK_IDX_AMOUNT] {
829 let element_log_width = t_log_bits - 3;
830
831 let element_width = 1 << element_log_width;
832
833 let block_size = 1 << (log_block_len + element_log_width);
834 let repeat = 1 << (4 - element_log_width - log_block_len);
835 let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
836
837 let mut block_idx = 0;
838
839 while block_idx < BLOCK_IDX_AMOUNT {
840 let base = block_idx * block_size;
841 let mut j = 0;
842 while j < 16 {
843 masks[block_idx][j] =
844 (base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
845 j += 1;
846 }
847 block_idx += 1;
848 }
849 let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
850
851 let mut block_idx = 0;
852
853 while block_idx < BLOCK_IDX_AMOUNT {
854 m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
855 block_idx += 1;
856 }
857
858 m128_masks
859}
860
861impl UnderlierWithBitConstants for M128 {
862 const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
863 Self::from_u128(interleave_mask_even!(u128, 0)),
864 Self::from_u128(interleave_mask_even!(u128, 1)),
865 Self::from_u128(interleave_mask_even!(u128, 2)),
866 Self::from_u128(interleave_mask_even!(u128, 3)),
867 Self::from_u128(interleave_mask_even!(u128, 4)),
868 Self::from_u128(interleave_mask_even!(u128, 5)),
869 Self::from_u128(interleave_mask_even!(u128, 6)),
870 ];
871
872 const INTERLEAVE_ODD_MASK: &'static [Self] = &[
873 Self::from_u128(interleave_mask_odd!(u128, 0)),
874 Self::from_u128(interleave_mask_odd!(u128, 1)),
875 Self::from_u128(interleave_mask_odd!(u128, 2)),
876 Self::from_u128(interleave_mask_odd!(u128, 3)),
877 Self::from_u128(interleave_mask_odd!(u128, 4)),
878 Self::from_u128(interleave_mask_odd!(u128, 5)),
879 Self::from_u128(interleave_mask_odd!(u128, 6)),
880 ];
881
882 #[inline(always)]
883 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
884 unsafe {
885 let (c, d) = interleave_bits(
886 Into::<Self>::into(self).into(),
887 Into::<Self>::into(other).into(),
888 log_block_len,
889 );
890 (Self::from(c), Self::from(d))
891 }
892 }
893}
894
895impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
896 fn from(value: __m128i) -> Self {
897 M128::from(value).into()
898 }
899}
900
901impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
902 fn from(value: u128) -> Self {
903 M128::from(value).into()
904 }
905}
906
907impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
908 fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
909 value.to_underlier().into()
910 }
911}
912
913impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
914where
915 u128: From<Scalar::Underlier>,
916{
917 #[inline(always)]
918 fn broadcast(scalar: Scalar) -> Self {
919 let tower_level = Scalar::N_BITS.ilog2() as usize;
920 let mut value = u128::from(scalar.to_underlier());
921 for n in tower_level..3 {
922 value |= value << (1 << n);
923 }
924
925 let value = must_cast(value);
926 let value = match tower_level {
927 0..=3 => unsafe { _mm_broadcastb_epi8(value) },
928 4 => unsafe { _mm_broadcastw_epi16(value) },
929 5 => unsafe { _mm_broadcastd_epi32(value) },
930 6 => unsafe { _mm_broadcastq_epi64(value) },
931 7 => value,
932 _ => unreachable!(),
933 };
934
935 value.into()
936 }
937}
938
939#[inline]
940unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
941 match log_block_len {
942 0 => unsafe {
943 let mask = _mm_set1_epi8(0x55i8);
944 interleave_bits_imm::<1>(a, b, mask)
945 },
946 1 => unsafe {
947 let mask = _mm_set1_epi8(0x33i8);
948 interleave_bits_imm::<2>(a, b, mask)
949 },
950 2 => unsafe {
951 let mask = _mm_set1_epi8(0x0fi8);
952 interleave_bits_imm::<4>(a, b, mask)
953 },
954 3 => unsafe {
955 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
956 let a = _mm_shuffle_epi8(a, shuffle);
957 let b = _mm_shuffle_epi8(b, shuffle);
958 let a_prime = _mm_unpacklo_epi8(a, b);
959 let b_prime = _mm_unpackhi_epi8(a, b);
960 (a_prime, b_prime)
961 },
962 4 => unsafe {
963 let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
964 let a = _mm_shuffle_epi8(a, shuffle);
965 let b = _mm_shuffle_epi8(b, shuffle);
966 let a_prime = _mm_unpacklo_epi16(a, b);
967 let b_prime = _mm_unpackhi_epi16(a, b);
968 (a_prime, b_prime)
969 },
970 5 => unsafe {
971 let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
972 let a = _mm_shuffle_epi8(a, shuffle);
973 let b = _mm_shuffle_epi8(b, shuffle);
974 let a_prime = _mm_unpacklo_epi32(a, b);
975 let b_prime = _mm_unpackhi_epi32(a, b);
976 (a_prime, b_prime)
977 },
978 6 => unsafe {
979 let a_prime = _mm_unpacklo_epi64(a, b);
980 let b_prime = _mm_unpackhi_epi64(a, b);
981 (a_prime, b_prime)
982 },
983 _ => panic!("unsupported block length"),
984 }
985}
986
987#[inline]
988unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
989 a: __m128i,
990 b: __m128i,
991 mask: __m128i,
992) -> (__m128i, __m128i) {
993 unsafe {
994 let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
995 let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
996 let b_prime = _mm_xor_si128(b, t);
997 (a_prime, b_prime)
998 }
999}
1000
1001impl_iteration!(M128,
1002 @strategy BitIterationStrategy, U1,
1003 @strategy FallbackStrategy, U2, U4,
1004 @strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
1005);
1006
1007#[cfg(test)]
1008mod tests {
1009 use binius_utils::bytes::BytesMut;
1010 use proptest::{arbitrary::any, proptest};
1011 use rand::{SeedableRng, rngs::StdRng};
1012
1013 use super::*;
1014 use crate::underlier::single_element_mask_bits;
1015
1016 fn check_roundtrip<T>(val: M128)
1017 where
1018 T: From<M128>,
1019 M128: From<T>,
1020 {
1021 assert_eq!(M128::from(T::from(val)), val);
1022 }
1023
1024 #[test]
1025 fn test_constants() {
1026 assert_eq!(M128::default(), M128::ZERO);
1027 assert_eq!(M128::from(0u128), M128::ZERO);
1028 assert_eq!(M128::from(1u128), M128::ONE);
1029 }
1030
1031 fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1032 (value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
1033 }
1034
1035 proptest! {
1036 #[test]
1037 fn test_conversion(a in any::<u128>()) {
1038 check_roundtrip::<u128>(a.into());
1039 check_roundtrip::<__m128i>(a.into());
1040 }
1041
1042 #[test]
1043 fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1044 assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1045 assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1046 assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1047 }
1048
1049 #[test]
1050 fn test_negate(a in any::<u128>()) {
1051 assert_eq!(M128::from(!a), !M128::from(a))
1052 }
1053
1054 #[test]
1055 fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1056 assert_eq!(M128::from(a << b), M128::from(a) << b);
1057 assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1058 }
1059
1060 #[test]
1061 fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1062 let a = M128::from(a);
1063 let b = M128::from(b);
1064
1065 let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1066 let (c, d) = (M128::from(c), M128::from(d));
1067
1068 for i in (0..128>>height).step_by(2) {
1069 assert_eq!(get(c, height, i), get(a, height, i));
1070 assert_eq!(get(c, height, i+1), get(b, height, i));
1071 assert_eq!(get(d, height, i), get(a, height, i+1));
1072 assert_eq!(get(d, height, i+1), get(b, height, i+1));
1073 }
1074 }
1075
1076 #[test]
1077 fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1078 let a = M128::from(a);
1079 let b = M128::from(b);
1080
1081 let result = a.unpack_lo_128b_lanes(b, height);
1082 for i in 0..128>>(height + 1) {
1083 assert_eq!(get(result, height, 2*i), get(a, height, i));
1084 assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1085 }
1086 }
1087
1088 #[test]
1089 fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1090 let a = M128::from(a);
1091 let b = M128::from(b);
1092
1093 let result = a.unpack_hi_128b_lanes(b, height);
1094 let half_block_count = 128>>(height + 1);
1095 for i in 0..half_block_count {
1096 assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1097 assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1098 }
1099 }
1100 }
1101
1102 #[test]
1103 fn test_fill_with_bit() {
1104 assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1105 assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1106 }
1107
1108 #[test]
1109 fn test_eq() {
1110 let a = M128::from(0u128);
1111 let b = M128::from(42u128);
1112 let c = M128::from(u128::MAX);
1113
1114 assert_eq!(a, a);
1115 assert_eq!(b, b);
1116 assert_eq!(c, c);
1117
1118 assert_ne!(a, b);
1119 assert_ne!(a, c);
1120 assert_ne!(b, c);
1121 }
1122
1123 #[test]
1124 fn test_serialize_and_deserialize_m128() {
1125 let mode = SerializationMode::Native;
1126
1127 let mut rng = StdRng::from_seed([0; 32]);
1128
1129 let original_value = M128::from(rng.r#gen::<u128>());
1130
1131 let mut buf = BytesMut::new();
1132 original_value.serialize(&mut buf, mode).unwrap();
1133
1134 let deserialized_value = M128::deserialize(buf.freeze(), mode).unwrap();
1135
1136 assert_eq!(original_value, deserialized_value);
1137 }
1138}