1use std::{
4 arch::x86_64::*,
5 array,
6 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use bytemuck::{must_cast, Pod, Zeroable};
10use rand::{Rng, RngCore};
11use seq_macro::seq;
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
13
14use crate::{
15 arch::{
16 binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
17 portable::{
18 packed::{impl_pack_scalar, PackedPrimitiveType},
19 packed_arithmetic::{
20 interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants,
21 },
22 },
23 },
24 arithmetic_traits::Broadcast,
25 tower_levels::TowerLevel,
26 underlier::{
27 impl_divisible, impl_iteration, spread_fallback, transpose_128b_values,
28 unpack_hi_128b_fallback, unpack_lo_128b_fallback, NumCast, Random, SmallU, SpreadToByte,
29 UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4,
30 },
31 BinaryField,
32};
33
34#[derive(Copy, Clone)]
36#[repr(transparent)]
37pub struct M128(pub(super) __m128i);
38
39impl M128 {
40 #[inline(always)]
41 pub const fn from_u128(val: u128) -> Self {
42 let mut result = Self::ZERO;
43 unsafe {
44 result.0 = std::mem::transmute_copy(&val);
45 }
46
47 result
48 }
49}
50
51impl From<__m128i> for M128 {
52 #[inline(always)]
53 fn from(value: __m128i) -> Self {
54 Self(value)
55 }
56}
57
58impl From<u128> for M128 {
59 fn from(value: u128) -> Self {
60 Self(unsafe { _mm_loadu_si128(&raw const value as *const __m128i) })
61 }
62}
63
64impl From<u64> for M128 {
65 fn from(value: u64) -> Self {
66 Self::from(value as u128)
67 }
68}
69
70impl From<u32> for M128 {
71 fn from(value: u32) -> Self {
72 Self::from(value as u128)
73 }
74}
75
76impl From<u16> for M128 {
77 fn from(value: u16) -> Self {
78 Self::from(value as u128)
79 }
80}
81
82impl From<u8> for M128 {
83 fn from(value: u8) -> Self {
84 Self::from(value as u128)
85 }
86}
87
88impl<const N: usize> From<SmallU<N>> for M128 {
89 fn from(value: SmallU<N>) -> Self {
90 Self::from(value.val() as u128)
91 }
92}
93
94impl From<M128> for u128 {
95 fn from(value: M128) -> Self {
96 let mut result = 0u128;
97 unsafe { _mm_storeu_si128(&raw mut result as *mut __m128i, value.0) };
98
99 result
100 }
101}
102
103impl From<M128> for __m128i {
104 #[inline(always)]
105 fn from(value: M128) -> Self {
106 value.0
107 }
108}
109
110impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
111impl_pack_scalar!(M128);
112
113impl<U: NumCast<u128>> NumCast<M128> for U {
114 #[inline(always)]
115 fn num_cast_from(val: M128) -> Self {
116 Self::num_cast_from(u128::from(val))
117 }
118}
119
120impl Default for M128 {
121 #[inline(always)]
122 fn default() -> Self {
123 Self(unsafe { _mm_setzero_si128() })
124 }
125}
126
127impl BitAnd for M128 {
128 type Output = Self;
129
130 #[inline(always)]
131 fn bitand(self, rhs: Self) -> Self::Output {
132 Self(unsafe { _mm_and_si128(self.0, rhs.0) })
133 }
134}
135
136impl BitAndAssign for M128 {
137 #[inline(always)]
138 fn bitand_assign(&mut self, rhs: Self) {
139 *self = *self & rhs
140 }
141}
142
143impl BitOr for M128 {
144 type Output = Self;
145
146 #[inline(always)]
147 fn bitor(self, rhs: Self) -> Self::Output {
148 Self(unsafe { _mm_or_si128(self.0, rhs.0) })
149 }
150}
151
152impl BitOrAssign for M128 {
153 #[inline(always)]
154 fn bitor_assign(&mut self, rhs: Self) {
155 *self = *self | rhs
156 }
157}
158
159impl BitXor for M128 {
160 type Output = Self;
161
162 #[inline(always)]
163 fn bitxor(self, rhs: Self) -> Self::Output {
164 Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
165 }
166}
167
168impl BitXorAssign for M128 {
169 #[inline(always)]
170 fn bitxor_assign(&mut self, rhs: Self) {
171 *self = *self ^ rhs;
172 }
173}
174
175impl Not for M128 {
176 type Output = Self;
177
178 fn not(self) -> Self::Output {
179 const ONES: __m128i = m128_from_u128!(u128::MAX);
180
181 self ^ Self(ONES)
182 }
183}
184
185pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
187 if left > right {
188 left
189 } else {
190 right
191 }
192}
193
194macro_rules! bitshift_128b {
199 ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
200 unsafe {
201 let carry = $byte_shift($val, 8);
202 seq!(N in 64..128 {
203 if $shift == N {
204 return $bit_shift_64(
205 carry,
206 crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
207 ).into();
208 }
209 });
210 seq!(N in 0..64 {
211 if $shift == N {
212 let carry = $bit_shift_64_opposite(
213 carry,
214 crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
215 );
216
217 let val = $bit_shift_64($val, N);
218 return $or(val, carry).into();
219 }
220 });
221
222 return Default::default()
223 }
224 };
225}
226
227pub(crate) use bitshift_128b;
228
229impl Shr<usize> for M128 {
230 type Output = Self;
231
232 #[inline(always)]
233 fn shr(self, rhs: usize) -> Self::Output {
234 bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
237 }
238}
239
240impl Shl<usize> for M128 {
241 type Output = Self;
242
243 #[inline(always)]
244 fn shl(self, rhs: usize) -> Self::Output {
245 bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
248 }
249}
250
251impl PartialEq for M128 {
252 fn eq(&self, other: &Self) -> bool {
253 unsafe {
254 let neq = _mm_xor_si128(self.0, other.0);
255 _mm_test_all_zeros(neq, neq) == 1
256 }
257 }
258}
259
260impl Eq for M128 {}
261
262impl PartialOrd for M128 {
263 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
264 Some(self.cmp(other))
265 }
266}
267
268impl Ord for M128 {
269 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
270 u128::from(*self).cmp(&u128::from(*other))
271 }
272}
273
274impl ConstantTimeEq for M128 {
275 fn ct_eq(&self, other: &Self) -> Choice {
276 unsafe {
277 let neq = _mm_xor_si128(self.0, other.0);
278 Choice::from(_mm_test_all_zeros(neq, neq) as u8)
279 }
280 }
281}
282
283impl ConditionallySelectable for M128 {
284 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
285 ConditionallySelectable::conditional_select(&u128::from(*a), &u128::from(*b), choice).into()
286 }
287}
288
289impl Random for M128 {
290 fn random(mut rng: impl RngCore) -> Self {
291 let val: u128 = rng.gen();
292 val.into()
293 }
294}
295
296impl std::fmt::Display for M128 {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 let data: u128 = (*self).into();
299 write!(f, "{data:02X?}")
300 }
301}
302
303impl std::fmt::Debug for M128 {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 write!(f, "M128({})", self)
306 }
307}
308
309#[repr(align(16))]
310pub struct AlignedData(pub [u128; 1]);
311
312macro_rules! m128_from_u128 {
313 ($val:expr) => {{
314 let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
315 unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
316 }};
317}
318
319pub(super) use m128_from_u128;
320
321impl UnderlierType for M128 {
322 const LOG_BITS: usize = 7;
323}
324
325impl UnderlierWithBitOps for M128 {
326 const ZERO: Self = { Self(m128_from_u128!(0)) };
327 const ONE: Self = { Self(m128_from_u128!(1)) };
328 const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
329
330 #[inline(always)]
331 fn fill_with_bit(val: u8) -> Self {
332 assert!(val == 0 || val == 1);
333 Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
334 }
335
336 #[inline(always)]
337 fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
338 where
339 T: UnderlierType,
340 Self: From<T>,
341 {
342 match T::BITS {
343 1 | 2 | 4 => {
344 let mut f = make_func_to_i8::<T, Self>(f);
345
346 unsafe {
347 _mm_set_epi8(
348 f(15),
349 f(14),
350 f(13),
351 f(12),
352 f(11),
353 f(10),
354 f(9),
355 f(8),
356 f(7),
357 f(6),
358 f(5),
359 f(4),
360 f(3),
361 f(2),
362 f(1),
363 f(0),
364 )
365 }
366 .into()
367 }
368 8 => {
369 let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
370 unsafe {
371 _mm_set_epi8(
372 f(15),
373 f(14),
374 f(13),
375 f(12),
376 f(11),
377 f(10),
378 f(9),
379 f(8),
380 f(7),
381 f(6),
382 f(5),
383 f(4),
384 f(3),
385 f(2),
386 f(1),
387 f(0),
388 )
389 }
390 .into()
391 }
392 16 => {
393 let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
394 unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
395 }
396 32 => {
397 let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
398 unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
399 }
400 64 => {
401 let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
402 unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
403 }
404 128 => Self::from(f(0)),
405 _ => panic!("unsupported bit count"),
406 }
407 }
408
409 #[inline(always)]
410 unsafe fn get_subvalue<T>(&self, i: usize) -> T
411 where
412 T: UnderlierType + NumCast<Self>,
413 {
414 match T::BITS {
415 1 | 2 | 4 => {
416 let elements_in_8 = 8 / T::BITS;
417 let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
418 *arr.get_unchecked(i / elements_in_8)
419 });
420
421 let shift = (i % elements_in_8) * T::BITS;
422 value_u8 >>= shift;
423
424 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
425 }
426 8 => {
427 let value_u8 =
428 as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
429 T::from_underlier(T::num_cast_from(Self::from(value_u8)))
430 }
431 16 => {
432 let value_u16 =
433 as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
434 T::from_underlier(T::num_cast_from(Self::from(value_u16)))
435 }
436 32 => {
437 let value_u32 =
438 as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
439 T::from_underlier(T::num_cast_from(Self::from(value_u32)))
440 }
441 64 => {
442 let value_u64 =
443 as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
444 T::from_underlier(T::num_cast_from(Self::from(value_u64)))
445 }
446 128 => T::from_underlier(T::num_cast_from(*self)),
447 _ => panic!("unsupported bit count"),
448 }
449 }
450
451 #[inline(always)]
452 unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
453 where
454 T: UnderlierWithBitOps,
455 Self: From<T>,
456 {
457 match T::BITS {
458 1 | 2 | 4 => {
459 let elements_in_8 = 8 / T::BITS;
460 let mask = (1u8 << T::BITS) - 1;
461 let shift = (i % elements_in_8) * T::BITS;
462 let val = u8::num_cast_from(Self::from(val)) << shift;
463 let mask = mask << shift;
464
465 as_array_mut::<_, u8, 16>(self, |array| unsafe {
466 let element = array.get_unchecked_mut(i / elements_in_8);
467 *element &= !mask;
468 *element |= val;
469 });
470 }
471 8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
472 *array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
473 }),
474 16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
475 *array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
476 }),
477 32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
478 *array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
479 }),
480 64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
481 *array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
482 }),
483 128 => {
484 *self = Self::from(val);
485 }
486 _ => panic!("unsupported bit count"),
487 }
488 }
489
490 #[inline(always)]
491 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
492 where
493 T: UnderlierWithBitOps + NumCast<Self>,
494 Self: From<T>,
495 {
496 match T::LOG_BITS {
497 0 => match log_block_len {
498 0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
499 1 => {
500 let bits: [u8; 2] =
501 array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
502
503 _mm_set_epi64x(
504 u64::fill_with_bit(bits[1]) as i64,
505 u64::fill_with_bit(bits[0]) as i64,
506 )
507 .into()
508 }
509 2 => {
510 let bits: [u8; 4] =
511 array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
512
513 _mm_set_epi32(
514 u32::fill_with_bit(bits[3]) as i32,
515 u32::fill_with_bit(bits[2]) as i32,
516 u32::fill_with_bit(bits[1]) as i32,
517 u32::fill_with_bit(bits[0]) as i32,
518 )
519 .into()
520 }
521 3 => {
522 let bits: [u8; 8] =
523 array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
524
525 _mm_set_epi16(
526 u16::fill_with_bit(bits[7]) as i16,
527 u16::fill_with_bit(bits[6]) as i16,
528 u16::fill_with_bit(bits[5]) as i16,
529 u16::fill_with_bit(bits[4]) as i16,
530 u16::fill_with_bit(bits[3]) as i16,
531 u16::fill_with_bit(bits[2]) as i16,
532 u16::fill_with_bit(bits[1]) as i16,
533 u16::fill_with_bit(bits[0]) as i16,
534 )
535 .into()
536 }
537 4 => {
538 let bits: [u8; 16] =
539 array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
540
541 _mm_set_epi8(
542 u8::fill_with_bit(bits[15]) as i8,
543 u8::fill_with_bit(bits[14]) as i8,
544 u8::fill_with_bit(bits[13]) as i8,
545 u8::fill_with_bit(bits[12]) as i8,
546 u8::fill_with_bit(bits[11]) as i8,
547 u8::fill_with_bit(bits[10]) as i8,
548 u8::fill_with_bit(bits[9]) as i8,
549 u8::fill_with_bit(bits[8]) as i8,
550 u8::fill_with_bit(bits[7]) as i8,
551 u8::fill_with_bit(bits[6]) as i8,
552 u8::fill_with_bit(bits[5]) as i8,
553 u8::fill_with_bit(bits[4]) as i8,
554 u8::fill_with_bit(bits[3]) as i8,
555 u8::fill_with_bit(bits[2]) as i8,
556 u8::fill_with_bit(bits[1]) as i8,
557 u8::fill_with_bit(bits[0]) as i8,
558 )
559 .into()
560 }
561 _ => spread_fallback(self, log_block_len, block_idx),
562 },
563 1 => match log_block_len {
564 0 => {
565 let value =
566 U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
567
568 _mm_set1_epi8(value as i8).into()
569 }
570 1 => {
571 let bytes: [u8; 2] = array::from_fn(|i| {
572 U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
573 });
574
575 Self::from_fn::<u8>(|i| bytes[i / 8])
576 }
577 2 => {
578 let bytes: [u8; 4] = array::from_fn(|i| {
579 U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
580 });
581
582 Self::from_fn::<u8>(|i| bytes[i / 4])
583 }
584 3 => {
585 let bytes: [u8; 8] = array::from_fn(|i| {
586 U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
587 .spread_to_byte()
588 });
589
590 Self::from_fn::<u8>(|i| bytes[i / 2])
591 }
592 4 => {
593 let bytes: [u8; 16] = array::from_fn(|i| {
594 U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
595 .spread_to_byte()
596 });
597
598 Self::from_fn::<u8>(|i| bytes[i])
599 }
600 _ => spread_fallback(self, log_block_len, block_idx),
601 },
602 2 => match log_block_len {
603 0 => {
604 let value =
605 U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
606
607 _mm_set1_epi8(value as i8).into()
608 }
609 1 => {
610 let values: [u8; 2] = array::from_fn(|i| {
611 U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
612 });
613
614 Self::from_fn::<u8>(|i| values[i / 8])
615 }
616 2 => {
617 let values: [u8; 4] = array::from_fn(|i| {
618 U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
619 .spread_to_byte()
620 });
621
622 Self::from_fn::<u8>(|i| values[i / 4])
623 }
624 3 => {
625 let values: [u8; 8] = array::from_fn(|i| {
626 U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
627 .spread_to_byte()
628 });
629
630 Self::from_fn::<u8>(|i| values[i / 2])
631 }
632 4 => {
633 let values: [u8; 16] = array::from_fn(|i| {
634 U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
635 .spread_to_byte()
636 });
637
638 Self::from_fn::<u8>(|i| values[i])
639 }
640 _ => spread_fallback(self, log_block_len, block_idx),
641 },
642 3 => match log_block_len {
643 0 => _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into(),
644 1 => _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into(),
645 2 => _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into(),
646 3 => _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into(),
647 4 => self,
648 _ => panic!("unsupported block length"),
649 },
650 4 => match log_block_len {
651 0 => {
652 let value = (u128::from(self) >> (block_idx * 16)) as u16;
653
654 _mm_set1_epi16(value as i16).into()
655 }
656 1 => {
657 let values: [u16; 2] =
658 array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
659
660 Self::from_fn::<u16>(|i| values[i / 4])
661 }
662 2 => {
663 let values: [u16; 4] =
664 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
665
666 Self::from_fn::<u16>(|i| values[i / 2])
667 }
668 3 => self,
669 _ => panic!("unsupported block length"),
670 },
671 5 => match log_block_len {
672 0 => {
673 let value = (u128::from(self) >> (block_idx * 32)) as u32;
674
675 _mm_set1_epi32(value as i32).into()
676 }
677 1 => {
678 let values: [u32; 2] =
679 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
680
681 Self::from_fn::<u32>(|i| values[i / 2])
682 }
683 2 => self,
684 _ => panic!("unsupported block length"),
685 },
686 6 => match log_block_len {
687 0 => {
688 let value = (u128::from(self) >> (block_idx * 64)) as u64;
689
690 _mm_set1_epi64x(value as i64).into()
691 }
692 1 => self,
693 _ => panic!("unsupported block length"),
694 },
695 7 => self,
696 _ => panic!("unsupported bit length"),
697 }
698 }
699
700 #[inline]
701 fn shl_128b_lanes(self, shift: usize) -> Self {
702 self << shift
703 }
704
705 #[inline]
706 fn shr_128b_lanes(self, shift: usize) -> Self {
707 self >> shift
708 }
709
710 #[inline]
711 fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
712 match log_block_len {
713 0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
714 3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
715 4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
716 5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
717 6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
718 _ => panic!("unsupported block length"),
719 }
720 }
721
722 #[inline]
723 fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
724 match log_block_len {
725 0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
726 3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
727 4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
728 5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
729 6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
730 _ => panic!("unsupported block length"),
731 }
732 }
733
734 #[inline]
735 fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
736 where
737 u8: NumCast<Self>,
738 Self: From<u8>,
739 {
740 transpose_128b_values::<Self, TL>(values, 0);
741 }
742
743 #[inline]
744 fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
745 where
746 u8: NumCast<Self>,
747 Self: From<u8>,
748 {
749 if TL::LOG_WIDTH == 0 {
750 return;
751 }
752
753 match TL::LOG_WIDTH {
754 1 => unsafe {
755 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
756 for v in values.as_mut().iter_mut() {
757 *v = _mm_shuffle_epi8(v.0, shuffle).into();
758 }
759 },
760 2 => unsafe {
761 let shuffle = _mm_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
762 for v in values.as_mut().iter_mut() {
763 *v = _mm_shuffle_epi8(v.0, shuffle).into();
764 }
765 },
766 3 => unsafe {
767 let shuffle = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
768 for v in values.as_mut().iter_mut() {
769 *v = _mm_shuffle_epi8(v.0, shuffle).into();
770 }
771 },
772 4 => {}
773 _ => unreachable!("Log width must be less than 5"),
774 }
775
776 transpose_128b_values::<_, TL>(values, 4 - TL::LOG_WIDTH);
777 }
778}
779
780unsafe impl Zeroable for M128 {}
781
782unsafe impl Pod for M128 {}
783
784unsafe impl Send for M128 {}
785
786unsafe impl Sync for M128 {}
787
788static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
789static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
790static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
791static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
792
793const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
794 log_block_len: usize,
795 t_log_bits: usize,
796) -> [M128; BLOCK_IDX_AMOUNT] {
797 let element_log_width = t_log_bits - 3;
798
799 let element_width = 1 << element_log_width;
800
801 let block_size = 1 << (log_block_len + element_log_width);
802 let repeat = 1 << (4 - element_log_width - log_block_len);
803 let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
804
805 let mut block_idx = 0;
806
807 while block_idx < BLOCK_IDX_AMOUNT {
808 let base = block_idx * block_size;
809 let mut j = 0;
810 while j < 16 {
811 masks[block_idx][j] =
812 (base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
813 j += 1;
814 }
815 block_idx += 1;
816 }
817 let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
818
819 let mut block_idx = 0;
820
821 while block_idx < BLOCK_IDX_AMOUNT {
822 m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
823 block_idx += 1;
824 }
825
826 m128_masks
827}
828
829impl UnderlierWithBitConstants for M128 {
830 const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
831 Self::from_u128(interleave_mask_even!(u128, 0)),
832 Self::from_u128(interleave_mask_even!(u128, 1)),
833 Self::from_u128(interleave_mask_even!(u128, 2)),
834 Self::from_u128(interleave_mask_even!(u128, 3)),
835 Self::from_u128(interleave_mask_even!(u128, 4)),
836 Self::from_u128(interleave_mask_even!(u128, 5)),
837 Self::from_u128(interleave_mask_even!(u128, 6)),
838 ];
839
840 const INTERLEAVE_ODD_MASK: &'static [Self] = &[
841 Self::from_u128(interleave_mask_odd!(u128, 0)),
842 Self::from_u128(interleave_mask_odd!(u128, 1)),
843 Self::from_u128(interleave_mask_odd!(u128, 2)),
844 Self::from_u128(interleave_mask_odd!(u128, 3)),
845 Self::from_u128(interleave_mask_odd!(u128, 4)),
846 Self::from_u128(interleave_mask_odd!(u128, 5)),
847 Self::from_u128(interleave_mask_odd!(u128, 6)),
848 ];
849
850 #[inline(always)]
851 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
852 unsafe {
853 let (c, d) = interleave_bits(
854 Into::<Self>::into(self).into(),
855 Into::<Self>::into(other).into(),
856 log_block_len,
857 );
858 (Self::from(c), Self::from(d))
859 }
860 }
861}
862
863impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
864 fn from(value: __m128i) -> Self {
865 M128::from(value).into()
866 }
867}
868
869impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
870 fn from(value: u128) -> Self {
871 M128::from(value).into()
872 }
873}
874
875impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
876 fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
877 value.to_underlier().into()
878 }
879}
880
881impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
882where
883 u128: From<Scalar::Underlier>,
884{
885 #[inline(always)]
886 fn broadcast(scalar: Scalar) -> Self {
887 let tower_level = Scalar::N_BITS.ilog2() as usize;
888 let mut value = u128::from(scalar.to_underlier());
889 for n in tower_level..3 {
890 value |= value << (1 << n);
891 }
892
893 let value = must_cast(value);
894 let value = match tower_level {
895 0..=3 => unsafe { _mm_broadcastb_epi8(value) },
896 4 => unsafe { _mm_broadcastw_epi16(value) },
897 5 => unsafe { _mm_broadcastd_epi32(value) },
898 6 => unsafe { _mm_broadcastq_epi64(value) },
899 7 => value,
900 _ => unreachable!(),
901 };
902
903 value.into()
904 }
905}
906
907#[inline]
908unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
909 match log_block_len {
910 0 => {
911 let mask = _mm_set1_epi8(0x55i8);
912 interleave_bits_imm::<1>(a, b, mask)
913 }
914 1 => {
915 let mask = _mm_set1_epi8(0x33i8);
916 interleave_bits_imm::<2>(a, b, mask)
917 }
918 2 => {
919 let mask = _mm_set1_epi8(0x0fi8);
920 interleave_bits_imm::<4>(a, b, mask)
921 }
922 3 => {
923 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
924 let a = _mm_shuffle_epi8(a, shuffle);
925 let b = _mm_shuffle_epi8(b, shuffle);
926 let a_prime = _mm_unpacklo_epi8(a, b);
927 let b_prime = _mm_unpackhi_epi8(a, b);
928 (a_prime, b_prime)
929 }
930 4 => {
931 let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
932 let a = _mm_shuffle_epi8(a, shuffle);
933 let b = _mm_shuffle_epi8(b, shuffle);
934 let a_prime = _mm_unpacklo_epi16(a, b);
935 let b_prime = _mm_unpackhi_epi16(a, b);
936 (a_prime, b_prime)
937 }
938 5 => {
939 let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
940 let a = _mm_shuffle_epi8(a, shuffle);
941 let b = _mm_shuffle_epi8(b, shuffle);
942 let a_prime = _mm_unpacklo_epi32(a, b);
943 let b_prime = _mm_unpackhi_epi32(a, b);
944 (a_prime, b_prime)
945 }
946 6 => {
947 let a_prime = _mm_unpacklo_epi64(a, b);
948 let b_prime = _mm_unpackhi_epi64(a, b);
949 (a_prime, b_prime)
950 }
951 _ => panic!("unsupported block length"),
952 }
953}
954
955#[inline]
956unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
957 a: __m128i,
958 b: __m128i,
959 mask: __m128i,
960) -> (__m128i, __m128i) {
961 let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
962 let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
963 let b_prime = _mm_xor_si128(b, t);
964 (a_prime, b_prime)
965}
966
967impl_iteration!(M128,
968 @strategy BitIterationStrategy, U1,
969 @strategy FallbackStrategy, U2, U4,
970 @strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
971);
972
973#[cfg(test)]
974mod tests {
975 use proptest::{arbitrary::any, proptest};
976
977 use super::*;
978 use crate::underlier::single_element_mask_bits;
979
980 fn check_roundtrip<T>(val: M128)
981 where
982 T: From<M128>,
983 M128: From<T>,
984 {
985 assert_eq!(M128::from(T::from(val)), val);
986 }
987
988 #[test]
989 fn test_constants() {
990 assert_eq!(M128::default(), M128::ZERO);
991 assert_eq!(M128::from(0u128), M128::ZERO);
992 assert_eq!(M128::from(1u128), M128::ONE);
993 }
994
995 fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
996 (value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
997 }
998
999 proptest! {
1000 #[test]
1001 fn test_conversion(a in any::<u128>()) {
1002 check_roundtrip::<u128>(a.into());
1003 check_roundtrip::<__m128i>(a.into());
1004 }
1005
1006 #[test]
1007 fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1008 assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1009 assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1010 assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1011 }
1012
1013 #[test]
1014 fn test_negate(a in any::<u128>()) {
1015 assert_eq!(M128::from(!a), !M128::from(a))
1016 }
1017
1018 #[test]
1019 fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1020 assert_eq!(M128::from(a << b), M128::from(a) << b);
1021 assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1022 }
1023
1024 #[test]
1025 fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1026 let a = M128::from(a);
1027 let b = M128::from(b);
1028
1029 let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1030 let (c, d) = (M128::from(c), M128::from(d));
1031
1032 for i in (0..128>>height).step_by(2) {
1033 assert_eq!(get(c, height, i), get(a, height, i));
1034 assert_eq!(get(c, height, i+1), get(b, height, i));
1035 assert_eq!(get(d, height, i), get(a, height, i+1));
1036 assert_eq!(get(d, height, i+1), get(b, height, i+1));
1037 }
1038 }
1039
1040 #[test]
1041 fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1042 let a = M128::from(a);
1043 let b = M128::from(b);
1044
1045 let result = a.unpack_lo_128b_lanes(b, height);
1046 for i in 0..128>>(height + 1) {
1047 assert_eq!(get(result, height, 2*i), get(a, height, i));
1048 assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1049 }
1050 }
1051
1052 #[test]
1053 fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1054 let a = M128::from(a);
1055 let b = M128::from(b);
1056
1057 let result = a.unpack_hi_128b_lanes(b, height);
1058 let half_block_count = 128>>(height + 1);
1059 for i in 0..half_block_count {
1060 assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1061 assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1062 }
1063 }
1064 }
1065
1066 #[test]
1067 fn test_fill_with_bit() {
1068 assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1069 assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1070 }
1071
1072 #[test]
1073 fn test_eq() {
1074 let a = M128::from(0u128);
1075 let b = M128::from(42u128);
1076 let c = M128::from(u128::MAX);
1077
1078 assert_eq!(a, a);
1079 assert_eq!(b, b);
1080 assert_eq!(c, c);
1081
1082 assert_ne!(a, b);
1083 assert_ne!(a, c);
1084 assert_ne!(b, c);
1085 }
1086}