1use std::{
4 arch::x86_64::*,
5 array,
6 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializeBytes,
11 bytes::{Buf, BufMut},
12 serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{
16 Rng,
17 distr::{Distribution, StandardUniform},
18};
19use seq_macro::seq;
20
21use crate::{
22 BinaryField,
23 arch::portable::packed::PackedPrimitiveType,
24 underlier::{
25 Divisible, NumCast, SmallU, SpreadToByte, U2, U4, UnderlierType, UnderlierWithBitOps,
26 impl_divisible_bitmask, mapget, spread_fallback,
27 },
28};
29
30#[derive(Copy, Clone)]
32#[repr(transparent)]
33pub struct M128(pub(super) __m128i);
34
35impl M128 {
36 #[inline(always)]
37 pub const fn from_u128(val: u128) -> Self {
38 let mut result = Self::ZERO;
39 unsafe {
40 result.0 = std::mem::transmute_copy(&val);
41 }
42
43 result
44 }
45}
46
47impl From<__m128i> for M128 {
48 #[inline(always)]
49 fn from(value: __m128i) -> Self {
50 Self(value)
51 }
52}
53
54impl From<u128> for M128 {
55 fn from(value: u128) -> Self {
56 Self(unsafe {
57 assert_eq!(align_of::<u128>(), 16);
59 _mm_load_si128(&raw const value as *const __m128i)
60 })
61 }
62}
63
64impl From<u64> for M128 {
65 fn from(value: u64) -> Self {
66 Self::from(value as u128)
67 }
68}
69
70impl From<u32> for M128 {
71 fn from(value: u32) -> Self {
72 Self::from(value as u128)
73 }
74}
75
76impl From<u16> for M128 {
77 fn from(value: u16) -> Self {
78 Self::from(value as u128)
79 }
80}
81
82impl From<u8> for M128 {
83 fn from(value: u8) -> Self {
84 Self::from(value as u128)
85 }
86}
87
88impl<const N: usize> From<SmallU<N>> for M128 {
89 fn from(value: SmallU<N>) -> Self {
90 Self::from(value.val() as u128)
91 }
92}
93
94impl From<M128> for u128 {
95 fn from(value: M128) -> Self {
96 let mut result = 0u128;
97 unsafe {
98 assert_eq!(align_of::<u128>(), 16);
100 _mm_store_si128(&raw mut result as *mut __m128i, value.0)
101 };
102 result
103 }
104}
105
106impl From<M128> for __m128i {
107 #[inline(always)]
108 fn from(value: M128) -> Self {
109 value.0
110 }
111}
112
113impl SerializeBytes for M128 {
114 fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
115 assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
116
117 let raw_value: u128 = (*self).into();
118
119 write_buf.put_u128_le(raw_value);
120 Ok(())
121 }
122}
123
124impl DeserializeBytes for M128 {
125 fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
126 where
127 Self: Sized,
128 {
129 assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
130
131 let raw_value = read_buf.get_u128_le();
132
133 Ok(Self::from(raw_value))
134 }
135}
136
137impl_divisible_bitmask!(M128, 1, 2, 4);
138
139impl<U: NumCast<u128>> NumCast<M128> for U {
140 #[inline(always)]
141 fn num_cast_from(val: M128) -> Self {
142 Self::num_cast_from(u128::from(val))
143 }
144}
145
146impl Default for M128 {
147 #[inline(always)]
148 fn default() -> Self {
149 Self(unsafe { _mm_setzero_si128() })
150 }
151}
152
153impl BitAnd for M128 {
154 type Output = Self;
155
156 #[inline(always)]
157 fn bitand(self, rhs: Self) -> Self::Output {
158 Self(unsafe { _mm_and_si128(self.0, rhs.0) })
159 }
160}
161
162impl BitAndAssign for M128 {
163 #[inline(always)]
164 fn bitand_assign(&mut self, rhs: Self) {
165 *self = *self & rhs
166 }
167}
168
169impl BitOr for M128 {
170 type Output = Self;
171
172 #[inline(always)]
173 fn bitor(self, rhs: Self) -> Self::Output {
174 Self(unsafe { _mm_or_si128(self.0, rhs.0) })
175 }
176}
177
178impl BitOrAssign for M128 {
179 #[inline(always)]
180 fn bitor_assign(&mut self, rhs: Self) {
181 *self = *self | rhs
182 }
183}
184
185impl BitXor for M128 {
186 type Output = Self;
187
188 #[inline(always)]
189 fn bitxor(self, rhs: Self) -> Self::Output {
190 Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
191 }
192}
193
194impl BitXorAssign for M128 {
195 #[inline(always)]
196 fn bitxor_assign(&mut self, rhs: Self) {
197 *self = *self ^ rhs;
198 }
199}
200
201impl Not for M128 {
202 type Output = Self;
203
204 fn not(self) -> Self::Output {
205 const ONES: __m128i = m128_from_u128!(u128::MAX);
206
207 self ^ Self(ONES)
208 }
209}
210
211pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
213 if left > right { left } else { right }
214}
215
216macro_rules! bitshift_128b {
221 ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
222 unsafe {
223 let carry = $byte_shift($val, 8);
224 seq!(N in 64..128 {
225 if $shift == N {
226 return $bit_shift_64(
227 carry,
228 crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
229 ).into();
230 }
231 });
232 seq!(N in 0..64 {
233 if $shift == N {
234 let carry = $bit_shift_64_opposite(
235 carry,
236 crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
237 );
238
239 let val = $bit_shift_64($val, N);
240 return $or(val, carry).into();
241 }
242 });
243
244 return Default::default()
245 }
246 };
247}
248
249impl Shr<usize> for M128 {
250 type Output = Self;
251
252 #[inline(always)]
253 fn shr(self, rhs: usize) -> Self::Output {
254 bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
257 }
258}
259
260impl Shl<usize> for M128 {
261 type Output = Self;
262
263 #[inline(always)]
264 fn shl(self, rhs: usize) -> Self::Output {
265 bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
268 }
269}
270
271impl PartialEq for M128 {
272 fn eq(&self, other: &Self) -> bool {
273 unsafe {
274 let neq = _mm_xor_si128(self.0, other.0);
275 _mm_test_all_zeros(neq, neq) == 1
276 }
277 }
278}
279
280impl Eq for M128 {}
281
282impl PartialOrd for M128 {
283 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
284 Some(self.cmp(other))
285 }
286}
287
288impl Ord for M128 {
289 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
290 u128::from(*self).cmp(&u128::from(*other))
291 }
292}
293
294impl Distribution<M128> for StandardUniform {
295 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
296 M128(rng.random())
297 }
298}
299
300impl std::fmt::Display for M128 {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 let data: u128 = (*self).into();
303 write!(f, "{data:02X?}")
304 }
305}
306
307impl std::fmt::Debug for M128 {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 write!(f, "M128({self})")
310 }
311}
312
313#[repr(align(16))]
314pub struct AlignedData(pub [u128; 1]);
315
316macro_rules! m128_from_u128 {
317 ($val:expr) => {{
318 let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
319 unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
320 }};
321}
322
323pub(super) use m128_from_u128;
324
325impl UnderlierType for M128 {
326 const LOG_BITS: usize = 7;
327}
328
329impl UnderlierWithBitOps for M128 {
330 const ZERO: Self = { Self(m128_from_u128!(0)) };
331 const ONE: Self = { Self(m128_from_u128!(1)) };
332 const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
333
334 #[inline(always)]
335 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
336 unsafe {
337 let (c, d) = interleave_bits(
338 Into::<Self>::into(self).into(),
339 Into::<Self>::into(other).into(),
340 log_block_len,
341 );
342 (Self::from(c), Self::from(d))
343 }
344 }
345
346 #[inline(always)]
347 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
348 where
349 T: UnderlierWithBitOps,
350 Self: Divisible<T>,
351 {
352 match T::LOG_BITS {
353 0 => match log_block_len {
354 0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
355 1 => unsafe {
356 let bits: [u8; 2] =
357 array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
358
359 _mm_set_epi64x(
360 u64::fill_with_bit(bits[1]) as i64,
361 u64::fill_with_bit(bits[0]) as i64,
362 )
363 .into()
364 },
365 2 => unsafe {
366 let bits: [u8; 4] =
367 array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
368
369 _mm_set_epi32(
370 u32::fill_with_bit(bits[3]) as i32,
371 u32::fill_with_bit(bits[2]) as i32,
372 u32::fill_with_bit(bits[1]) as i32,
373 u32::fill_with_bit(bits[0]) as i32,
374 )
375 .into()
376 },
377 3 => unsafe {
378 let bits: [u8; 8] =
379 array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
380
381 _mm_set_epi16(
382 u16::fill_with_bit(bits[7]) as i16,
383 u16::fill_with_bit(bits[6]) as i16,
384 u16::fill_with_bit(bits[5]) as i16,
385 u16::fill_with_bit(bits[4]) as i16,
386 u16::fill_with_bit(bits[3]) as i16,
387 u16::fill_with_bit(bits[2]) as i16,
388 u16::fill_with_bit(bits[1]) as i16,
389 u16::fill_with_bit(bits[0]) as i16,
390 )
391 .into()
392 },
393 4 => unsafe {
394 let bits: [u8; 16] =
395 array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
396
397 _mm_set_epi8(
398 u8::fill_with_bit(bits[15]) as i8,
399 u8::fill_with_bit(bits[14]) as i8,
400 u8::fill_with_bit(bits[13]) as i8,
401 u8::fill_with_bit(bits[12]) as i8,
402 u8::fill_with_bit(bits[11]) as i8,
403 u8::fill_with_bit(bits[10]) as i8,
404 u8::fill_with_bit(bits[9]) as i8,
405 u8::fill_with_bit(bits[8]) as i8,
406 u8::fill_with_bit(bits[7]) as i8,
407 u8::fill_with_bit(bits[6]) as i8,
408 u8::fill_with_bit(bits[5]) as i8,
409 u8::fill_with_bit(bits[4]) as i8,
410 u8::fill_with_bit(bits[3]) as i8,
411 u8::fill_with_bit(bits[2]) as i8,
412 u8::fill_with_bit(bits[1]) as i8,
413 u8::fill_with_bit(bits[0]) as i8,
414 )
415 .into()
416 },
417 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
418 },
419 1 => match log_block_len {
420 0 => unsafe {
421 let value =
422 U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
423
424 _mm_set1_epi8(value as i8).into()
425 },
426 1 => {
427 let bytes: [u8; 2] = array::from_fn(|i| {
428 U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
429 });
430
431 Self::from_fn::<u8>(|i| bytes[i / 8])
432 }
433 2 => {
434 let bytes: [u8; 4] = array::from_fn(|i| {
435 U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
436 });
437
438 Self::from_fn::<u8>(|i| bytes[i / 4])
439 }
440 3 => {
441 let bytes: [u8; 8] = array::from_fn(|i| {
442 U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
443 .spread_to_byte()
444 });
445
446 Self::from_fn::<u8>(|i| bytes[i / 2])
447 }
448 4 => {
449 let bytes: [u8; 16] = array::from_fn(|i| {
450 U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
451 .spread_to_byte()
452 });
453
454 Self::from_fn::<u8>(|i| bytes[i])
455 }
456 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
457 },
458 2 => match log_block_len {
459 0 => {
460 let value =
461 U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
462
463 unsafe { _mm_set1_epi8(value as i8).into() }
464 }
465 1 => {
466 let values: [u8; 2] = array::from_fn(|i| {
467 U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
468 });
469
470 Self::from_fn::<u8>(|i| values[i / 8])
471 }
472 2 => {
473 let values: [u8; 4] = array::from_fn(|i| {
474 U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
475 .spread_to_byte()
476 });
477
478 Self::from_fn::<u8>(|i| values[i / 4])
479 }
480 3 => {
481 let values: [u8; 8] = array::from_fn(|i| {
482 U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
483 .spread_to_byte()
484 });
485
486 Self::from_fn::<u8>(|i| values[i / 2])
487 }
488 4 => {
489 let values: [u8; 16] = array::from_fn(|i| {
490 U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
491 .spread_to_byte()
492 });
493
494 Self::from_fn::<u8>(|i| values[i])
495 }
496 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
497 },
498 3 => match log_block_len {
499 0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
500 1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
501 2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
502 3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
503 4 => self,
504 _ => panic!("unsupported block length"),
505 },
506 4 => match log_block_len {
507 0 => {
508 let value = (u128::from(self) >> (block_idx * 16)) as u16;
509
510 unsafe { _mm_set1_epi16(value as i16).into() }
511 }
512 1 => {
513 let values: [u16; 2] =
514 array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
515
516 Self::from_fn::<u16>(|i| values[i / 4])
517 }
518 2 => {
519 let values: [u16; 4] =
520 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
521
522 Self::from_fn::<u16>(|i| values[i / 2])
523 }
524 3 => self,
525 _ => panic!("unsupported block length"),
526 },
527 5 => match log_block_len {
528 0 => unsafe {
529 let value = (u128::from(self) >> (block_idx * 32)) as u32;
530
531 _mm_set1_epi32(value as i32).into()
532 },
533 1 => {
534 let values: [u32; 2] =
535 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
536
537 Self::from_fn::<u32>(|i| values[i / 2])
538 }
539 2 => self,
540 _ => panic!("unsupported block length"),
541 },
542 6 => match log_block_len {
543 0 => unsafe {
544 let value = (u128::from(self) >> (block_idx * 64)) as u64;
545
546 _mm_set1_epi64x(value as i64).into()
547 },
548 1 => self,
549 _ => panic!("unsupported block length"),
550 },
551 7 => self,
552 _ => panic!("unsupported bit length"),
553 }
554 }
555}
556
557unsafe impl Zeroable for M128 {}
558
559unsafe impl Pod for M128 {}
560
561unsafe impl Send for M128 {}
562
563unsafe impl Sync for M128 {}
564
565static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
566static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
567static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
568static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
569
570const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
571 log_block_len: usize,
572 t_log_bits: usize,
573) -> [M128; BLOCK_IDX_AMOUNT] {
574 let element_log_width = t_log_bits - 3;
575
576 let element_width = 1 << element_log_width;
577
578 let block_size = 1 << (log_block_len + element_log_width);
579 let repeat = 1 << (4 - element_log_width - log_block_len);
580 let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
581
582 let mut block_idx = 0;
583
584 while block_idx < BLOCK_IDX_AMOUNT {
585 let base = block_idx * block_size;
586 let mut j = 0;
587 while j < 16 {
588 masks[block_idx][j] =
589 (base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
590 j += 1;
591 }
592 block_idx += 1;
593 }
594 let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
595
596 let mut block_idx = 0;
597
598 while block_idx < BLOCK_IDX_AMOUNT {
599 m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
600 block_idx += 1;
601 }
602
603 m128_masks
604}
605
606impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
607 fn from(value: __m128i) -> Self {
608 M128::from(value).into()
609 }
610}
611
612impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
613 fn from(value: u128) -> Self {
614 M128::from(value).into()
615 }
616}
617
618impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
619 fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
620 value.to_underlier().into()
621 }
622}
623
624#[inline]
625unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
626 match log_block_len {
627 0 => unsafe {
628 let mask = _mm_set1_epi8(0x55i8);
629 interleave_bits_imm::<1>(a, b, mask)
630 },
631 1 => unsafe {
632 let mask = _mm_set1_epi8(0x33i8);
633 interleave_bits_imm::<2>(a, b, mask)
634 },
635 2 => unsafe {
636 let mask = _mm_set1_epi8(0x0fi8);
637 interleave_bits_imm::<4>(a, b, mask)
638 },
639 3 => unsafe {
640 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
641 let a = _mm_shuffle_epi8(a, shuffle);
642 let b = _mm_shuffle_epi8(b, shuffle);
643 let a_prime = _mm_unpacklo_epi8(a, b);
644 let b_prime = _mm_unpackhi_epi8(a, b);
645 (a_prime, b_prime)
646 },
647 4 => unsafe {
648 let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
649 let a = _mm_shuffle_epi8(a, shuffle);
650 let b = _mm_shuffle_epi8(b, shuffle);
651 let a_prime = _mm_unpacklo_epi16(a, b);
652 let b_prime = _mm_unpackhi_epi16(a, b);
653 (a_prime, b_prime)
654 },
655 5 => unsafe {
656 let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
657 let a = _mm_shuffle_epi8(a, shuffle);
658 let b = _mm_shuffle_epi8(b, shuffle);
659 let a_prime = _mm_unpacklo_epi32(a, b);
660 let b_prime = _mm_unpackhi_epi32(a, b);
661 (a_prime, b_prime)
662 },
663 6 => unsafe {
664 let a_prime = _mm_unpacklo_epi64(a, b);
665 let b_prime = _mm_unpackhi_epi64(a, b);
666 (a_prime, b_prime)
667 },
668 _ => panic!("unsupported block length"),
669 }
670}
671
672#[inline]
673unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
674 a: __m128i,
675 b: __m128i,
676 mask: __m128i,
677) -> (__m128i, __m128i) {
678 unsafe {
679 let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
680 let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
681 let b_prime = _mm_xor_si128(b, t);
682 (a_prime, b_prime)
683 }
684}
685
686impl Divisible<u128> for M128 {
689 const LOG_N: usize = 0;
690
691 #[inline]
692 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone {
693 std::iter::once(u128::from(value))
694 }
695
696 #[inline]
697 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
698 std::iter::once(u128::from(*value))
699 }
700
701 #[inline]
702 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
703 slice.iter().map(|&v| u128::from(v))
704 }
705
706 #[inline]
707 fn get(self, index: usize) -> u128 {
708 assert!(index == 0, "index out of bounds");
709 u128::from(self)
710 }
711
712 #[inline]
713 fn set(self, index: usize, val: u128) -> Self {
714 assert!(index == 0, "index out of bounds");
715 Self::from(val)
716 }
717
718 #[inline]
719 fn broadcast(val: u128) -> Self {
720 Self::from(val)
721 }
722
723 #[inline]
724 fn from_iter(mut iter: impl Iterator<Item = u128>) -> Self {
725 iter.next().map(Self::from).unwrap_or(Self::ZERO)
726 }
727}
728
729impl Divisible<u64> for M128 {
730 const LOG_N: usize = 1;
731
732 #[inline]
733 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone {
734 mapget::value_iter(value)
735 }
736
737 #[inline]
738 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
739 mapget::value_iter(*value)
740 }
741
742 #[inline]
743 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
744 mapget::slice_iter(slice)
745 }
746
747 #[inline]
748 fn get(self, index: usize) -> u64 {
749 unsafe {
750 match index {
751 0 => _mm_extract_epi64(self.0, 0) as u64,
752 1 => _mm_extract_epi64(self.0, 1) as u64,
753 _ => panic!("index out of bounds"),
754 }
755 }
756 }
757
758 #[inline]
759 fn set(self, index: usize, val: u64) -> Self {
760 unsafe {
761 match index {
762 0 => Self(_mm_insert_epi64(self.0, val as i64, 0)),
763 1 => Self(_mm_insert_epi64(self.0, val as i64, 1)),
764 _ => panic!("index out of bounds"),
765 }
766 }
767 }
768
769 #[inline]
770 fn broadcast(val: u64) -> Self {
771 unsafe { Self(_mm_set1_epi64x(val as i64)) }
772 }
773
774 #[inline]
775 fn from_iter(iter: impl Iterator<Item = u64>) -> Self {
776 let mut result = Self::ZERO;
777 let arr: &mut [u64; 2] = bytemuck::cast_mut(&mut result);
778 for (i, val) in iter.take(2).enumerate() {
779 arr[i] = val;
780 }
781 result
782 }
783}
784
785impl Divisible<u32> for M128 {
786 const LOG_N: usize = 2;
787
788 #[inline]
789 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone {
790 mapget::value_iter(value)
791 }
792
793 #[inline]
794 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
795 mapget::value_iter(*value)
796 }
797
798 #[inline]
799 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
800 mapget::slice_iter(slice)
801 }
802
803 #[inline]
804 fn get(self, index: usize) -> u32 {
805 unsafe {
806 match index {
807 0 => _mm_extract_epi32(self.0, 0) as u32,
808 1 => _mm_extract_epi32(self.0, 1) as u32,
809 2 => _mm_extract_epi32(self.0, 2) as u32,
810 3 => _mm_extract_epi32(self.0, 3) as u32,
811 _ => panic!("index out of bounds"),
812 }
813 }
814 }
815
816 #[inline]
817 fn set(self, index: usize, val: u32) -> Self {
818 unsafe {
819 match index {
820 0 => Self(_mm_insert_epi32(self.0, val as i32, 0)),
821 1 => Self(_mm_insert_epi32(self.0, val as i32, 1)),
822 2 => Self(_mm_insert_epi32(self.0, val as i32, 2)),
823 3 => Self(_mm_insert_epi32(self.0, val as i32, 3)),
824 _ => panic!("index out of bounds"),
825 }
826 }
827 }
828
829 #[inline]
830 fn broadcast(val: u32) -> Self {
831 unsafe { Self(_mm_set1_epi32(val as i32)) }
832 }
833
834 #[inline]
835 fn from_iter(iter: impl Iterator<Item = u32>) -> Self {
836 let mut result = Self::ZERO;
837 let arr: &mut [u32; 4] = bytemuck::cast_mut(&mut result);
838 for (i, val) in iter.take(4).enumerate() {
839 arr[i] = val;
840 }
841 result
842 }
843}
844
845impl Divisible<u16> for M128 {
846 const LOG_N: usize = 3;
847
848 #[inline]
849 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone {
850 mapget::value_iter(value)
851 }
852
853 #[inline]
854 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
855 mapget::value_iter(*value)
856 }
857
858 #[inline]
859 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
860 mapget::slice_iter(slice)
861 }
862
863 #[inline]
864 fn get(self, index: usize) -> u16 {
865 unsafe {
866 match index {
867 0 => _mm_extract_epi16(self.0, 0) as u16,
868 1 => _mm_extract_epi16(self.0, 1) as u16,
869 2 => _mm_extract_epi16(self.0, 2) as u16,
870 3 => _mm_extract_epi16(self.0, 3) as u16,
871 4 => _mm_extract_epi16(self.0, 4) as u16,
872 5 => _mm_extract_epi16(self.0, 5) as u16,
873 6 => _mm_extract_epi16(self.0, 6) as u16,
874 7 => _mm_extract_epi16(self.0, 7) as u16,
875 _ => panic!("index out of bounds"),
876 }
877 }
878 }
879
880 #[inline]
881 fn set(self, index: usize, val: u16) -> Self {
882 unsafe {
883 match index {
884 0 => Self(_mm_insert_epi16(self.0, val as i32, 0)),
885 1 => Self(_mm_insert_epi16(self.0, val as i32, 1)),
886 2 => Self(_mm_insert_epi16(self.0, val as i32, 2)),
887 3 => Self(_mm_insert_epi16(self.0, val as i32, 3)),
888 4 => Self(_mm_insert_epi16(self.0, val as i32, 4)),
889 5 => Self(_mm_insert_epi16(self.0, val as i32, 5)),
890 6 => Self(_mm_insert_epi16(self.0, val as i32, 6)),
891 7 => Self(_mm_insert_epi16(self.0, val as i32, 7)),
892 _ => panic!("index out of bounds"),
893 }
894 }
895 }
896
897 #[inline]
898 fn broadcast(val: u16) -> Self {
899 unsafe { Self(_mm_set1_epi16(val as i16)) }
900 }
901
902 #[inline]
903 fn from_iter(iter: impl Iterator<Item = u16>) -> Self {
904 let mut result = Self::ZERO;
905 let arr: &mut [u16; 8] = bytemuck::cast_mut(&mut result);
906 for (i, val) in iter.take(8).enumerate() {
907 arr[i] = val;
908 }
909 result
910 }
911}
912
913impl Divisible<u8> for M128 {
914 const LOG_N: usize = 4;
915
916 #[inline]
917 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone {
918 mapget::value_iter(value)
919 }
920
921 #[inline]
922 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
923 mapget::value_iter(*value)
924 }
925
926 #[inline]
927 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
928 mapget::slice_iter(slice)
929 }
930
931 #[inline]
932 fn get(self, index: usize) -> u8 {
933 unsafe {
934 match index {
935 0 => _mm_extract_epi8(self.0, 0) as u8,
936 1 => _mm_extract_epi8(self.0, 1) as u8,
937 2 => _mm_extract_epi8(self.0, 2) as u8,
938 3 => _mm_extract_epi8(self.0, 3) as u8,
939 4 => _mm_extract_epi8(self.0, 4) as u8,
940 5 => _mm_extract_epi8(self.0, 5) as u8,
941 6 => _mm_extract_epi8(self.0, 6) as u8,
942 7 => _mm_extract_epi8(self.0, 7) as u8,
943 8 => _mm_extract_epi8(self.0, 8) as u8,
944 9 => _mm_extract_epi8(self.0, 9) as u8,
945 10 => _mm_extract_epi8(self.0, 10) as u8,
946 11 => _mm_extract_epi8(self.0, 11) as u8,
947 12 => _mm_extract_epi8(self.0, 12) as u8,
948 13 => _mm_extract_epi8(self.0, 13) as u8,
949 14 => _mm_extract_epi8(self.0, 14) as u8,
950 15 => _mm_extract_epi8(self.0, 15) as u8,
951 _ => panic!("index out of bounds"),
952 }
953 }
954 }
955
956 #[inline]
957 fn set(self, index: usize, val: u8) -> Self {
958 unsafe {
959 match index {
960 0 => Self(_mm_insert_epi8(self.0, val as i32, 0)),
961 1 => Self(_mm_insert_epi8(self.0, val as i32, 1)),
962 2 => Self(_mm_insert_epi8(self.0, val as i32, 2)),
963 3 => Self(_mm_insert_epi8(self.0, val as i32, 3)),
964 4 => Self(_mm_insert_epi8(self.0, val as i32, 4)),
965 5 => Self(_mm_insert_epi8(self.0, val as i32, 5)),
966 6 => Self(_mm_insert_epi8(self.0, val as i32, 6)),
967 7 => Self(_mm_insert_epi8(self.0, val as i32, 7)),
968 8 => Self(_mm_insert_epi8(self.0, val as i32, 8)),
969 9 => Self(_mm_insert_epi8(self.0, val as i32, 9)),
970 10 => Self(_mm_insert_epi8(self.0, val as i32, 10)),
971 11 => Self(_mm_insert_epi8(self.0, val as i32, 11)),
972 12 => Self(_mm_insert_epi8(self.0, val as i32, 12)),
973 13 => Self(_mm_insert_epi8(self.0, val as i32, 13)),
974 14 => Self(_mm_insert_epi8(self.0, val as i32, 14)),
975 15 => Self(_mm_insert_epi8(self.0, val as i32, 15)),
976 _ => panic!("index out of bounds"),
977 }
978 }
979 }
980
981 #[inline]
982 fn broadcast(val: u8) -> Self {
983 unsafe { Self(_mm_set1_epi8(val as i8)) }
984 }
985
986 #[inline]
987 fn from_iter(iter: impl Iterator<Item = u8>) -> Self {
988 let mut result = Self::ZERO;
989 let arr: &mut [u8; 16] = bytemuck::cast_mut(&mut result);
990 for (i, val) in iter.take(16).enumerate() {
991 arr[i] = val;
992 }
993 result
994 }
995}
996
997#[cfg(test)]
998mod tests {
999 use binius_utils::bytes::BytesMut;
1000 use proptest::{arbitrary::any, proptest};
1001 use rand::{SeedableRng, rngs::StdRng};
1002
1003 use super::*;
1004
1005 fn check_roundtrip<T>(val: M128)
1006 where
1007 T: From<M128>,
1008 M128: From<T>,
1009 {
1010 assert_eq!(M128::from(T::from(val)), val);
1011 }
1012
1013 #[test]
1014 fn test_constants() {
1015 assert_eq!(M128::default(), M128::ZERO);
1016 assert_eq!(M128::from(0u128), M128::ZERO);
1017 assert_eq!(M128::from(1u128), M128::ONE);
1018 }
1019
1020 fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1021 (value >> (index << log_block_len)) & M128::from(1u128 << log_block_len)
1022 }
1023
1024 proptest! {
1025 #[test]
1026 fn test_conversion(a in any::<u128>()) {
1027 check_roundtrip::<u128>(a.into());
1028 check_roundtrip::<__m128i>(a.into());
1029 }
1030
1031 #[test]
1032 fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1033 assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1034 assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1035 assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1036 }
1037
1038 #[test]
1039 fn test_negate(a in any::<u128>()) {
1040 assert_eq!(M128::from(!a), !M128::from(a))
1041 }
1042
1043 #[test]
1044 fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1045 assert_eq!(M128::from(a << b), M128::from(a) << b);
1046 assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1047 }
1048
1049 #[test]
1050 fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1051 let a = M128::from(a);
1052 let b = M128::from(b);
1053
1054 let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1055 let (c, d) = (M128::from(c), M128::from(d));
1056
1057 for i in (0..128>>height).step_by(2) {
1058 assert_eq!(get(c, height, i), get(a, height, i));
1059 assert_eq!(get(c, height, i+1), get(b, height, i));
1060 assert_eq!(get(d, height, i), get(a, height, i+1));
1061 assert_eq!(get(d, height, i+1), get(b, height, i+1));
1062 }
1063 }
1064 }
1065
1066 #[test]
1067 fn test_fill_with_bit() {
1068 assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1069 assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1070 }
1071
1072 #[test]
1073 fn test_eq() {
1074 let a = M128::from(0u128);
1075 let b = M128::from(42u128);
1076 let c = M128::from(u128::MAX);
1077
1078 assert_eq!(a, a);
1079 assert_eq!(b, b);
1080 assert_eq!(c, c);
1081
1082 assert_ne!(a, b);
1083 assert_ne!(a, c);
1084 assert_ne!(b, c);
1085 }
1086
1087 #[test]
1088 fn test_serialize_and_deserialize_m128() {
1089 let mut rng = StdRng::from_seed([0; 32]);
1090
1091 let original_value = M128::from(rng.random::<u128>());
1092
1093 let mut buf = BytesMut::new();
1094 original_value.serialize(&mut buf).unwrap();
1095
1096 let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1097
1098 assert_eq!(original_value, deserialized_value);
1099 }
1100}