1use std::{
4 arch::x86_64::*,
5 array,
6 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10 DeserializeBytes, SerializationError, SerializeBytes,
11 bytes::{Buf, BufMut},
12 serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{
16 Rng,
17 distr::{Distribution, StandardUniform},
18};
19use seq_macro::seq;
20
21use crate::{
22 BinaryField,
23 arch::portable::packed::PackedPrimitiveType,
24 underlier::{
25 Divisible, NumCast, SmallU, SpreadToByte, U2, U4, UnderlierType, UnderlierWithBitOps,
26 impl_divisible_bitmask, mapget, spread_fallback,
27 },
28};
29
30#[derive(Copy, Clone)]
32#[repr(transparent)]
33pub struct M128(pub(super) __m128i);
34
35impl M128 {
36 #[inline(always)]
37 pub const fn from_u128(val: u128) -> Self {
38 let mut result = Self::ZERO;
39 unsafe {
40 result.0 = std::mem::transmute_copy(&val);
41 }
42
43 result
44 }
45}
46
47impl From<__m128i> for M128 {
48 #[inline(always)]
49 fn from(value: __m128i) -> Self {
50 Self(value)
51 }
52}
53
54impl From<u128> for M128 {
55 fn from(value: u128) -> Self {
56 Self(unsafe {
57 assert_eq!(align_of::<u128>(), 16);
59 _mm_load_si128(&raw const value as *const __m128i)
60 })
61 }
62}
63
64impl From<u64> for M128 {
65 fn from(value: u64) -> Self {
66 Self::from(value as u128)
67 }
68}
69
70impl From<u32> for M128 {
71 fn from(value: u32) -> Self {
72 Self::from(value as u128)
73 }
74}
75
76impl From<u16> for M128 {
77 fn from(value: u16) -> Self {
78 Self::from(value as u128)
79 }
80}
81
82impl From<u8> for M128 {
83 fn from(value: u8) -> Self {
84 Self::from(value as u128)
85 }
86}
87
88impl<const N: usize> From<SmallU<N>> for M128 {
89 fn from(value: SmallU<N>) -> Self {
90 Self::from(value.val() as u128)
91 }
92}
93
94impl From<M128> for u128 {
95 fn from(value: M128) -> Self {
96 let mut result = 0u128;
97 unsafe {
98 assert_eq!(align_of::<u128>(), 16);
100 _mm_store_si128(&raw mut result as *mut __m128i, value.0)
101 };
102 result
103 }
104}
105
106impl From<M128> for __m128i {
107 #[inline(always)]
108 fn from(value: M128) -> Self {
109 value.0
110 }
111}
112
113impl SerializeBytes for M128 {
114 fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
115 assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
116
117 let raw_value: u128 = (*self).into();
118
119 write_buf.put_u128_le(raw_value);
120 Ok(())
121 }
122}
123
124impl DeserializeBytes for M128 {
125 fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
126 where
127 Self: Sized,
128 {
129 assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
130
131 let raw_value = read_buf.get_u128_le();
132
133 Ok(Self::from(raw_value))
134 }
135}
136
137impl_divisible_bitmask!(M128, 1, 2, 4);
138
139impl<U: NumCast<u128>> NumCast<M128> for U {
140 #[inline(always)]
141 fn num_cast_from(val: M128) -> Self {
142 Self::num_cast_from(u128::from(val))
143 }
144}
145
146impl Default for M128 {
147 #[inline(always)]
148 fn default() -> Self {
149 Self(unsafe { _mm_setzero_si128() })
150 }
151}
152
153impl BitAnd for M128 {
154 type Output = Self;
155
156 #[inline(always)]
157 fn bitand(self, rhs: Self) -> Self::Output {
158 Self(unsafe { _mm_and_si128(self.0, rhs.0) })
159 }
160}
161
162impl BitAndAssign for M128 {
163 #[inline(always)]
164 fn bitand_assign(&mut self, rhs: Self) {
165 *self = *self & rhs
166 }
167}
168
169impl BitOr for M128 {
170 type Output = Self;
171
172 #[inline(always)]
173 fn bitor(self, rhs: Self) -> Self::Output {
174 Self(unsafe { _mm_or_si128(self.0, rhs.0) })
175 }
176}
177
178impl BitOrAssign for M128 {
179 #[inline(always)]
180 fn bitor_assign(&mut self, rhs: Self) {
181 *self = *self | rhs
182 }
183}
184
185impl BitXor for M128 {
186 type Output = Self;
187
188 #[inline(always)]
189 fn bitxor(self, rhs: Self) -> Self::Output {
190 Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
191 }
192}
193
194impl BitXorAssign for M128 {
195 #[inline(always)]
196 fn bitxor_assign(&mut self, rhs: Self) {
197 *self = *self ^ rhs;
198 }
199}
200
201impl Not for M128 {
202 type Output = Self;
203
204 fn not(self) -> Self::Output {
205 const ONES: __m128i = m128_from_u128!(u128::MAX);
206
207 self ^ Self(ONES)
208 }
209}
210
211pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
213 if left > right { left } else { right }
214}
215
216macro_rules! bitshift_128b {
221 ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
222 unsafe {
223 let carry = $byte_shift($val, 8);
224 seq!(N in 64..128 {
225 if $shift == N {
226 return $bit_shift_64(
227 carry,
228 crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
229 ).into();
230 }
231 });
232 seq!(N in 0..64 {
233 if $shift == N {
234 let carry = $bit_shift_64_opposite(
235 carry,
236 crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
237 );
238
239 let val = $bit_shift_64($val, N);
240 return $or(val, carry).into();
241 }
242 });
243
244 return Default::default()
245 }
246 };
247}
248
249pub(crate) use bitshift_128b;
250
251impl Shr<usize> for M128 {
252 type Output = Self;
253
254 #[inline(always)]
255 fn shr(self, rhs: usize) -> Self::Output {
256 bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
259 }
260}
261
262impl Shl<usize> for M128 {
263 type Output = Self;
264
265 #[inline(always)]
266 fn shl(self, rhs: usize) -> Self::Output {
267 bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
270 }
271}
272
273impl PartialEq for M128 {
274 fn eq(&self, other: &Self) -> bool {
275 unsafe {
276 let neq = _mm_xor_si128(self.0, other.0);
277 _mm_test_all_zeros(neq, neq) == 1
278 }
279 }
280}
281
282impl Eq for M128 {}
283
284impl PartialOrd for M128 {
285 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
286 Some(self.cmp(other))
287 }
288}
289
290impl Ord for M128 {
291 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
292 u128::from(*self).cmp(&u128::from(*other))
293 }
294}
295
296impl Distribution<M128> for StandardUniform {
297 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
298 M128(rng.random())
299 }
300}
301
302impl std::fmt::Display for M128 {
303 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304 let data: u128 = (*self).into();
305 write!(f, "{data:02X?}")
306 }
307}
308
309impl std::fmt::Debug for M128 {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 write!(f, "M128({self})")
312 }
313}
314
315#[repr(align(16))]
316pub struct AlignedData(pub [u128; 1]);
317
318macro_rules! m128_from_u128 {
319 ($val:expr) => {{
320 let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
321 unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
322 }};
323}
324
325pub(super) use m128_from_u128;
326
327impl UnderlierType for M128 {
328 const LOG_BITS: usize = 7;
329}
330
331impl UnderlierWithBitOps for M128 {
332 const ZERO: Self = { Self(m128_from_u128!(0)) };
333 const ONE: Self = { Self(m128_from_u128!(1)) };
334 const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
335
336 #[inline(always)]
337 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
338 unsafe {
339 let (c, d) = interleave_bits(
340 Into::<Self>::into(self).into(),
341 Into::<Self>::into(other).into(),
342 log_block_len,
343 );
344 (Self::from(c), Self::from(d))
345 }
346 }
347
348 #[inline(always)]
349 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
350 where
351 T: UnderlierWithBitOps,
352 Self: Divisible<T>,
353 {
354 match T::LOG_BITS {
355 0 => match log_block_len {
356 0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
357 1 => unsafe {
358 let bits: [u8; 2] =
359 array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
360
361 _mm_set_epi64x(
362 u64::fill_with_bit(bits[1]) as i64,
363 u64::fill_with_bit(bits[0]) as i64,
364 )
365 .into()
366 },
367 2 => unsafe {
368 let bits: [u8; 4] =
369 array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
370
371 _mm_set_epi32(
372 u32::fill_with_bit(bits[3]) as i32,
373 u32::fill_with_bit(bits[2]) as i32,
374 u32::fill_with_bit(bits[1]) as i32,
375 u32::fill_with_bit(bits[0]) as i32,
376 )
377 .into()
378 },
379 3 => unsafe {
380 let bits: [u8; 8] =
381 array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
382
383 _mm_set_epi16(
384 u16::fill_with_bit(bits[7]) as i16,
385 u16::fill_with_bit(bits[6]) as i16,
386 u16::fill_with_bit(bits[5]) as i16,
387 u16::fill_with_bit(bits[4]) as i16,
388 u16::fill_with_bit(bits[3]) as i16,
389 u16::fill_with_bit(bits[2]) as i16,
390 u16::fill_with_bit(bits[1]) as i16,
391 u16::fill_with_bit(bits[0]) as i16,
392 )
393 .into()
394 },
395 4 => unsafe {
396 let bits: [u8; 16] =
397 array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
398
399 _mm_set_epi8(
400 u8::fill_with_bit(bits[15]) as i8,
401 u8::fill_with_bit(bits[14]) as i8,
402 u8::fill_with_bit(bits[13]) as i8,
403 u8::fill_with_bit(bits[12]) as i8,
404 u8::fill_with_bit(bits[11]) as i8,
405 u8::fill_with_bit(bits[10]) as i8,
406 u8::fill_with_bit(bits[9]) as i8,
407 u8::fill_with_bit(bits[8]) as i8,
408 u8::fill_with_bit(bits[7]) as i8,
409 u8::fill_with_bit(bits[6]) as i8,
410 u8::fill_with_bit(bits[5]) as i8,
411 u8::fill_with_bit(bits[4]) as i8,
412 u8::fill_with_bit(bits[3]) as i8,
413 u8::fill_with_bit(bits[2]) as i8,
414 u8::fill_with_bit(bits[1]) as i8,
415 u8::fill_with_bit(bits[0]) as i8,
416 )
417 .into()
418 },
419 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
420 },
421 1 => match log_block_len {
422 0 => unsafe {
423 let value =
424 U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
425
426 _mm_set1_epi8(value as i8).into()
427 },
428 1 => {
429 let bytes: [u8; 2] = array::from_fn(|i| {
430 U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
431 });
432
433 Self::from_fn::<u8>(|i| bytes[i / 8])
434 }
435 2 => {
436 let bytes: [u8; 4] = array::from_fn(|i| {
437 U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
438 });
439
440 Self::from_fn::<u8>(|i| bytes[i / 4])
441 }
442 3 => {
443 let bytes: [u8; 8] = array::from_fn(|i| {
444 U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
445 .spread_to_byte()
446 });
447
448 Self::from_fn::<u8>(|i| bytes[i / 2])
449 }
450 4 => {
451 let bytes: [u8; 16] = array::from_fn(|i| {
452 U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
453 .spread_to_byte()
454 });
455
456 Self::from_fn::<u8>(|i| bytes[i])
457 }
458 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
459 },
460 2 => match log_block_len {
461 0 => {
462 let value =
463 U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
464
465 unsafe { _mm_set1_epi8(value as i8).into() }
466 }
467 1 => {
468 let values: [u8; 2] = array::from_fn(|i| {
469 U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
470 });
471
472 Self::from_fn::<u8>(|i| values[i / 8])
473 }
474 2 => {
475 let values: [u8; 4] = array::from_fn(|i| {
476 U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
477 .spread_to_byte()
478 });
479
480 Self::from_fn::<u8>(|i| values[i / 4])
481 }
482 3 => {
483 let values: [u8; 8] = array::from_fn(|i| {
484 U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
485 .spread_to_byte()
486 });
487
488 Self::from_fn::<u8>(|i| values[i / 2])
489 }
490 4 => {
491 let values: [u8; 16] = array::from_fn(|i| {
492 U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
493 .spread_to_byte()
494 });
495
496 Self::from_fn::<u8>(|i| values[i])
497 }
498 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
499 },
500 3 => match log_block_len {
501 0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
502 1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
503 2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
504 3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
505 4 => self,
506 _ => panic!("unsupported block length"),
507 },
508 4 => match log_block_len {
509 0 => {
510 let value = (u128::from(self) >> (block_idx * 16)) as u16;
511
512 unsafe { _mm_set1_epi16(value as i16).into() }
513 }
514 1 => {
515 let values: [u16; 2] =
516 array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
517
518 Self::from_fn::<u16>(|i| values[i / 4])
519 }
520 2 => {
521 let values: [u16; 4] =
522 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
523
524 Self::from_fn::<u16>(|i| values[i / 2])
525 }
526 3 => self,
527 _ => panic!("unsupported block length"),
528 },
529 5 => match log_block_len {
530 0 => unsafe {
531 let value = (u128::from(self) >> (block_idx * 32)) as u32;
532
533 _mm_set1_epi32(value as i32).into()
534 },
535 1 => {
536 let values: [u32; 2] =
537 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
538
539 Self::from_fn::<u32>(|i| values[i / 2])
540 }
541 2 => self,
542 _ => panic!("unsupported block length"),
543 },
544 6 => match log_block_len {
545 0 => unsafe {
546 let value = (u128::from(self) >> (block_idx * 64)) as u64;
547
548 _mm_set1_epi64x(value as i64).into()
549 },
550 1 => self,
551 _ => panic!("unsupported block length"),
552 },
553 7 => self,
554 _ => panic!("unsupported bit length"),
555 }
556 }
557}
558
559unsafe impl Zeroable for M128 {}
560
561unsafe impl Pod for M128 {}
562
563unsafe impl Send for M128 {}
564
565unsafe impl Sync for M128 {}
566
567static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
568static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
569static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
570static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
571
572const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
573 log_block_len: usize,
574 t_log_bits: usize,
575) -> [M128; BLOCK_IDX_AMOUNT] {
576 let element_log_width = t_log_bits - 3;
577
578 let element_width = 1 << element_log_width;
579
580 let block_size = 1 << (log_block_len + element_log_width);
581 let repeat = 1 << (4 - element_log_width - log_block_len);
582 let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
583
584 let mut block_idx = 0;
585
586 while block_idx < BLOCK_IDX_AMOUNT {
587 let base = block_idx * block_size;
588 let mut j = 0;
589 while j < 16 {
590 masks[block_idx][j] =
591 (base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
592 j += 1;
593 }
594 block_idx += 1;
595 }
596 let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
597
598 let mut block_idx = 0;
599
600 while block_idx < BLOCK_IDX_AMOUNT {
601 m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
602 block_idx += 1;
603 }
604
605 m128_masks
606}
607
608impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
609 fn from(value: __m128i) -> Self {
610 M128::from(value).into()
611 }
612}
613
614impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
615 fn from(value: u128) -> Self {
616 M128::from(value).into()
617 }
618}
619
620impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
621 fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
622 value.to_underlier().into()
623 }
624}
625
626#[inline]
627unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
628 match log_block_len {
629 0 => unsafe {
630 let mask = _mm_set1_epi8(0x55i8);
631 interleave_bits_imm::<1>(a, b, mask)
632 },
633 1 => unsafe {
634 let mask = _mm_set1_epi8(0x33i8);
635 interleave_bits_imm::<2>(a, b, mask)
636 },
637 2 => unsafe {
638 let mask = _mm_set1_epi8(0x0fi8);
639 interleave_bits_imm::<4>(a, b, mask)
640 },
641 3 => unsafe {
642 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
643 let a = _mm_shuffle_epi8(a, shuffle);
644 let b = _mm_shuffle_epi8(b, shuffle);
645 let a_prime = _mm_unpacklo_epi8(a, b);
646 let b_prime = _mm_unpackhi_epi8(a, b);
647 (a_prime, b_prime)
648 },
649 4 => unsafe {
650 let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
651 let a = _mm_shuffle_epi8(a, shuffle);
652 let b = _mm_shuffle_epi8(b, shuffle);
653 let a_prime = _mm_unpacklo_epi16(a, b);
654 let b_prime = _mm_unpackhi_epi16(a, b);
655 (a_prime, b_prime)
656 },
657 5 => unsafe {
658 let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
659 let a = _mm_shuffle_epi8(a, shuffle);
660 let b = _mm_shuffle_epi8(b, shuffle);
661 let a_prime = _mm_unpacklo_epi32(a, b);
662 let b_prime = _mm_unpackhi_epi32(a, b);
663 (a_prime, b_prime)
664 },
665 6 => unsafe {
666 let a_prime = _mm_unpacklo_epi64(a, b);
667 let b_prime = _mm_unpackhi_epi64(a, b);
668 (a_prime, b_prime)
669 },
670 _ => panic!("unsupported block length"),
671 }
672}
673
674#[inline]
675unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
676 a: __m128i,
677 b: __m128i,
678 mask: __m128i,
679) -> (__m128i, __m128i) {
680 unsafe {
681 let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
682 let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
683 let b_prime = _mm_xor_si128(b, t);
684 (a_prime, b_prime)
685 }
686}
687
688impl Divisible<u128> for M128 {
691 const LOG_N: usize = 0;
692
693 #[inline]
694 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone {
695 std::iter::once(u128::from(value))
696 }
697
698 #[inline]
699 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
700 std::iter::once(u128::from(*value))
701 }
702
703 #[inline]
704 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
705 slice.iter().map(|&v| u128::from(v))
706 }
707
708 #[inline]
709 fn get(self, index: usize) -> u128 {
710 assert!(index == 0, "index out of bounds");
711 u128::from(self)
712 }
713
714 #[inline]
715 fn set(self, index: usize, val: u128) -> Self {
716 assert!(index == 0, "index out of bounds");
717 Self::from(val)
718 }
719
720 #[inline]
721 fn broadcast(val: u128) -> Self {
722 Self::from(val)
723 }
724
725 #[inline]
726 fn from_iter(mut iter: impl Iterator<Item = u128>) -> Self {
727 iter.next().map(Self::from).unwrap_or(Self::ZERO)
728 }
729}
730
731impl Divisible<u64> for M128 {
732 const LOG_N: usize = 1;
733
734 #[inline]
735 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone {
736 mapget::value_iter(value)
737 }
738
739 #[inline]
740 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
741 mapget::value_iter(*value)
742 }
743
744 #[inline]
745 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
746 mapget::slice_iter(slice)
747 }
748
749 #[inline]
750 fn get(self, index: usize) -> u64 {
751 unsafe {
752 match index {
753 0 => _mm_extract_epi64(self.0, 0) as u64,
754 1 => _mm_extract_epi64(self.0, 1) as u64,
755 _ => panic!("index out of bounds"),
756 }
757 }
758 }
759
760 #[inline]
761 fn set(self, index: usize, val: u64) -> Self {
762 unsafe {
763 match index {
764 0 => Self(_mm_insert_epi64(self.0, val as i64, 0)),
765 1 => Self(_mm_insert_epi64(self.0, val as i64, 1)),
766 _ => panic!("index out of bounds"),
767 }
768 }
769 }
770
771 #[inline]
772 fn broadcast(val: u64) -> Self {
773 unsafe { Self(_mm_set1_epi64x(val as i64)) }
774 }
775
776 #[inline]
777 fn from_iter(iter: impl Iterator<Item = u64>) -> Self {
778 let mut result = Self::ZERO;
779 let arr: &mut [u64; 2] = bytemuck::cast_mut(&mut result);
780 for (i, val) in iter.take(2).enumerate() {
781 arr[i] = val;
782 }
783 result
784 }
785}
786
787impl Divisible<u32> for M128 {
788 const LOG_N: usize = 2;
789
790 #[inline]
791 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone {
792 mapget::value_iter(value)
793 }
794
795 #[inline]
796 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
797 mapget::value_iter(*value)
798 }
799
800 #[inline]
801 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
802 mapget::slice_iter(slice)
803 }
804
805 #[inline]
806 fn get(self, index: usize) -> u32 {
807 unsafe {
808 match index {
809 0 => _mm_extract_epi32(self.0, 0) as u32,
810 1 => _mm_extract_epi32(self.0, 1) as u32,
811 2 => _mm_extract_epi32(self.0, 2) as u32,
812 3 => _mm_extract_epi32(self.0, 3) as u32,
813 _ => panic!("index out of bounds"),
814 }
815 }
816 }
817
818 #[inline]
819 fn set(self, index: usize, val: u32) -> Self {
820 unsafe {
821 match index {
822 0 => Self(_mm_insert_epi32(self.0, val as i32, 0)),
823 1 => Self(_mm_insert_epi32(self.0, val as i32, 1)),
824 2 => Self(_mm_insert_epi32(self.0, val as i32, 2)),
825 3 => Self(_mm_insert_epi32(self.0, val as i32, 3)),
826 _ => panic!("index out of bounds"),
827 }
828 }
829 }
830
831 #[inline]
832 fn broadcast(val: u32) -> Self {
833 unsafe { Self(_mm_set1_epi32(val as i32)) }
834 }
835
836 #[inline]
837 fn from_iter(iter: impl Iterator<Item = u32>) -> Self {
838 let mut result = Self::ZERO;
839 let arr: &mut [u32; 4] = bytemuck::cast_mut(&mut result);
840 for (i, val) in iter.take(4).enumerate() {
841 arr[i] = val;
842 }
843 result
844 }
845}
846
847impl Divisible<u16> for M128 {
848 const LOG_N: usize = 3;
849
850 #[inline]
851 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone {
852 mapget::value_iter(value)
853 }
854
855 #[inline]
856 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
857 mapget::value_iter(*value)
858 }
859
860 #[inline]
861 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
862 mapget::slice_iter(slice)
863 }
864
865 #[inline]
866 fn get(self, index: usize) -> u16 {
867 unsafe {
868 match index {
869 0 => _mm_extract_epi16(self.0, 0) as u16,
870 1 => _mm_extract_epi16(self.0, 1) as u16,
871 2 => _mm_extract_epi16(self.0, 2) as u16,
872 3 => _mm_extract_epi16(self.0, 3) as u16,
873 4 => _mm_extract_epi16(self.0, 4) as u16,
874 5 => _mm_extract_epi16(self.0, 5) as u16,
875 6 => _mm_extract_epi16(self.0, 6) as u16,
876 7 => _mm_extract_epi16(self.0, 7) as u16,
877 _ => panic!("index out of bounds"),
878 }
879 }
880 }
881
882 #[inline]
883 fn set(self, index: usize, val: u16) -> Self {
884 unsafe {
885 match index {
886 0 => Self(_mm_insert_epi16(self.0, val as i32, 0)),
887 1 => Self(_mm_insert_epi16(self.0, val as i32, 1)),
888 2 => Self(_mm_insert_epi16(self.0, val as i32, 2)),
889 3 => Self(_mm_insert_epi16(self.0, val as i32, 3)),
890 4 => Self(_mm_insert_epi16(self.0, val as i32, 4)),
891 5 => Self(_mm_insert_epi16(self.0, val as i32, 5)),
892 6 => Self(_mm_insert_epi16(self.0, val as i32, 6)),
893 7 => Self(_mm_insert_epi16(self.0, val as i32, 7)),
894 _ => panic!("index out of bounds"),
895 }
896 }
897 }
898
899 #[inline]
900 fn broadcast(val: u16) -> Self {
901 unsafe { Self(_mm_set1_epi16(val as i16)) }
902 }
903
904 #[inline]
905 fn from_iter(iter: impl Iterator<Item = u16>) -> Self {
906 let mut result = Self::ZERO;
907 let arr: &mut [u16; 8] = bytemuck::cast_mut(&mut result);
908 for (i, val) in iter.take(8).enumerate() {
909 arr[i] = val;
910 }
911 result
912 }
913}
914
915impl Divisible<u8> for M128 {
916 const LOG_N: usize = 4;
917
918 #[inline]
919 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone {
920 mapget::value_iter(value)
921 }
922
923 #[inline]
924 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
925 mapget::value_iter(*value)
926 }
927
928 #[inline]
929 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
930 mapget::slice_iter(slice)
931 }
932
933 #[inline]
934 fn get(self, index: usize) -> u8 {
935 unsafe {
936 match index {
937 0 => _mm_extract_epi8(self.0, 0) as u8,
938 1 => _mm_extract_epi8(self.0, 1) as u8,
939 2 => _mm_extract_epi8(self.0, 2) as u8,
940 3 => _mm_extract_epi8(self.0, 3) as u8,
941 4 => _mm_extract_epi8(self.0, 4) as u8,
942 5 => _mm_extract_epi8(self.0, 5) as u8,
943 6 => _mm_extract_epi8(self.0, 6) as u8,
944 7 => _mm_extract_epi8(self.0, 7) as u8,
945 8 => _mm_extract_epi8(self.0, 8) as u8,
946 9 => _mm_extract_epi8(self.0, 9) as u8,
947 10 => _mm_extract_epi8(self.0, 10) as u8,
948 11 => _mm_extract_epi8(self.0, 11) as u8,
949 12 => _mm_extract_epi8(self.0, 12) as u8,
950 13 => _mm_extract_epi8(self.0, 13) as u8,
951 14 => _mm_extract_epi8(self.0, 14) as u8,
952 15 => _mm_extract_epi8(self.0, 15) as u8,
953 _ => panic!("index out of bounds"),
954 }
955 }
956 }
957
958 #[inline]
959 fn set(self, index: usize, val: u8) -> Self {
960 unsafe {
961 match index {
962 0 => Self(_mm_insert_epi8(self.0, val as i32, 0)),
963 1 => Self(_mm_insert_epi8(self.0, val as i32, 1)),
964 2 => Self(_mm_insert_epi8(self.0, val as i32, 2)),
965 3 => Self(_mm_insert_epi8(self.0, val as i32, 3)),
966 4 => Self(_mm_insert_epi8(self.0, val as i32, 4)),
967 5 => Self(_mm_insert_epi8(self.0, val as i32, 5)),
968 6 => Self(_mm_insert_epi8(self.0, val as i32, 6)),
969 7 => Self(_mm_insert_epi8(self.0, val as i32, 7)),
970 8 => Self(_mm_insert_epi8(self.0, val as i32, 8)),
971 9 => Self(_mm_insert_epi8(self.0, val as i32, 9)),
972 10 => Self(_mm_insert_epi8(self.0, val as i32, 10)),
973 11 => Self(_mm_insert_epi8(self.0, val as i32, 11)),
974 12 => Self(_mm_insert_epi8(self.0, val as i32, 12)),
975 13 => Self(_mm_insert_epi8(self.0, val as i32, 13)),
976 14 => Self(_mm_insert_epi8(self.0, val as i32, 14)),
977 15 => Self(_mm_insert_epi8(self.0, val as i32, 15)),
978 _ => panic!("index out of bounds"),
979 }
980 }
981 }
982
983 #[inline]
984 fn broadcast(val: u8) -> Self {
985 unsafe { Self(_mm_set1_epi8(val as i8)) }
986 }
987
988 #[inline]
989 fn from_iter(iter: impl Iterator<Item = u8>) -> Self {
990 let mut result = Self::ZERO;
991 let arr: &mut [u8; 16] = bytemuck::cast_mut(&mut result);
992 for (i, val) in iter.take(16).enumerate() {
993 arr[i] = val;
994 }
995 result
996 }
997}
998
999#[cfg(test)]
1000mod tests {
1001 use binius_utils::bytes::BytesMut;
1002 use proptest::{arbitrary::any, proptest};
1003 use rand::{SeedableRng, rngs::StdRng};
1004
1005 use super::*;
1006
1007 fn check_roundtrip<T>(val: M128)
1008 where
1009 T: From<M128>,
1010 M128: From<T>,
1011 {
1012 assert_eq!(M128::from(T::from(val)), val);
1013 }
1014
1015 #[test]
1016 fn test_constants() {
1017 assert_eq!(M128::default(), M128::ZERO);
1018 assert_eq!(M128::from(0u128), M128::ZERO);
1019 assert_eq!(M128::from(1u128), M128::ONE);
1020 }
1021
1022 fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1023 (value >> (index << log_block_len)) & M128::from(1u128 << log_block_len)
1024 }
1025
1026 proptest! {
1027 #[test]
1028 fn test_conversion(a in any::<u128>()) {
1029 check_roundtrip::<u128>(a.into());
1030 check_roundtrip::<__m128i>(a.into());
1031 }
1032
1033 #[test]
1034 fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1035 assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1036 assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1037 assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1038 }
1039
1040 #[test]
1041 fn test_negate(a in any::<u128>()) {
1042 assert_eq!(M128::from(!a), !M128::from(a))
1043 }
1044
1045 #[test]
1046 fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1047 assert_eq!(M128::from(a << b), M128::from(a) << b);
1048 assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1049 }
1050
1051 #[test]
1052 fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1053 let a = M128::from(a);
1054 let b = M128::from(b);
1055
1056 let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1057 let (c, d) = (M128::from(c), M128::from(d));
1058
1059 for i in (0..128>>height).step_by(2) {
1060 assert_eq!(get(c, height, i), get(a, height, i));
1061 assert_eq!(get(c, height, i+1), get(b, height, i));
1062 assert_eq!(get(d, height, i), get(a, height, i+1));
1063 assert_eq!(get(d, height, i+1), get(b, height, i+1));
1064 }
1065 }
1066 }
1067
1068 #[test]
1069 fn test_fill_with_bit() {
1070 assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1071 assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1072 }
1073
1074 #[test]
1075 fn test_eq() {
1076 let a = M128::from(0u128);
1077 let b = M128::from(42u128);
1078 let c = M128::from(u128::MAX);
1079
1080 assert_eq!(a, a);
1081 assert_eq!(b, b);
1082 assert_eq!(c, c);
1083
1084 assert_ne!(a, b);
1085 assert_ne!(a, c);
1086 assert_ne!(b, c);
1087 }
1088
1089 #[test]
1090 fn test_serialize_and_deserialize_m128() {
1091 let mut rng = StdRng::from_seed([0; 32]);
1092
1093 let original_value = M128::from(rng.random::<u128>());
1094
1095 let mut buf = BytesMut::new();
1096 original_value.serialize(&mut buf).unwrap();
1097
1098 let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1099
1100 assert_eq!(original_value, deserialized_value);
1101 }
1102}