1use std::{
4 arch::x86_64::*,
5 array,
6 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use bytemuck::{must_cast, Pod, Zeroable};
10use rand::{Rng, RngCore};
11use seq_macro::seq;
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
13
14use crate::{
15 arch::{
16 binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
17 portable::{
18 packed::{impl_pack_scalar, PackedPrimitiveType},
19 packed_arithmetic::{
20 interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants,
21 },
22 },
23 },
24 arithmetic_traits::Broadcast,
25 underlier::{
26 impl_divisible, impl_iteration, spread_fallback, unpack_hi_128b_fallback,
27 unpack_lo_128b_fallback, NumCast, Random, SmallU, SpreadToByte, UnderlierType,
28 UnderlierWithBitOps, WithUnderlier, U1, U2, U4,
29 },
30 BinaryField,
31};
32
33#[derive(Copy, Clone)]
35#[repr(transparent)]
36pub struct M128(pub(super) __m128i);
37
38impl M128 {
39 #[inline(always)]
40 pub const fn from_u128(val: u128) -> Self {
41 let mut result = Self::ZERO;
42 unsafe {
43 result.0 = std::mem::transmute_copy(&val);
44 }
45
46 result
47 }
48}
49
50impl From<__m128i> for M128 {
51 #[inline(always)]
52 fn from(value: __m128i) -> Self {
53 Self(value)
54 }
55}
56
57impl From<u128> for M128 {
58 fn from(value: u128) -> Self {
59 Self(unsafe { _mm_loadu_si128(&raw const value as *const __m128i) })
60 }
61}
62
63impl From<u64> for M128 {
64 fn from(value: u64) -> Self {
65 Self::from(value as u128)
66 }
67}
68
69impl From<u32> for M128 {
70 fn from(value: u32) -> Self {
71 Self::from(value as u128)
72 }
73}
74
75impl From<u16> for M128 {
76 fn from(value: u16) -> Self {
77 Self::from(value as u128)
78 }
79}
80
81impl From<u8> for M128 {
82 fn from(value: u8) -> Self {
83 Self::from(value as u128)
84 }
85}
86
87impl<const N: usize> From<SmallU<N>> for M128 {
88 fn from(value: SmallU<N>) -> Self {
89 Self::from(value.val() as u128)
90 }
91}
92
93impl From<M128> for u128 {
94 fn from(value: M128) -> Self {
95 let mut result = 0u128;
96 unsafe { _mm_storeu_si128(&raw mut result as *mut __m128i, value.0) };
97
98 result
99 }
100}
101
102impl From<M128> for __m128i {
103 #[inline(always)]
104 fn from(value: M128) -> Self {
105 value.0
106 }
107}
108
109impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
110impl_pack_scalar!(M128);
111
112impl<U: NumCast<u128>> NumCast<M128> for U {
113 #[inline(always)]
114 fn num_cast_from(val: M128) -> Self {
115 Self::num_cast_from(u128::from(val))
116 }
117}
118
119impl Default for M128 {
120 #[inline(always)]
121 fn default() -> Self {
122 Self(unsafe { _mm_setzero_si128() })
123 }
124}
125
126impl BitAnd for M128 {
127 type Output = Self;
128
129 #[inline(always)]
130 fn bitand(self, rhs: Self) -> Self::Output {
131 Self(unsafe { _mm_and_si128(self.0, rhs.0) })
132 }
133}
134
135impl BitAndAssign for M128 {
136 #[inline(always)]
137 fn bitand_assign(&mut self, rhs: Self) {
138 *self = *self & rhs
139 }
140}
141
142impl BitOr for M128 {
143 type Output = Self;
144
145 #[inline(always)]
146 fn bitor(self, rhs: Self) -> Self::Output {
147 Self(unsafe { _mm_or_si128(self.0, rhs.0) })
148 }
149}
150
151impl BitOrAssign for M128 {
152 #[inline(always)]
153 fn bitor_assign(&mut self, rhs: Self) {
154 *self = *self | rhs
155 }
156}
157
158impl BitXor for M128 {
159 type Output = Self;
160
161 #[inline(always)]
162 fn bitxor(self, rhs: Self) -> Self::Output {
163 Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
164 }
165}
166
167impl BitXorAssign for M128 {
168 #[inline(always)]
169 fn bitxor_assign(&mut self, rhs: Self) {
170 *self = *self ^ rhs;
171 }
172}
173
174impl Not for M128 {
175 type Output = Self;
176
177 fn not(self) -> Self::Output {
178 const ONES: __m128i = m128_from_u128!(u128::MAX);
179
180 self ^ Self(ONES)
181 }
182}
183
184pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
186 if left > right {
187 left
188 } else {
189 right
190 }
191}
192
193macro_rules! bitshift_128b {
198 ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
199 unsafe {
200 let carry = $byte_shift($val, 8);
201 seq!(N in 64..128 {
202 if $shift == N {
203 return $bit_shift_64(
204 carry,
205 crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
206 ).into();
207 }
208 });
209 seq!(N in 0..64 {
210 if $shift == N {
211 let carry = $bit_shift_64_opposite(
212 carry,
213 crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
214 );
215
216 let val = $bit_shift_64($val, N);
217 return $or(val, carry).into();
218 }
219 });
220
221 return Default::default()
222 }
223 };
224}
225
226pub(crate) use bitshift_128b;
227
228impl Shr<usize> for M128 {
229 type Output = Self;
230
231 #[inline(always)]
232 fn shr(self, rhs: usize) -> Self::Output {
233 bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
236 }
237}
238
239impl Shl<usize> for M128 {
240 type Output = Self;
241
242 #[inline(always)]
243 fn shl(self, rhs: usize) -> Self::Output {
244 bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
247 }
248}
249
250impl PartialEq for M128 {
251 fn eq(&self, other: &Self) -> bool {
252 unsafe {
253 let neq = _mm_xor_si128(self.0, other.0);
254 _mm_test_all_zeros(neq, neq) == 1
255 }
256 }
257}
258
259impl Eq for M128 {}
260
261impl PartialOrd for M128 {
262 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
263 Some(self.cmp(other))
264 }
265}
266
267impl Ord for M128 {
268 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
269 u128::from(*self).cmp(&u128::from(*other))
270 }
271}
272
273impl ConstantTimeEq for M128 {
274 fn ct_eq(&self, other: &Self) -> Choice {
275 unsafe {
276 let neq = _mm_xor_si128(self.0, other.0);
277 Choice::from(_mm_test_all_zeros(neq, neq) as u8)
278 }
279 }
280}
281
282impl ConditionallySelectable for M128 {
283 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
284 ConditionallySelectable::conditional_select(&u128::from(*a), &u128::from(*b), choice).into()
285 }
286}
287
288impl Random for M128 {
289 fn random(mut rng: impl RngCore) -> Self {
290 let val: u128 = rng.gen();
291 val.into()
292 }
293}
294
295impl std::fmt::Display for M128 {
296 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297 let data: u128 = (*self).into();
298 write!(f, "{data:02X?}")
299 }
300}
301
302impl std::fmt::Debug for M128 {
303 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 write!(f, "M128({})", self)
305 }
306}
307
308#[repr(align(16))]
309pub struct AlignedData(pub [u128; 1]);
310
311macro_rules! m128_from_u128 {
312 ($val:expr) => {{
313 let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
314 unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
315 }};
316}
317
318pub(super) use m128_from_u128;
319
320impl UnderlierType for M128 {
321 const LOG_BITS: usize = 7;
322}
323
324impl UnderlierWithBitOps for M128 {
325 const ZERO: Self = { Self(m128_from_u128!(0)) };
326 const ONE: Self = { Self(m128_from_u128!(1)) };
327 const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
328
329 #[inline(always)]
330 fn fill_with_bit(val: u8) -> Self {
331 assert!(val == 0 || val == 1);
332 Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
333 }
334
335 #[inline(always)]
336 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
337 where
338 T: UnderlierType,
339 Self: From<T>,
340 {
341 match T::BITS {
342 1 | 2 | 4 => {
343 let mut f = make_func_to_i8::<T, Self>(f);
344
345 unsafe {
346 _mm_set_epi8(
347 f(15),
348 f(14),
349 f(13),
350 f(12),
351 f(11),
352 f(10),
353 f(9),
354 f(8),
355 f(7),
356 f(6),
357 f(5),
358 f(4),
359 f(3),
360 f(2),
361 f(1),
362 f(0),
363 )
364 }
365 .into()
366 }
367 8 => {
368 let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
369 unsafe {
370 _mm_set_epi8(
371 f(15),
372 f(14),
373 f(13),
374 f(12),
375 f(11),
376 f(10),
377 f(9),
378 f(8),
379 f(7),
380 f(6),
381 f(5),
382 f(4),
383 f(3),
384 f(2),
385 f(1),
386 f(0),
387 )
388 }
389 .into()
390 }
391 16 => {
392 let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
393 unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
394 }
395 32 => {
396 let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
397 unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
398 }
399 64 => {
400 let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
401 unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
402 }
403 128 => Self::from(f(0)),
404 _ => panic!("unsupported bit count"),
405 }
406 }
407
408 #[inline(always)]
409 unsafe fn get_subvalue<T>(&self, i: usize) -> T
410 where
411 T: UnderlierType + NumCast<Self>,
412 {
413 match T::BITS {
414 1 | 2 | 4 => {
415 let elements_in_8 = 8 / T::BITS;
416 let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
417 *arr.get_unchecked(i / elements_in_8)
418 });
419
420 let shift = (i % elements_in_8) * T::BITS;
421 value_u8 >>= shift;
422
423 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
424 }
425 8 => {
426 let value_u8 =
427 as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
428 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
429 }
430 16 => {
431 let value_u16 =
432 as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
433 T::from_underlier(T::num_cast_from(Self::from(value_u16)))
434 }
435 32 => {
436 let value_u32 =
437 as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
438 T::from_underlier(T::num_cast_from(Self::from(value_u32)))
439 }
440 64 => {
441 let value_u64 =
442 as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
443 T::from_underlier(T::num_cast_from(Self::from(value_u64)))
444 }
445 128 => T::from_underlier(T::num_cast_from(*self)),
446 _ => panic!("unsupported bit count"),
447 }
448 }
449
450 #[inline(always)]
451 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
452 where
453 T: UnderlierWithBitOps,
454 Self: From<T>,
455 {
456 match T::BITS {
457 1 | 2 | 4 => {
458 let elements_in_8 = 8 / T::BITS;
459 let mask = (1u8 << T::BITS) - 1;
460 let shift = (i % elements_in_8) * T::BITS;
461 let val = u8::num_cast_from(Self::from(val)) << shift;
462 let mask = mask << shift;
463
464 as_array_mut::<_, u8, 16>(self, |array| unsafe {
465 let element = array.get_unchecked_mut(i / elements_in_8);
466 *element &= !mask;
467 *element |= val;
468 });
469 }
470 8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
471 *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
472 }),
473 16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
474 *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
475 }),
476 32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
477 *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
478 }),
479 64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
480 *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
481 }),
482 128 => {
483 *self = Self::from(val);
484 }
485 _ => panic!("unsupported bit count"),
486 }
487 }
488
489 #[inline(always)]
490 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
491 where
492 T: UnderlierWithBitOps + NumCast<Self>,
493 Self: From<T>,
494 {
495 match T::LOG_BITS {
496 0 => match log_block_len {
497 0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
498 1 => {
499 let bits: [u8; 2] =
500 array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
501
502 _mm_set_epi64x(
503 u64::fill_with_bit(bits[1]) as i64,
504 u64::fill_with_bit(bits[0]) as i64,
505 )
506 .into()
507 }
508 2 => {
509 let bits: [u8; 4] =
510 array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
511
512 _mm_set_epi32(
513 u32::fill_with_bit(bits[3]) as i32,
514 u32::fill_with_bit(bits[2]) as i32,
515 u32::fill_with_bit(bits[1]) as i32,
516 u32::fill_with_bit(bits[0]) as i32,
517 )
518 .into()
519 }
520 3 => {
521 let bits: [u8; 8] =
522 array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
523
524 _mm_set_epi16(
525 u16::fill_with_bit(bits[7]) as i16,
526 u16::fill_with_bit(bits[6]) as i16,
527 u16::fill_with_bit(bits[5]) as i16,
528 u16::fill_with_bit(bits[4]) as i16,
529 u16::fill_with_bit(bits[3]) as i16,
530 u16::fill_with_bit(bits[2]) as i16,
531 u16::fill_with_bit(bits[1]) as i16,
532 u16::fill_with_bit(bits[0]) as i16,
533 )
534 .into()
535 }
536 4 => {
537 let bits: [u8; 16] =
538 array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
539
540 _mm_set_epi8(
541 u8::fill_with_bit(bits[15]) as i8,
542 u8::fill_with_bit(bits[14]) as i8,
543 u8::fill_with_bit(bits[13]) as i8,
544 u8::fill_with_bit(bits[12]) as i8,
545 u8::fill_with_bit(bits[11]) as i8,
546 u8::fill_with_bit(bits[10]) as i8,
547 u8::fill_with_bit(bits[9]) as i8,
548 u8::fill_with_bit(bits[8]) as i8,
549 u8::fill_with_bit(bits[7]) as i8,
550 u8::fill_with_bit(bits[6]) as i8,
551 u8::fill_with_bit(bits[5]) as i8,
552 u8::fill_with_bit(bits[4]) as i8,
553 u8::fill_with_bit(bits[3]) as i8,
554 u8::fill_with_bit(bits[2]) as i8,
555 u8::fill_with_bit(bits[1]) as i8,
556 u8::fill_with_bit(bits[0]) as i8,
557 )
558 .into()
559 }
560 _ => spread_fallback(self, log_block_len, block_idx),
561 },
562 1 => match log_block_len {
563 0 => {
564 let value =
565 U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
566
567 _mm_set1_epi8(value as i8).into()
568 }
569 1 => {
570 let bytes: [u8; 2] = array::from_fn(|i| {
571 U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
572 });
573
574 Self::from_fn::<u8>(|i| bytes[i / 8])
575 }
576 2 => {
577 let bytes: [u8; 4] = array::from_fn(|i| {
578 U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
579 });
580
581 Self::from_fn::<u8>(|i| bytes[i / 4])
582 }
583 3 => {
584 let bytes: [u8; 8] = array::from_fn(|i| {
585 U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
586 .spread_to_byte()
587 });
588
589 Self::from_fn::<u8>(|i| bytes[i / 2])
590 }
591 4 => {
592 let bytes: [u8; 16] = array::from_fn(|i| {
593 U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
594 .spread_to_byte()
595 });
596
597 Self::from_fn::<u8>(|i| bytes[i])
598 }
599 _ => spread_fallback(self, log_block_len, block_idx),
600 },
601 2 => match log_block_len {
602 0 => {
603 let value =
604 U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
605
606 _mm_set1_epi8(value as i8).into()
607 }
608 1 => {
609 let values: [u8; 2] = array::from_fn(|i| {
610 U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
611 });
612
613 Self::from_fn::<u8>(|i| values[i / 8])
614 }
615 2 => {
616 let values: [u8; 4] = array::from_fn(|i| {
617 U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
618 .spread_to_byte()
619 });
620
621 Self::from_fn::<u8>(|i| values[i / 4])
622 }
623 3 => {
624 let values: [u8; 8] = array::from_fn(|i| {
625 U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
626 .spread_to_byte()
627 });
628
629 Self::from_fn::<u8>(|i| values[i / 2])
630 }
631 4 => {
632 let values: [u8; 16] = array::from_fn(|i| {
633 U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
634 .spread_to_byte()
635 });
636
637 Self::from_fn::<u8>(|i| values[i])
638 }
639 _ => spread_fallback(self, log_block_len, block_idx),
640 },
641 3 => match log_block_len {
642 0 => _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into(),
643 1 => _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into(),
644 2 => _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into(),
645 3 => _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into(),
646 4 => self,
647 _ => panic!("unsupported block length"),
648 },
649 4 => match log_block_len {
650 0 => {
651 let value = (u128::from(self) >> (block_idx * 16)) as u16;
652
653 _mm_set1_epi16(value as i16).into()
654 }
655 1 => {
656 let values: [u16; 2] =
657 array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
658
659 Self::from_fn::<u16>(|i| values[i / 4])
660 }
661 2 => {
662 let values: [u16; 4] =
663 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
664
665 Self::from_fn::<u16>(|i| values[i / 2])
666 }
667 3 => self,
668 _ => panic!("unsupported block length"),
669 },
670 5 => match log_block_len {
671 0 => {
672 let value = (u128::from(self) >> (block_idx * 32)) as u32;
673
674 _mm_set1_epi32(value as i32).into()
675 }
676 1 => {
677 let values: [u32; 2] =
678 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
679
680 Self::from_fn::<u32>(|i| values[i / 2])
681 }
682 2 => self,
683 _ => panic!("unsupported block length"),
684 },
685 6 => match log_block_len {
686 0 => {
687 let value = (u128::from(self) >> (block_idx * 64)) as u64;
688
689 _mm_set1_epi64x(value as i64).into()
690 }
691 1 => self,
692 _ => panic!("unsupported block length"),
693 },
694 7 => self,
695 _ => panic!("unsupported bit length"),
696 }
697 }
698
699 #[inline]
700 fn shl_128b_lanes(self, shift: usize) -> Self {
701 self << shift
702 }
703
704 #[inline]
705 fn shr_128b_lanes(self, shift: usize) -> Self {
706 self >> shift
707 }
708
709 #[inline]
710 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
711 match log_block_len {
712 0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
713 3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
714 4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
715 5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
716 6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
717 _ => panic!("unsupported block length"),
718 }
719 }
720
721 #[inline]
722 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
723 match log_block_len {
724 0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
725 3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
726 4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
727 5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
728 6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
729 _ => panic!("unsupported block length"),
730 }
731 }
732}
733
734unsafe impl Zeroable for M128 {}
735
736unsafe impl Pod for M128 {}
737
738unsafe impl Send for M128 {}
739
740unsafe impl Sync for M128 {}
741
742static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
743static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
744static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
745static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
746
747const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
748 log_block_len: usize,
749 t_log_bits: usize,
750) -> [M128; BLOCK_IDX_AMOUNT] {
751 let element_log_width = t_log_bits - 3;
752
753 let element_width = 1 << element_log_width;
754
755 let block_size = 1 << (log_block_len + element_log_width);
756 let repeat = 1 << (4 - element_log_width - log_block_len);
757 let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
758
759 let mut block_idx = 0;
760
761 while block_idx < BLOCK_IDX_AMOUNT {
762 let base = block_idx * block_size;
763 let mut j = 0;
764 while j < 16 {
765 masks[block_idx][j] =
766 (base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
767 j += 1;
768 }
769 block_idx += 1;
770 }
771 let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
772
773 let mut block_idx = 0;
774
775 while block_idx < BLOCK_IDX_AMOUNT {
776 m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
777 block_idx += 1;
778 }
779
780 m128_masks
781}
782
783impl UnderlierWithBitConstants for M128 {
784 const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
785 Self::from_u128(interleave_mask_even!(u128, 0)),
786 Self::from_u128(interleave_mask_even!(u128, 1)),
787 Self::from_u128(interleave_mask_even!(u128, 2)),
788 Self::from_u128(interleave_mask_even!(u128, 3)),
789 Self::from_u128(interleave_mask_even!(u128, 4)),
790 Self::from_u128(interleave_mask_even!(u128, 5)),
791 Self::from_u128(interleave_mask_even!(u128, 6)),
792 ];
793
794 const INTERLEAVE_ODD_MASK: &'static [Self] = &[
795 Self::from_u128(interleave_mask_odd!(u128, 0)),
796 Self::from_u128(interleave_mask_odd!(u128, 1)),
797 Self::from_u128(interleave_mask_odd!(u128, 2)),
798 Self::from_u128(interleave_mask_odd!(u128, 3)),
799 Self::from_u128(interleave_mask_odd!(u128, 4)),
800 Self::from_u128(interleave_mask_odd!(u128, 5)),
801 Self::from_u128(interleave_mask_odd!(u128, 6)),
802 ];
803
804 #[inline(always)]
805 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
806 unsafe {
807 let (c, d) = interleave_bits(
808 Into::<Self>::into(self).into(),
809 Into::<Self>::into(other).into(),
810 log_block_len,
811 );
812 (Self::from(c), Self::from(d))
813 }
814 }
815}
816
817impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
818 fn from(value: __m128i) -> Self {
819 M128::from(value).into()
820 }
821}
822
823impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
824 fn from(value: u128) -> Self {
825 M128::from(value).into()
826 }
827}
828
829impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
830 fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
831 value.to_underlier().into()
832 }
833}
834
835impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
836where
837 u128: From<Scalar::Underlier>,
838{
839 #[inline(always)]
840 fn broadcast(scalar: Scalar) -> Self {
841 let tower_level = Scalar::N_BITS.ilog2() as usize;
842 let mut value = u128::from(scalar.to_underlier());
843 for n in tower_level..3 {
844 value |= value << (1 << n);
845 }
846
847 let value = must_cast(value);
848 let value = match tower_level {
849 0..=3 => unsafe { _mm_broadcastb_epi8(value) },
850 4 => unsafe { _mm_broadcastw_epi16(value) },
851 5 => unsafe { _mm_broadcastd_epi32(value) },
852 6 => unsafe { _mm_broadcastq_epi64(value) },
853 7 => value,
854 _ => unreachable!(),
855 };
856
857 value.into()
858 }
859}
860
861#[inline]
862unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
863 match log_block_len {
864 0 => {
865 let mask = _mm_set1_epi8(0x55i8);
866 interleave_bits_imm::<1>(a, b, mask)
867 }
868 1 => {
869 let mask = _mm_set1_epi8(0x33i8);
870 interleave_bits_imm::<2>(a, b, mask)
871 }
872 2 => {
873 let mask = _mm_set1_epi8(0x0fi8);
874 interleave_bits_imm::<4>(a, b, mask)
875 }
876 3 => {
877 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
878 let a = _mm_shuffle_epi8(a, shuffle);
879 let b = _mm_shuffle_epi8(b, shuffle);
880 let a_prime = _mm_unpacklo_epi8(a, b);
881 let b_prime = _mm_unpackhi_epi8(a, b);
882 (a_prime, b_prime)
883 }
884 4 => {
885 let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
886 let a = _mm_shuffle_epi8(a, shuffle);
887 let b = _mm_shuffle_epi8(b, shuffle);
888 let a_prime = _mm_unpacklo_epi16(a, b);
889 let b_prime = _mm_unpackhi_epi16(a, b);
890 (a_prime, b_prime)
891 }
892 5 => {
893 let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
894 let a = _mm_shuffle_epi8(a, shuffle);
895 let b = _mm_shuffle_epi8(b, shuffle);
896 let a_prime = _mm_unpacklo_epi32(a, b);
897 let b_prime = _mm_unpackhi_epi32(a, b);
898 (a_prime, b_prime)
899 }
900 6 => {
901 let a_prime = _mm_unpacklo_epi64(a, b);
902 let b_prime = _mm_unpackhi_epi64(a, b);
903 (a_prime, b_prime)
904 }
905 _ => panic!("unsupported block length"),
906 }
907}
908
909#[inline]
910unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
911 a: __m128i,
912 b: __m128i,
913 mask: __m128i,
914) -> (__m128i, __m128i) {
915 let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
916 let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
917 let b_prime = _mm_xor_si128(b, t);
918 (a_prime, b_prime)
919}
920
921impl_iteration!(M128,
922 @strategy BitIterationStrategy, U1,
923 @strategy FallbackStrategy, U2, U4,
924 @strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
925);
926
927#[cfg(test)]
928mod tests {
929 use proptest::{arbitrary::any, proptest};
930
931 use super::*;
932 use crate::underlier::single_element_mask_bits;
933
934 fn check_roundtrip<T>(val: M128)
935 where
936 T: From<M128>,
937 M128: From<T>,
938 {
939 assert_eq!(M128::from(T::from(val)), val);
940 }
941
942 #[test]
943 fn test_constants() {
944 assert_eq!(M128::default(), M128::ZERO);
945 assert_eq!(M128::from(0u128), M128::ZERO);
946 assert_eq!(M128::from(1u128), M128::ONE);
947 }
948
949 fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
950 (value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
951 }
952
953 proptest! {
954 #[test]
955 fn test_conversion(a in any::<u128>()) {
956 check_roundtrip::<u128>(a.into());
957 check_roundtrip::<__m128i>(a.into());
958 }
959
960 #[test]
961 fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
962 assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
963 assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
964 assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
965 }
966
967 #[test]
968 fn test_negate(a in any::<u128>()) {
969 assert_eq!(M128::from(!a), !M128::from(a))
970 }
971
972 #[test]
973 fn test_shifts(a in any::<u128>(), b in 0..128usize) {
974 assert_eq!(M128::from(a << b), M128::from(a) << b);
975 assert_eq!(M128::from(a >> b), M128::from(a) >> b);
976 }
977
978 #[test]
979 fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
980 let a = M128::from(a);
981 let b = M128::from(b);
982
983 let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
984 let (c, d) = (M128::from(c), M128::from(d));
985
986 for i in (0..128>>height).step_by(2) {
987 assert_eq!(get(c, height, i), get(a, height, i));
988 assert_eq!(get(c, height, i+1), get(b, height, i));
989 assert_eq!(get(d, height, i), get(a, height, i+1));
990 assert_eq!(get(d, height, i+1), get(b, height, i+1));
991 }
992 }
993
994 #[test]
995 fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
996 let a = M128::from(a);
997 let b = M128::from(b);
998
999 let result = a.unpack_lo_128b_lanes(b, height);
1000 for i in 0..128>>(height + 1) {
1001 assert_eq!(get(result, height, 2*i), get(a, height, i));
1002 assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1003 }
1004 }
1005
1006 #[test]
1007 fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1008 let a = M128::from(a);
1009 let b = M128::from(b);
1010
1011 let result = a.unpack_hi_128b_lanes(b, height);
1012 let half_block_count = 128>>(height + 1);
1013 for i in 0..half_block_count {
1014 assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1015 assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1016 }
1017 }
1018 }
1019
1020 #[test]
1021 fn test_fill_with_bit() {
1022 assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1023 assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1024 }
1025
1026 #[test]
1027 fn test_eq() {
1028 let a = M128::from(0u128);
1029 let b = M128::from(42u128);
1030 let c = M128::from(u128::MAX);
1031
1032 assert_eq!(a, a);
1033 assert_eq!(b, b);
1034 assert_eq!(c, c);
1035
1036 assert_ne!(a, b);
1037 assert_ne!(a, c);
1038 assert_ne!(b, c);
1039 }
1040}