1use std::{
4 arch::x86_64::*,
5 array,
6 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializeBytes,
11 bytes::{Buf, BufMut},
12 serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{
16 Rng,
17 distr::{Distribution, StandardUniform},
18};
19use seq_macro::seq;
20
21use crate::{
22 BinaryField,
23 arch::{
24 binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
25 portable::{
26 packed::{PackedPrimitiveType, impl_pack_scalar},
27 packed_arithmetic::{
28 UnderlierWithBitConstants, interleave_mask_even, interleave_mask_odd,
29 },
30 },
31 },
32 arithmetic_traits::Broadcast,
33 underlier::{
34 NumCast, SmallU, SpreadToByte, U1, U2, U4, UnderlierType, UnderlierWithBitOps,
35 WithUnderlier, impl_divisible, impl_iteration, spread_fallback, unpack_hi_128b_fallback,
36 unpack_lo_128b_fallback,
37 },
38};
39
40#[derive(Copy, Clone)]
42#[repr(transparent)]
43pub struct M128(pub(super) __m128i);
44
45impl M128 {
46 #[inline(always)]
47 pub const fn from_u128(val: u128) -> Self {
48 let mut result = Self::ZERO;
49 unsafe {
50 result.0 = std::mem::transmute_copy(&val);
51 }
52
53 result
54 }
55}
56
57impl From<__m128i> for M128 {
58 #[inline(always)]
59 fn from(value: __m128i) -> Self {
60 Self(value)
61 }
62}
63
64impl From<u128> for M128 {
65 fn from(value: u128) -> Self {
66 Self(unsafe {
67 assert_eq!(align_of::<u128>(), 16);
69 _mm_load_si128(&raw const value as *const __m128i)
70 })
71 }
72}
73
74impl From<u64> for M128 {
75 fn from(value: u64) -> Self {
76 Self::from(value as u128)
77 }
78}
79
80impl From<u32> for M128 {
81 fn from(value: u32) -> Self {
82 Self::from(value as u128)
83 }
84}
85
86impl From<u16> for M128 {
87 fn from(value: u16) -> Self {
88 Self::from(value as u128)
89 }
90}
91
92impl From<u8> for M128 {
93 fn from(value: u8) -> Self {
94 Self::from(value as u128)
95 }
96}
97
98impl<const N: usize> From<SmallU<N>> for M128 {
99 fn from(value: SmallU<N>) -> Self {
100 Self::from(value.val() as u128)
101 }
102}
103
104impl From<M128> for u128 {
105 fn from(value: M128) -> Self {
106 let mut result = 0u128;
107 unsafe {
108 assert_eq!(align_of::<u128>(), 16);
110 _mm_store_si128(&raw mut result as *mut __m128i, value.0)
111 };
112 result
113 }
114}
115
116impl From<M128> for __m128i {
117 #[inline(always)]
118 fn from(value: M128) -> Self {
119 value.0
120 }
121}
122
123impl SerializeBytes for M128 {
124 fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
125 assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
126
127 let raw_value: u128 = (*self).into();
128
129 write_buf.put_u128_le(raw_value);
130 Ok(())
131 }
132}
133
134impl DeserializeBytes for M128 {
135 fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
136 where
137 Self: Sized,
138 {
139 assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
140
141 let raw_value = read_buf.get_u128_le();
142
143 Ok(Self::from(raw_value))
144 }
145}
146
147impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
148impl_pack_scalar!(M128);
149
150impl<U: NumCast<u128>> NumCast<M128> for U {
151 #[inline(always)]
152 fn num_cast_from(val: M128) -> Self {
153 Self::num_cast_from(u128::from(val))
154 }
155}
156
157impl Default for M128 {
158 #[inline(always)]
159 fn default() -> Self {
160 Self(unsafe { _mm_setzero_si128() })
161 }
162}
163
164impl BitAnd for M128 {
165 type Output = Self;
166
167 #[inline(always)]
168 fn bitand(self, rhs: Self) -> Self::Output {
169 Self(unsafe { _mm_and_si128(self.0, rhs.0) })
170 }
171}
172
173impl BitAndAssign for M128 {
174 #[inline(always)]
175 fn bitand_assign(&mut self, rhs: Self) {
176 *self = *self & rhs
177 }
178}
179
180impl BitOr for M128 {
181 type Output = Self;
182
183 #[inline(always)]
184 fn bitor(self, rhs: Self) -> Self::Output {
185 Self(unsafe { _mm_or_si128(self.0, rhs.0) })
186 }
187}
188
189impl BitOrAssign for M128 {
190 #[inline(always)]
191 fn bitor_assign(&mut self, rhs: Self) {
192 *self = *self | rhs
193 }
194}
195
196impl BitXor for M128 {
197 type Output = Self;
198
199 #[inline(always)]
200 fn bitxor(self, rhs: Self) -> Self::Output {
201 Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
202 }
203}
204
205impl BitXorAssign for M128 {
206 #[inline(always)]
207 fn bitxor_assign(&mut self, rhs: Self) {
208 *self = *self ^ rhs;
209 }
210}
211
212impl Not for M128 {
213 type Output = Self;
214
215 fn not(self) -> Self::Output {
216 const ONES: __m128i = m128_from_u128!(u128::MAX);
217
218 self ^ Self(ONES)
219 }
220}
221
222pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
224 if left > right { left } else { right }
225}
226
227macro_rules! bitshift_128b {
232 ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
233 unsafe {
234 let carry = $byte_shift($val, 8);
235 seq!(N in 64..128 {
236 if $shift == N {
237 return $bit_shift_64(
238 carry,
239 crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
240 ).into();
241 }
242 });
243 seq!(N in 0..64 {
244 if $shift == N {
245 let carry = $bit_shift_64_opposite(
246 carry,
247 crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
248 );
249
250 let val = $bit_shift_64($val, N);
251 return $or(val, carry).into();
252 }
253 });
254
255 return Default::default()
256 }
257 };
258}
259
260pub(crate) use bitshift_128b;
261
262impl Shr<usize> for M128 {
263 type Output = Self;
264
265 #[inline(always)]
266 fn shr(self, rhs: usize) -> Self::Output {
267 bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
270 }
271}
272
273impl Shl<usize> for M128 {
274 type Output = Self;
275
276 #[inline(always)]
277 fn shl(self, rhs: usize) -> Self::Output {
278 bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
281 }
282}
283
284impl PartialEq for M128 {
285 fn eq(&self, other: &Self) -> bool {
286 unsafe {
287 let neq = _mm_xor_si128(self.0, other.0);
288 _mm_test_all_zeros(neq, neq) == 1
289 }
290 }
291}
292
293impl Eq for M128 {}
294
295impl PartialOrd for M128 {
296 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
297 Some(self.cmp(other))
298 }
299}
300
301impl Ord for M128 {
302 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
303 u128::from(*self).cmp(&u128::from(*other))
304 }
305}
306
307impl Distribution<M128> for StandardUniform {
308 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
309 M128(rng.random())
310 }
311}
312
313impl std::fmt::Display for M128 {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 let data: u128 = (*self).into();
316 write!(f, "{data:02X?}")
317 }
318}
319
320impl std::fmt::Debug for M128 {
321 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322 write!(f, "M128({self})")
323 }
324}
325
326#[repr(align(16))]
327pub struct AlignedData(pub [u128; 1]);
328
329macro_rules! m128_from_u128 {
330 ($val:expr) => {{
331 let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
332 unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
333 }};
334}
335
336pub(super) use m128_from_u128;
337
338impl UnderlierType for M128 {
339 const LOG_BITS: usize = 7;
340}
341
342impl UnderlierWithBitOps for M128 {
343 const ZERO: Self = { Self(m128_from_u128!(0)) };
344 const ONE: Self = { Self(m128_from_u128!(1)) };
345 const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
346
347 #[inline(always)]
348 fn fill_with_bit(val: u8) -> Self {
349 assert!(val == 0 || val == 1);
350 Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
351 }
352
353 #[inline(always)]
354 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
355 where
356 T: UnderlierType,
357 Self: From<T>,
358 {
359 match T::BITS {
360 1 | 2 | 4 => {
361 let mut f = make_func_to_i8::<T, Self>(f);
362
363 unsafe {
364 _mm_set_epi8(
365 f(15),
366 f(14),
367 f(13),
368 f(12),
369 f(11),
370 f(10),
371 f(9),
372 f(8),
373 f(7),
374 f(6),
375 f(5),
376 f(4),
377 f(3),
378 f(2),
379 f(1),
380 f(0),
381 )
382 }
383 .into()
384 }
385 8 => {
386 let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
387 unsafe {
388 _mm_set_epi8(
389 f(15),
390 f(14),
391 f(13),
392 f(12),
393 f(11),
394 f(10),
395 f(9),
396 f(8),
397 f(7),
398 f(6),
399 f(5),
400 f(4),
401 f(3),
402 f(2),
403 f(1),
404 f(0),
405 )
406 }
407 .into()
408 }
409 16 => {
410 let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
411 unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
412 }
413 32 => {
414 let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
415 unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
416 }
417 64 => {
418 let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
419 unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
420 }
421 128 => Self::from(f(0)),
422 _ => panic!("unsupported bit count"),
423 }
424 }
425
426 #[inline(always)]
427 unsafe fn get_subvalue<T>(&self, i: usize) -> T
428 where
429 T: UnderlierType + NumCast<Self>,
430 {
431 match T::BITS {
432 1 | 2 | 4 => {
433 let elements_in_8 = 8 / T::BITS;
434 let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
435 *arr.get_unchecked(i / elements_in_8)
436 });
437
438 let shift = (i % elements_in_8) * T::BITS;
439 value_u8 >>= shift;
440
441 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
442 }
443 8 => {
444 let value_u8 =
445 as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
446 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
447 }
448 16 => {
449 let value_u16 =
450 as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
451 T::from_underlier(T::num_cast_from(Self::from(value_u16)))
452 }
453 32 => {
454 let value_u32 =
455 as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
456 T::from_underlier(T::num_cast_from(Self::from(value_u32)))
457 }
458 64 => {
459 let value_u64 =
460 as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
461 T::from_underlier(T::num_cast_from(Self::from(value_u64)))
462 }
463 128 => T::from_underlier(T::num_cast_from(*self)),
464 _ => panic!("unsupported bit count"),
465 }
466 }
467
468 #[inline(always)]
469 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
470 where
471 T: UnderlierWithBitOps,
472 Self: From<T>,
473 {
474 match T::BITS {
475 1 | 2 | 4 => {
476 let elements_in_8 = 8 / T::BITS;
477 let mask = (1u8 << T::BITS) - 1;
478 let shift = (i % elements_in_8) * T::BITS;
479 let val = u8::num_cast_from(Self::from(val)) << shift;
480 let mask = mask << shift;
481
482 as_array_mut::<_, u8, 16>(self, |array| unsafe {
483 let element = array.get_unchecked_mut(i / elements_in_8);
484 *element &= !mask;
485 *element |= val;
486 });
487 }
488 8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
489 *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
490 }),
491 16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
492 *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
493 }),
494 32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
495 *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
496 }),
497 64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
498 *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
499 }),
500 128 => {
501 *self = Self::from(val);
502 }
503 _ => panic!("unsupported bit count"),
504 }
505 }
506
507 #[inline(always)]
508 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
509 where
510 T: UnderlierWithBitOps + NumCast<Self>,
511 Self: From<T>,
512 {
513 match T::LOG_BITS {
514 0 => match log_block_len {
515 0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
516 1 => unsafe {
517 let bits: [u8; 2] =
518 array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
519
520 _mm_set_epi64x(
521 u64::fill_with_bit(bits[1]) as i64,
522 u64::fill_with_bit(bits[0]) as i64,
523 )
524 .into()
525 },
526 2 => unsafe {
527 let bits: [u8; 4] =
528 array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
529
530 _mm_set_epi32(
531 u32::fill_with_bit(bits[3]) as i32,
532 u32::fill_with_bit(bits[2]) as i32,
533 u32::fill_with_bit(bits[1]) as i32,
534 u32::fill_with_bit(bits[0]) as i32,
535 )
536 .into()
537 },
538 3 => unsafe {
539 let bits: [u8; 8] =
540 array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
541
542 _mm_set_epi16(
543 u16::fill_with_bit(bits[7]) as i16,
544 u16::fill_with_bit(bits[6]) as i16,
545 u16::fill_with_bit(bits[5]) as i16,
546 u16::fill_with_bit(bits[4]) as i16,
547 u16::fill_with_bit(bits[3]) as i16,
548 u16::fill_with_bit(bits[2]) as i16,
549 u16::fill_with_bit(bits[1]) as i16,
550 u16::fill_with_bit(bits[0]) as i16,
551 )
552 .into()
553 },
554 4 => unsafe {
555 let bits: [u8; 16] =
556 array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
557
558 _mm_set_epi8(
559 u8::fill_with_bit(bits[15]) as i8,
560 u8::fill_with_bit(bits[14]) as i8,
561 u8::fill_with_bit(bits[13]) as i8,
562 u8::fill_with_bit(bits[12]) as i8,
563 u8::fill_with_bit(bits[11]) as i8,
564 u8::fill_with_bit(bits[10]) as i8,
565 u8::fill_with_bit(bits[9]) as i8,
566 u8::fill_with_bit(bits[8]) as i8,
567 u8::fill_with_bit(bits[7]) as i8,
568 u8::fill_with_bit(bits[6]) as i8,
569 u8::fill_with_bit(bits[5]) as i8,
570 u8::fill_with_bit(bits[4]) as i8,
571 u8::fill_with_bit(bits[3]) as i8,
572 u8::fill_with_bit(bits[2]) as i8,
573 u8::fill_with_bit(bits[1]) as i8,
574 u8::fill_with_bit(bits[0]) as i8,
575 )
576 .into()
577 },
578 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
579 },
580 1 => match log_block_len {
581 0 => unsafe {
582 let value =
583 U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
584
585 _mm_set1_epi8(value as i8).into()
586 },
587 1 => {
588 let bytes: [u8; 2] = array::from_fn(|i| {
589 U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
590 });
591
592 Self::from_fn::<u8>(|i| bytes[i / 8])
593 }
594 2 => {
595 let bytes: [u8; 4] = array::from_fn(|i| {
596 U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
597 });
598
599 Self::from_fn::<u8>(|i| bytes[i / 4])
600 }
601 3 => {
602 let bytes: [u8; 8] = array::from_fn(|i| {
603 U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
604 .spread_to_byte()
605 });
606
607 Self::from_fn::<u8>(|i| bytes[i / 2])
608 }
609 4 => {
610 let bytes: [u8; 16] = array::from_fn(|i| {
611 U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
612 .spread_to_byte()
613 });
614
615 Self::from_fn::<u8>(|i| bytes[i])
616 }
617 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
618 },
619 2 => match log_block_len {
620 0 => {
621 let value =
622 U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
623
624 unsafe { _mm_set1_epi8(value as i8).into() }
625 }
626 1 => {
627 let values: [u8; 2] = array::from_fn(|i| {
628 U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
629 });
630
631 Self::from_fn::<u8>(|i| values[i / 8])
632 }
633 2 => {
634 let values: [u8; 4] = array::from_fn(|i| {
635 U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
636 .spread_to_byte()
637 });
638
639 Self::from_fn::<u8>(|i| values[i / 4])
640 }
641 3 => {
642 let values: [u8; 8] = array::from_fn(|i| {
643 U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
644 .spread_to_byte()
645 });
646
647 Self::from_fn::<u8>(|i| values[i / 2])
648 }
649 4 => {
650 let values: [u8; 16] = array::from_fn(|i| {
651 U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
652 .spread_to_byte()
653 });
654
655 Self::from_fn::<u8>(|i| values[i])
656 }
657 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
658 },
659 3 => match log_block_len {
660 0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
661 1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
662 2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
663 3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
664 4 => self,
665 _ => panic!("unsupported block length"),
666 },
667 4 => match log_block_len {
668 0 => {
669 let value = (u128::from(self) >> (block_idx * 16)) as u16;
670
671 unsafe { _mm_set1_epi16(value as i16).into() }
672 }
673 1 => {
674 let values: [u16; 2] =
675 array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
676
677 Self::from_fn::<u16>(|i| values[i / 4])
678 }
679 2 => {
680 let values: [u16; 4] =
681 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
682
683 Self::from_fn::<u16>(|i| values[i / 2])
684 }
685 3 => self,
686 _ => panic!("unsupported block length"),
687 },
688 5 => match log_block_len {
689 0 => unsafe {
690 let value = (u128::from(self) >> (block_idx * 32)) as u32;
691
692 _mm_set1_epi32(value as i32).into()
693 },
694 1 => {
695 let values: [u32; 2] =
696 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
697
698 Self::from_fn::<u32>(|i| values[i / 2])
699 }
700 2 => self,
701 _ => panic!("unsupported block length"),
702 },
703 6 => match log_block_len {
704 0 => unsafe {
705 let value = (u128::from(self) >> (block_idx * 64)) as u64;
706
707 _mm_set1_epi64x(value as i64).into()
708 },
709 1 => self,
710 _ => panic!("unsupported block length"),
711 },
712 7 => self,
713 _ => panic!("unsupported bit length"),
714 }
715 }
716
717 #[inline]
718 fn shl_128b_lanes(self, shift: usize) -> Self {
719 self << shift
720 }
721
722 #[inline]
723 fn shr_128b_lanes(self, shift: usize) -> Self {
724 self >> shift
725 }
726
727 #[inline]
728 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
729 match log_block_len {
730 0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
731 3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
732 4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
733 5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
734 6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
735 _ => panic!("unsupported block length"),
736 }
737 }
738
739 #[inline]
740 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
741 match log_block_len {
742 0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
743 3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
744 4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
745 5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
746 6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
747 _ => panic!("unsupported block length"),
748 }
749 }
750}
751
752unsafe impl Zeroable for M128 {}
753
754unsafe impl Pod for M128 {}
755
756unsafe impl Send for M128 {}
757
758unsafe impl Sync for M128 {}
759
760static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
761static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
762static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
763static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
764
765const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
766 log_block_len: usize,
767 t_log_bits: usize,
768) -> [M128; BLOCK_IDX_AMOUNT] {
769 let element_log_width = t_log_bits - 3;
770
771 let element_width = 1 << element_log_width;
772
773 let block_size = 1 << (log_block_len + element_log_width);
774 let repeat = 1 << (4 - element_log_width - log_block_len);
775 let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
776
777 let mut block_idx = 0;
778
779 while block_idx < BLOCK_IDX_AMOUNT {
780 let base = block_idx * block_size;
781 let mut j = 0;
782 while j < 16 {
783 masks[block_idx][j] =
784 (base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
785 j += 1;
786 }
787 block_idx += 1;
788 }
789 let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
790
791 let mut block_idx = 0;
792
793 while block_idx < BLOCK_IDX_AMOUNT {
794 m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
795 block_idx += 1;
796 }
797
798 m128_masks
799}
800
801impl UnderlierWithBitConstants for M128 {
802 const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
803 Self::from_u128(interleave_mask_even!(u128, 0)),
804 Self::from_u128(interleave_mask_even!(u128, 1)),
805 Self::from_u128(interleave_mask_even!(u128, 2)),
806 Self::from_u128(interleave_mask_even!(u128, 3)),
807 Self::from_u128(interleave_mask_even!(u128, 4)),
808 Self::from_u128(interleave_mask_even!(u128, 5)),
809 Self::from_u128(interleave_mask_even!(u128, 6)),
810 ];
811
812 const INTERLEAVE_ODD_MASK: &'static [Self] = &[
813 Self::from_u128(interleave_mask_odd!(u128, 0)),
814 Self::from_u128(interleave_mask_odd!(u128, 1)),
815 Self::from_u128(interleave_mask_odd!(u128, 2)),
816 Self::from_u128(interleave_mask_odd!(u128, 3)),
817 Self::from_u128(interleave_mask_odd!(u128, 4)),
818 Self::from_u128(interleave_mask_odd!(u128, 5)),
819 Self::from_u128(interleave_mask_odd!(u128, 6)),
820 ];
821
822 #[inline(always)]
823 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
824 unsafe {
825 let (c, d) = interleave_bits(
826 Into::<Self>::into(self).into(),
827 Into::<Self>::into(other).into(),
828 log_block_len,
829 );
830 (Self::from(c), Self::from(d))
831 }
832 }
833}
834
835impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
836 fn from(value: __m128i) -> Self {
837 M128::from(value).into()
838 }
839}
840
841impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
842 fn from(value: u128) -> Self {
843 M128::from(value).into()
844 }
845}
846
847impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
848 fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
849 value.to_underlier().into()
850 }
851}
852
853impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
854where
855 u128: From<Scalar::Underlier>,
856{
857 #[inline(always)]
858 fn broadcast(scalar: Scalar) -> Self {
859 let tower_level = Scalar::N_BITS.ilog2() as usize;
860 match tower_level {
861 0..=3 => {
862 let mut value = u128::from(scalar.to_underlier()) as u8;
863 for n in tower_level..3 {
864 value |= value << (1 << n);
865 }
866
867 unsafe { _mm_set1_epi8(value as i8) }.into()
868 }
869 4 => {
870 let value = u128::from(scalar.to_underlier()) as u16;
871 unsafe { _mm_set1_epi16(value as i16) }.into()
872 }
873 5 => {
874 let value = u128::from(scalar.to_underlier()) as u32;
875 unsafe { _mm_set1_epi32(value as i32) }.into()
876 }
877 6 => {
878 let value = u128::from(scalar.to_underlier()) as u64;
879 unsafe { _mm_set1_epi64x(value as i64) }.into()
880 }
881 7 => {
882 let value = u128::from(scalar.to_underlier());
883 value.into()
884 }
885 _ => {
886 unreachable!("invalid tower level")
887 }
888 }
889 }
890}
891
892#[inline]
893unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
894 match log_block_len {
895 0 => unsafe {
896 let mask = _mm_set1_epi8(0x55i8);
897 interleave_bits_imm::<1>(a, b, mask)
898 },
899 1 => unsafe {
900 let mask = _mm_set1_epi8(0x33i8);
901 interleave_bits_imm::<2>(a, b, mask)
902 },
903 2 => unsafe {
904 let mask = _mm_set1_epi8(0x0fi8);
905 interleave_bits_imm::<4>(a, b, mask)
906 },
907 3 => unsafe {
908 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
909 let a = _mm_shuffle_epi8(a, shuffle);
910 let b = _mm_shuffle_epi8(b, shuffle);
911 let a_prime = _mm_unpacklo_epi8(a, b);
912 let b_prime = _mm_unpackhi_epi8(a, b);
913 (a_prime, b_prime)
914 },
915 4 => unsafe {
916 let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
917 let a = _mm_shuffle_epi8(a, shuffle);
918 let b = _mm_shuffle_epi8(b, shuffle);
919 let a_prime = _mm_unpacklo_epi16(a, b);
920 let b_prime = _mm_unpackhi_epi16(a, b);
921 (a_prime, b_prime)
922 },
923 5 => unsafe {
924 let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
925 let a = _mm_shuffle_epi8(a, shuffle);
926 let b = _mm_shuffle_epi8(b, shuffle);
927 let a_prime = _mm_unpacklo_epi32(a, b);
928 let b_prime = _mm_unpackhi_epi32(a, b);
929 (a_prime, b_prime)
930 },
931 6 => unsafe {
932 let a_prime = _mm_unpacklo_epi64(a, b);
933 let b_prime = _mm_unpackhi_epi64(a, b);
934 (a_prime, b_prime)
935 },
936 _ => panic!("unsupported block length"),
937 }
938}
939
940#[inline]
941unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
942 a: __m128i,
943 b: __m128i,
944 mask: __m128i,
945) -> (__m128i, __m128i) {
946 unsafe {
947 let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
948 let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
949 let b_prime = _mm_xor_si128(b, t);
950 (a_prime, b_prime)
951 }
952}
953
954impl_iteration!(M128,
955 @strategy BitIterationStrategy, U1,
956 @strategy FallbackStrategy, U2, U4,
957 @strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
958);
959
960#[cfg(test)]
961mod tests {
962 use binius_utils::bytes::BytesMut;
963 use proptest::{arbitrary::any, proptest};
964 use rand::{SeedableRng, rngs::StdRng};
965
966 use super::*;
967 use crate::underlier::single_element_mask_bits;
968
969 fn check_roundtrip<T>(val: M128)
970 where
971 T: From<M128>,
972 M128: From<T>,
973 {
974 assert_eq!(M128::from(T::from(val)), val);
975 }
976
977 #[test]
978 fn test_constants() {
979 assert_eq!(M128::default(), M128::ZERO);
980 assert_eq!(M128::from(0u128), M128::ZERO);
981 assert_eq!(M128::from(1u128), M128::ONE);
982 }
983
984 fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
985 (value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
986 }
987
988 proptest! {
989 #[test]
990 fn test_conversion(a in any::<u128>()) {
991 check_roundtrip::<u128>(a.into());
992 check_roundtrip::<__m128i>(a.into());
993 }
994
995 #[test]
996 fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
997 assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
998 assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
999 assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1000 }
1001
1002 #[test]
1003 fn test_negate(a in any::<u128>()) {
1004 assert_eq!(M128::from(!a), !M128::from(a))
1005 }
1006
1007 #[test]
1008 fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1009 assert_eq!(M128::from(a << b), M128::from(a) << b);
1010 assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1011 }
1012
1013 #[test]
1014 fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1015 let a = M128::from(a);
1016 let b = M128::from(b);
1017
1018 let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1019 let (c, d) = (M128::from(c), M128::from(d));
1020
1021 for i in (0..128>>height).step_by(2) {
1022 assert_eq!(get(c, height, i), get(a, height, i));
1023 assert_eq!(get(c, height, i+1), get(b, height, i));
1024 assert_eq!(get(d, height, i), get(a, height, i+1));
1025 assert_eq!(get(d, height, i+1), get(b, height, i+1));
1026 }
1027 }
1028
1029 #[test]
1030 fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1031 let a = M128::from(a);
1032 let b = M128::from(b);
1033
1034 let result = a.unpack_lo_128b_lanes(b, height);
1035 for i in 0..128>>(height + 1) {
1036 assert_eq!(get(result, height, 2*i), get(a, height, i));
1037 assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1038 }
1039 }
1040
1041 #[test]
1042 fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1043 let a = M128::from(a);
1044 let b = M128::from(b);
1045
1046 let result = a.unpack_hi_128b_lanes(b, height);
1047 let half_block_count = 128>>(height + 1);
1048 for i in 0..half_block_count {
1049 assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1050 assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1051 }
1052 }
1053 }
1054
1055 #[test]
1056 fn test_fill_with_bit() {
1057 assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1058 assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1059 }
1060
1061 #[test]
1062 fn test_eq() {
1063 let a = M128::from(0u128);
1064 let b = M128::from(42u128);
1065 let c = M128::from(u128::MAX);
1066
1067 assert_eq!(a, a);
1068 assert_eq!(b, b);
1069 assert_eq!(c, c);
1070
1071 assert_ne!(a, b);
1072 assert_ne!(a, c);
1073 assert_ne!(b, c);
1074 }
1075
1076 #[test]
1077 fn test_serialize_and_deserialize_m128() {
1078 let mut rng = StdRng::from_seed([0; 32]);
1079
1080 let original_value = M128::from(rng.random::<u128>());
1081
1082 let mut buf = BytesMut::new();
1083 original_value.serialize(&mut buf).unwrap();
1084
1085 let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1086
1087 assert_eq!(original_value, deserialized_value);
1088 }
1089}