1use std::{
5 arch::x86_64::*,
6 array, mem,
7 ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
8};
9
10use binius_utils::{
11 DeserializeBytes, SerializationError, SerializeBytes,
12 bytes::{Buf, BufMut},
13 serialization::{assert_enough_data_for, assert_enough_space_for},
14};
15use bytemuck::{Pod, Zeroable};
16use rand::{
17 distr::{Distribution, StandardUniform},
18 prelude::*,
19};
20use seq_macro::seq;
21
22use crate::{
23 BinaryField,
24 arch::portable::packed::PackedPrimitiveType,
25 underlier::{
26 Divisible, NumCast, SmallU, SpreadToByte, U2, U4, UnderlierType, impl_divisible_bitmask,
27 impl_divisible_self, mapget, spread_fallback,
28 },
29};
30
31pub const fn m128i_from_u128(x: u128) -> __m128i {
32 let _: [(); align_of::<u128>()] = [(); align_of::<__m128i>()];
34 unsafe { mem::transmute(x) }
35}
36
37#[derive(Copy, Clone)]
39#[repr(transparent)]
40pub struct M128(pub(super) __m128i);
41
42impl M128 {
43 #[inline(always)]
44 pub const fn from_u128(val: u128) -> Self {
45 Self(m128i_from_u128(val))
46 }
47}
48
49impl From<__m128i> for M128 {
50 #[inline(always)]
51 fn from(value: __m128i) -> Self {
52 Self(value)
53 }
54}
55
56impl From<u128> for M128 {
57 fn from(value: u128) -> Self {
58 Self(m128i_from_u128(value))
59 }
60}
61
62impl From<u64> for M128 {
63 fn from(value: u64) -> Self {
64 Self::from(value as u128)
65 }
66}
67
68impl From<u32> for M128 {
69 fn from(value: u32) -> Self {
70 Self::from(value as u128)
71 }
72}
73
74impl From<u16> for M128 {
75 fn from(value: u16) -> Self {
76 Self::from(value as u128)
77 }
78}
79
80impl From<u8> for M128 {
81 fn from(value: u8) -> Self {
82 Self::from(value as u128)
83 }
84}
85
86impl<const N: usize> From<SmallU<N>> for M128 {
87 fn from(value: SmallU<N>) -> Self {
88 Self::from(value.val() as u128)
89 }
90}
91
92impl From<M128> for u128 {
93 fn from(value: M128) -> Self {
94 let mut result = 0u128;
95 unsafe {
96 assert_eq!(align_of::<u128>(), 16);
98 _mm_store_si128(&raw mut result as *mut __m128i, value.0)
99 };
100 result
101 }
102}
103
104impl From<M128> for __m128i {
105 #[inline(always)]
106 fn from(value: M128) -> Self {
107 value.0
108 }
109}
110
111impl SerializeBytes for M128 {
112 fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
113 assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
114
115 let raw_value: u128 = (*self).into();
116
117 write_buf.put_u128_le(raw_value);
118 Ok(())
119 }
120}
121
122impl DeserializeBytes for M128 {
123 fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
124 where
125 Self: Sized,
126 {
127 assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
128
129 let raw_value = read_buf.get_u128_le();
130
131 Ok(Self::from(raw_value))
132 }
133}
134
135impl_divisible_bitmask!(M128, 1, 2, 4);
136
137impl<U: NumCast<u128>> NumCast<M128> for U {
138 #[inline(always)]
139 fn num_cast_from(val: M128) -> Self {
140 Self::num_cast_from(u128::from(val))
141 }
142}
143
144impl Default for M128 {
145 #[inline(always)]
146 fn default() -> Self {
147 Self(unsafe { _mm_setzero_si128() })
148 }
149}
150
151impl BitAnd for M128 {
152 type Output = Self;
153
154 #[inline(always)]
155 fn bitand(self, rhs: Self) -> Self::Output {
156 Self(unsafe { _mm_and_si128(self.0, rhs.0) })
157 }
158}
159
160impl BitAndAssign for M128 {
161 #[inline(always)]
162 fn bitand_assign(&mut self, rhs: Self) {
163 *self = *self & rhs
164 }
165}
166
167impl BitOr for M128 {
168 type Output = Self;
169
170 #[inline(always)]
171 fn bitor(self, rhs: Self) -> Self::Output {
172 Self(unsafe { _mm_or_si128(self.0, rhs.0) })
173 }
174}
175
176impl BitOrAssign for M128 {
177 #[inline(always)]
178 fn bitor_assign(&mut self, rhs: Self) {
179 *self = *self | rhs
180 }
181}
182
183impl BitXor for M128 {
184 type Output = Self;
185
186 #[inline(always)]
187 fn bitxor(self, rhs: Self) -> Self::Output {
188 Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
189 }
190}
191
192impl BitXorAssign for M128 {
193 #[inline(always)]
194 fn bitxor_assign(&mut self, rhs: Self) {
195 *self = *self ^ rhs;
196 }
197}
198
199impl Not for M128 {
200 type Output = Self;
201
202 fn not(self) -> Self::Output {
203 const ONES: M128 = M128::from_u128(u128::MAX);
204
205 self ^ ONES
206 }
207}
208
209const fn max_i32(left: i32, right: i32) -> i32 {
211 if left > right { left } else { right }
212}
213
214macro_rules! bitshift_128b {
219 ($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
220 unsafe {
221 let carry = $byte_shift($val, 8);
222 seq!(N in 64..128 {
223 if $shift == N {
224 return $bit_shift_64(
225 carry,
226 crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
227 ).into();
228 }
229 });
230 seq!(N in 0..64 {
231 if $shift == N {
232 let carry = $bit_shift_64_opposite(
233 carry,
234 crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
235 );
236
237 let val = $bit_shift_64($val, N);
238 return $or(val, carry).into();
239 }
240 });
241
242 return Default::default()
243 }
244 };
245}
246
247impl Shr<usize> for M128 {
248 type Output = Self;
249
250 #[inline(always)]
251 fn shr(self, rhs: usize) -> Self::Output {
252 bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
255 }
256}
257
258impl Shl<usize> for M128 {
259 type Output = Self;
260
261 #[inline(always)]
262 fn shl(self, rhs: usize) -> Self::Output {
263 bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
266 }
267}
268
269impl PartialEq for M128 {
270 fn eq(&self, other: &Self) -> bool {
271 unsafe {
272 let neq = _mm_xor_si128(self.0, other.0);
273 _mm_test_all_zeros(neq, neq) == 1
274 }
275 }
276}
277
278impl Eq for M128 {}
279
280impl PartialOrd for M128 {
281 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
282 Some(self.cmp(other))
283 }
284}
285
286impl Ord for M128 {
287 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
288 u128::from(*self).cmp(&u128::from(*other))
289 }
290}
291
292impl std::hash::Hash for M128 {
293 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
294 u128::from(*self).hash(state);
295 }
296}
297
298impl std::fmt::LowerHex for M128 {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 std::fmt::LowerHex::fmt(&u128::from(*self), f)
301 }
302}
303
304impl Distribution<M128> for StandardUniform {
305 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
306 M128(rng.random())
307 }
308}
309
310impl std::fmt::Display for M128 {
311 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312 let data: u128 = (*self).into();
313 write!(f, "{data:02X?}")
314 }
315}
316
317impl std::fmt::Debug for M128 {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 write!(f, "M128({self})")
320 }
321}
322
323impl UnderlierType for M128 {
324 const LOG_BITS: usize = 7;
325 const ZERO: Self = { Self::from_u128(0) };
326 const ONE: Self = { Self::from_u128(1) };
327 const ONES: Self = { Self::from_u128(u128::MAX) };
328
329 #[inline(always)]
330 fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
331 unsafe {
332 let (c, d) = interleave_bits(
333 Into::<Self>::into(self).into(),
334 Into::<Self>::into(other).into(),
335 log_block_len,
336 );
337 (Self::from(c), Self::from(d))
338 }
339 }
340
341 #[inline(always)]
342 unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
343 where
344 T: UnderlierType,
345 Self: Divisible<T>,
346 {
347 match T::LOG_BITS {
348 0 => match log_block_len {
349 0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
350 1 => unsafe {
351 let bits: [u8; 2] =
352 array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
353
354 _mm_set_epi64x(
355 u64::fill_with_bit(bits[1]) as i64,
356 u64::fill_with_bit(bits[0]) as i64,
357 )
358 .into()
359 },
360 2 => unsafe {
361 let bits: [u8; 4] =
362 array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
363
364 _mm_set_epi32(
365 u32::fill_with_bit(bits[3]) as i32,
366 u32::fill_with_bit(bits[2]) as i32,
367 u32::fill_with_bit(bits[1]) as i32,
368 u32::fill_with_bit(bits[0]) as i32,
369 )
370 .into()
371 },
372 3 => unsafe {
373 let bits: [u8; 8] =
374 array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
375
376 _mm_set_epi16(
377 u16::fill_with_bit(bits[7]) as i16,
378 u16::fill_with_bit(bits[6]) as i16,
379 u16::fill_with_bit(bits[5]) as i16,
380 u16::fill_with_bit(bits[4]) as i16,
381 u16::fill_with_bit(bits[3]) as i16,
382 u16::fill_with_bit(bits[2]) as i16,
383 u16::fill_with_bit(bits[1]) as i16,
384 u16::fill_with_bit(bits[0]) as i16,
385 )
386 .into()
387 },
388 4 => unsafe {
389 let bits: [u8; 16] =
390 array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
391
392 _mm_set_epi8(
393 u8::fill_with_bit(bits[15]) as i8,
394 u8::fill_with_bit(bits[14]) as i8,
395 u8::fill_with_bit(bits[13]) as i8,
396 u8::fill_with_bit(bits[12]) as i8,
397 u8::fill_with_bit(bits[11]) as i8,
398 u8::fill_with_bit(bits[10]) as i8,
399 u8::fill_with_bit(bits[9]) as i8,
400 u8::fill_with_bit(bits[8]) as i8,
401 u8::fill_with_bit(bits[7]) as i8,
402 u8::fill_with_bit(bits[6]) as i8,
403 u8::fill_with_bit(bits[5]) as i8,
404 u8::fill_with_bit(bits[4]) as i8,
405 u8::fill_with_bit(bits[3]) as i8,
406 u8::fill_with_bit(bits[2]) as i8,
407 u8::fill_with_bit(bits[1]) as i8,
408 u8::fill_with_bit(bits[0]) as i8,
409 )
410 .into()
411 },
412 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
413 },
414 1 => match log_block_len {
415 0 => unsafe {
416 let value =
417 U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
418
419 _mm_set1_epi8(value as i8).into()
420 },
421 1 => {
422 let bytes: [u8; 2] = array::from_fn(|i| {
423 U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
424 });
425
426 Self::from_fn::<u8>(|i| bytes[i / 8])
427 }
428 2 => {
429 let bytes: [u8; 4] = array::from_fn(|i| {
430 U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
431 });
432
433 Self::from_fn::<u8>(|i| bytes[i / 4])
434 }
435 3 => {
436 let bytes: [u8; 8] = array::from_fn(|i| {
437 U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
438 .spread_to_byte()
439 });
440
441 Self::from_fn::<u8>(|i| bytes[i / 2])
442 }
443 4 => {
444 let bytes: [u8; 16] = array::from_fn(|i| {
445 U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
446 .spread_to_byte()
447 });
448
449 Self::from_fn::<u8>(|i| bytes[i])
450 }
451 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
452 },
453 2 => match log_block_len {
454 0 => {
455 let value =
456 U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
457
458 unsafe { _mm_set1_epi8(value as i8).into() }
459 }
460 1 => {
461 let values: [u8; 2] = array::from_fn(|i| {
462 U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
463 });
464
465 Self::from_fn::<u8>(|i| values[i / 8])
466 }
467 2 => {
468 let values: [u8; 4] = array::from_fn(|i| {
469 U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
470 .spread_to_byte()
471 });
472
473 Self::from_fn::<u8>(|i| values[i / 4])
474 }
475 3 => {
476 let values: [u8; 8] = array::from_fn(|i| {
477 U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
478 .spread_to_byte()
479 });
480
481 Self::from_fn::<u8>(|i| values[i / 2])
482 }
483 4 => {
484 let values: [u8; 16] = array::from_fn(|i| {
485 U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
486 .spread_to_byte()
487 });
488
489 Self::from_fn::<u8>(|i| values[i])
490 }
491 _ => unsafe { spread_fallback(self, log_block_len, block_idx) },
492 },
493 3 => match log_block_len {
494 0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
495 1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
496 2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
497 3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
498 4 => self,
499 _ => panic!("unsupported block length"),
500 },
501 4 => match log_block_len {
502 0 => {
503 let value = (u128::from(self) >> (block_idx * 16)) as u16;
504
505 unsafe { _mm_set1_epi16(value as i16).into() }
506 }
507 1 => {
508 let values: [u16; 2] =
509 array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
510
511 Self::from_fn::<u16>(|i| values[i / 4])
512 }
513 2 => {
514 let values: [u16; 4] =
515 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
516
517 Self::from_fn::<u16>(|i| values[i / 2])
518 }
519 3 => self,
520 _ => panic!("unsupported block length"),
521 },
522 5 => match log_block_len {
523 0 => unsafe {
524 let value = (u128::from(self) >> (block_idx * 32)) as u32;
525
526 _mm_set1_epi32(value as i32).into()
527 },
528 1 => {
529 let values: [u32; 2] =
530 array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
531
532 Self::from_fn::<u32>(|i| values[i / 2])
533 }
534 2 => self,
535 _ => panic!("unsupported block length"),
536 },
537 6 => match log_block_len {
538 0 => unsafe {
539 let value = (u128::from(self) >> (block_idx * 64)) as u64;
540
541 _mm_set1_epi64x(value as i64).into()
542 },
543 1 => self,
544 _ => panic!("unsupported block length"),
545 },
546 7 => self,
547 _ => panic!("unsupported bit length"),
548 }
549 }
550}
551
552unsafe impl Zeroable for M128 {}
553
554unsafe impl Pod for M128 {}
555
556unsafe impl Send for M128 {}
557
558unsafe impl Sync for M128 {}
559
560static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
561static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
562static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
563static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
564
565const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
566 log_block_len: usize,
567 t_log_bits: usize,
568) -> [M128; BLOCK_IDX_AMOUNT] {
569 let element_log_width = t_log_bits - 3;
570
571 let element_width = 1 << element_log_width;
572
573 let block_size = 1 << (log_block_len + element_log_width);
574 let repeat = 1 << (4 - element_log_width - log_block_len);
575 let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
576
577 let mut block_idx = 0;
578
579 while block_idx < BLOCK_IDX_AMOUNT {
580 let base = block_idx * block_size;
581 let mut j = 0;
582 while j < 16 {
583 masks[block_idx][j] =
584 (base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
585 j += 1;
586 }
587 block_idx += 1;
588 }
589 let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
590
591 let mut block_idx = 0;
592
593 while block_idx < BLOCK_IDX_AMOUNT {
594 m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
595 block_idx += 1;
596 }
597
598 m128_masks
599}
600
601impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
602 fn from(value: __m128i) -> Self {
603 M128::from(value).into()
604 }
605}
606
607impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
608 fn from(value: u128) -> Self {
609 M128::from(value).into()
610 }
611}
612
613impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
614 fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
615 value.to_underlier().into()
616 }
617}
618
619#[inline]
620unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
621 match log_block_len {
622 0 => unsafe {
623 let mask = _mm_set1_epi8(0x55i8);
624 interleave_bits_imm::<1>(a, b, mask)
625 },
626 1 => unsafe {
627 let mask = _mm_set1_epi8(0x33i8);
628 interleave_bits_imm::<2>(a, b, mask)
629 },
630 2 => unsafe {
631 let mask = _mm_set1_epi8(0x0fi8);
632 interleave_bits_imm::<4>(a, b, mask)
633 },
634 3 => unsafe {
635 let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
636 let a = _mm_shuffle_epi8(a, shuffle);
637 let b = _mm_shuffle_epi8(b, shuffle);
638 let a_prime = _mm_unpacklo_epi8(a, b);
639 let b_prime = _mm_unpackhi_epi8(a, b);
640 (a_prime, b_prime)
641 },
642 4 => unsafe {
643 let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
644 let a = _mm_shuffle_epi8(a, shuffle);
645 let b = _mm_shuffle_epi8(b, shuffle);
646 let a_prime = _mm_unpacklo_epi16(a, b);
647 let b_prime = _mm_unpackhi_epi16(a, b);
648 (a_prime, b_prime)
649 },
650 5 => unsafe {
651 let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
652 let a = _mm_shuffle_epi8(a, shuffle);
653 let b = _mm_shuffle_epi8(b, shuffle);
654 let a_prime = _mm_unpacklo_epi32(a, b);
655 let b_prime = _mm_unpackhi_epi32(a, b);
656 (a_prime, b_prime)
657 },
658 6 => unsafe {
659 let a_prime = _mm_unpacklo_epi64(a, b);
660 let b_prime = _mm_unpackhi_epi64(a, b);
661 (a_prime, b_prime)
662 },
663 _ => panic!("unsupported block length"),
664 }
665}
666
667#[inline]
668unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
669 a: __m128i,
670 b: __m128i,
671 mask: __m128i,
672) -> (__m128i, __m128i) {
673 unsafe {
674 let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
675 let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
676 let b_prime = _mm_xor_si128(b, t);
677 (a_prime, b_prime)
678 }
679}
680
681impl_divisible_self!(M128);
685
686impl Divisible<u128> for M128 {
687 const LOG_N: usize = 0;
688
689 #[inline]
690 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone {
691 std::iter::once(u128::from(value))
692 }
693
694 #[inline]
695 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
696 std::iter::once(u128::from(*value))
697 }
698
699 #[inline]
700 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
701 slice.iter().map(|&v| u128::from(v))
702 }
703
704 #[inline]
705 unsafe fn get_unchecked(&self, _index: usize) -> u128 {
706 u128::from(*self)
707 }
708
709 #[inline]
710 unsafe fn set_unchecked(&mut self, _index: usize, val: u128) {
711 *self = Self::from(val);
712 }
713
714 #[inline]
715 fn broadcast(val: u128) -> Self {
716 Self::from(val)
717 }
718
719 #[inline]
720 fn from_iter(mut iter: impl Iterator<Item = u128>) -> Self {
721 iter.next().map(Self::from).unwrap_or(Self::ZERO)
722 }
723}
724
725impl Divisible<u64> for M128 {
726 const LOG_N: usize = 1;
727
728 #[inline]
729 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone {
730 mapget::value_iter(value)
731 }
732
733 #[inline]
734 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
735 mapget::value_iter(*value)
736 }
737
738 #[inline]
739 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
740 mapget::slice_iter(slice)
741 }
742
743 #[inline]
744 unsafe fn get_unchecked(&self, index: usize) -> u64 {
745 unsafe {
746 match index {
747 0 => _mm_extract_epi64(self.0, 0) as u64,
748 1 => _mm_extract_epi64(self.0, 1) as u64,
749 _ => core::hint::unreachable_unchecked(),
750 }
751 }
752 }
753
754 #[inline]
755 unsafe fn set_unchecked(&mut self, index: usize, val: u64) {
756 *self = unsafe {
757 match index {
758 0 => Self(_mm_insert_epi64(self.0, val as i64, 0)),
759 1 => Self(_mm_insert_epi64(self.0, val as i64, 1)),
760 _ => core::hint::unreachable_unchecked(),
761 }
762 };
763 }
764
765 #[inline]
766 fn broadcast(val: u64) -> Self {
767 unsafe { Self(_mm_set1_epi64x(val as i64)) }
768 }
769
770 #[inline]
771 fn from_iter(iter: impl Iterator<Item = u64>) -> Self {
772 let mut result = Self::ZERO;
773 let arr: &mut [u64; 2] = bytemuck::cast_mut(&mut result);
774 for (i, val) in iter.take(2).enumerate() {
775 arr[i] = val;
776 }
777 result
778 }
779}
780
781impl Divisible<u32> for M128 {
782 const LOG_N: usize = 2;
783
784 #[inline]
785 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone {
786 mapget::value_iter(value)
787 }
788
789 #[inline]
790 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
791 mapget::value_iter(*value)
792 }
793
794 #[inline]
795 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
796 mapget::slice_iter(slice)
797 }
798
799 #[inline]
800 unsafe fn get_unchecked(&self, index: usize) -> u32 {
801 unsafe {
802 match index {
803 0 => _mm_extract_epi32(self.0, 0) as u32,
804 1 => _mm_extract_epi32(self.0, 1) as u32,
805 2 => _mm_extract_epi32(self.0, 2) as u32,
806 3 => _mm_extract_epi32(self.0, 3) as u32,
807 _ => core::hint::unreachable_unchecked(),
808 }
809 }
810 }
811
812 #[inline]
813 unsafe fn set_unchecked(&mut self, index: usize, val: u32) {
814 *self = unsafe {
815 match index {
816 0 => Self(_mm_insert_epi32(self.0, val as i32, 0)),
817 1 => Self(_mm_insert_epi32(self.0, val as i32, 1)),
818 2 => Self(_mm_insert_epi32(self.0, val as i32, 2)),
819 3 => Self(_mm_insert_epi32(self.0, val as i32, 3)),
820 _ => core::hint::unreachable_unchecked(),
821 }
822 };
823 }
824
825 #[inline]
826 fn broadcast(val: u32) -> Self {
827 unsafe { Self(_mm_set1_epi32(val as i32)) }
828 }
829
830 #[inline]
831 fn from_iter(iter: impl Iterator<Item = u32>) -> Self {
832 let mut result = Self::ZERO;
833 let arr: &mut [u32; 4] = bytemuck::cast_mut(&mut result);
834 for (i, val) in iter.take(4).enumerate() {
835 arr[i] = val;
836 }
837 result
838 }
839}
840
841impl Divisible<u16> for M128 {
842 const LOG_N: usize = 3;
843
844 #[inline]
845 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone {
846 mapget::value_iter(value)
847 }
848
849 #[inline]
850 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
851 mapget::value_iter(*value)
852 }
853
854 #[inline]
855 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
856 mapget::slice_iter(slice)
857 }
858
859 #[inline]
860 unsafe fn get_unchecked(&self, index: usize) -> u16 {
861 unsafe {
862 match index {
863 0 => _mm_extract_epi16(self.0, 0) as u16,
864 1 => _mm_extract_epi16(self.0, 1) as u16,
865 2 => _mm_extract_epi16(self.0, 2) as u16,
866 3 => _mm_extract_epi16(self.0, 3) as u16,
867 4 => _mm_extract_epi16(self.0, 4) as u16,
868 5 => _mm_extract_epi16(self.0, 5) as u16,
869 6 => _mm_extract_epi16(self.0, 6) as u16,
870 7 => _mm_extract_epi16(self.0, 7) as u16,
871 _ => core::hint::unreachable_unchecked(),
872 }
873 }
874 }
875
876 #[inline]
877 unsafe fn set_unchecked(&mut self, index: usize, val: u16) {
878 *self = unsafe {
879 match index {
880 0 => Self(_mm_insert_epi16(self.0, val as i32, 0)),
881 1 => Self(_mm_insert_epi16(self.0, val as i32, 1)),
882 2 => Self(_mm_insert_epi16(self.0, val as i32, 2)),
883 3 => Self(_mm_insert_epi16(self.0, val as i32, 3)),
884 4 => Self(_mm_insert_epi16(self.0, val as i32, 4)),
885 5 => Self(_mm_insert_epi16(self.0, val as i32, 5)),
886 6 => Self(_mm_insert_epi16(self.0, val as i32, 6)),
887 7 => Self(_mm_insert_epi16(self.0, val as i32, 7)),
888 _ => core::hint::unreachable_unchecked(),
889 }
890 };
891 }
892
893 #[inline]
894 fn broadcast(val: u16) -> Self {
895 unsafe { Self(_mm_set1_epi16(val as i16)) }
896 }
897
898 #[inline]
899 fn from_iter(iter: impl Iterator<Item = u16>) -> Self {
900 let mut result = Self::ZERO;
901 let arr: &mut [u16; 8] = bytemuck::cast_mut(&mut result);
902 for (i, val) in iter.take(8).enumerate() {
903 arr[i] = val;
904 }
905 result
906 }
907}
908
909impl Divisible<u8> for M128 {
910 const LOG_N: usize = 4;
911
912 #[inline]
913 fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone {
914 mapget::value_iter(value)
915 }
916
917 #[inline]
918 fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
919 mapget::value_iter(*value)
920 }
921
922 #[inline]
923 fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
924 mapget::slice_iter(slice)
925 }
926
927 #[inline]
928 unsafe fn get_unchecked(&self, index: usize) -> u8 {
929 unsafe {
930 match index {
931 0 => _mm_extract_epi8(self.0, 0) as u8,
932 1 => _mm_extract_epi8(self.0, 1) as u8,
933 2 => _mm_extract_epi8(self.0, 2) as u8,
934 3 => _mm_extract_epi8(self.0, 3) as u8,
935 4 => _mm_extract_epi8(self.0, 4) as u8,
936 5 => _mm_extract_epi8(self.0, 5) as u8,
937 6 => _mm_extract_epi8(self.0, 6) as u8,
938 7 => _mm_extract_epi8(self.0, 7) as u8,
939 8 => _mm_extract_epi8(self.0, 8) as u8,
940 9 => _mm_extract_epi8(self.0, 9) as u8,
941 10 => _mm_extract_epi8(self.0, 10) as u8,
942 11 => _mm_extract_epi8(self.0, 11) as u8,
943 12 => _mm_extract_epi8(self.0, 12) as u8,
944 13 => _mm_extract_epi8(self.0, 13) as u8,
945 14 => _mm_extract_epi8(self.0, 14) as u8,
946 15 => _mm_extract_epi8(self.0, 15) as u8,
947 _ => core::hint::unreachable_unchecked(),
948 }
949 }
950 }
951
952 #[inline]
953 unsafe fn set_unchecked(&mut self, index: usize, val: u8) {
954 *self = unsafe {
955 match index {
956 0 => Self(_mm_insert_epi8(self.0, val as i32, 0)),
957 1 => Self(_mm_insert_epi8(self.0, val as i32, 1)),
958 2 => Self(_mm_insert_epi8(self.0, val as i32, 2)),
959 3 => Self(_mm_insert_epi8(self.0, val as i32, 3)),
960 4 => Self(_mm_insert_epi8(self.0, val as i32, 4)),
961 5 => Self(_mm_insert_epi8(self.0, val as i32, 5)),
962 6 => Self(_mm_insert_epi8(self.0, val as i32, 6)),
963 7 => Self(_mm_insert_epi8(self.0, val as i32, 7)),
964 8 => Self(_mm_insert_epi8(self.0, val as i32, 8)),
965 9 => Self(_mm_insert_epi8(self.0, val as i32, 9)),
966 10 => Self(_mm_insert_epi8(self.0, val as i32, 10)),
967 11 => Self(_mm_insert_epi8(self.0, val as i32, 11)),
968 12 => Self(_mm_insert_epi8(self.0, val as i32, 12)),
969 13 => Self(_mm_insert_epi8(self.0, val as i32, 13)),
970 14 => Self(_mm_insert_epi8(self.0, val as i32, 14)),
971 15 => Self(_mm_insert_epi8(self.0, val as i32, 15)),
972 _ => core::hint::unreachable_unchecked(),
973 }
974 };
975 }
976
977 #[inline]
978 fn broadcast(val: u8) -> Self {
979 unsafe { Self(_mm_set1_epi8(val as i8)) }
980 }
981
982 #[inline]
983 fn from_iter(iter: impl Iterator<Item = u8>) -> Self {
984 let mut result = Self::ZERO;
985 let arr: &mut [u8; 16] = bytemuck::cast_mut(&mut result);
986 for (i, val) in iter.take(16).enumerate() {
987 arr[i] = val;
988 }
989 result
990 }
991}
992
993#[cfg(test)]
994mod tests {
995 use binius_utils::bytes::BytesMut;
996 use proptest::{arbitrary::any, proptest};
997 use rand::prelude::*;
998
999 use super::*;
1000
1001 fn check_roundtrip<T>(val: M128)
1002 where
1003 T: From<M128>,
1004 M128: From<T>,
1005 {
1006 assert_eq!(M128::from(T::from(val)), val);
1007 }
1008
1009 #[test]
1010 fn test_constants() {
1011 assert_eq!(M128::default(), M128::ZERO);
1012 assert_eq!(M128::from(0u128), M128::ZERO);
1013 assert_eq!(M128::from(1u128), M128::ONE);
1014 }
1015
1016 fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1017 (value >> (index << log_block_len)) & M128::from(1u128 << log_block_len)
1018 }
1019
1020 proptest! {
1021 #[test]
1022 fn test_conversion(a in any::<u128>()) {
1023 check_roundtrip::<u128>(a.into());
1024 check_roundtrip::<__m128i>(a.into());
1025 }
1026
1027 #[test]
1028 fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1029 assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1030 assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1031 assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1032 }
1033
1034 #[test]
1035 fn test_negate(a in any::<u128>()) {
1036 assert_eq!(M128::from(!a), !M128::from(a))
1037 }
1038
1039 #[test]
1040 fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1041 assert_eq!(M128::from(a << b), M128::from(a) << b);
1042 assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1043 }
1044
1045 #[test]
1046 fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1047 let a = M128::from(a);
1048 let b = M128::from(b);
1049
1050 let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1051 let (c, d) = (M128::from(c), M128::from(d));
1052
1053 for i in (0..128>>height).step_by(2) {
1054 assert_eq!(get(c, height, i), get(a, height, i));
1055 assert_eq!(get(c, height, i+1), get(b, height, i));
1056 assert_eq!(get(d, height, i), get(a, height, i+1));
1057 assert_eq!(get(d, height, i+1), get(b, height, i+1));
1058 }
1059 }
1060 }
1061
1062 #[test]
1063 fn test_fill_with_bit() {
1064 assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1065 assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1066 }
1067
1068 #[test]
1069 fn test_eq() {
1070 let a = M128::from(0u128);
1071 let b = M128::from(42u128);
1072 let c = M128::from(u128::MAX);
1073
1074 assert_eq!(a, a);
1075 assert_eq!(b, b);
1076 assert_eq!(c, c);
1077
1078 assert_ne!(a, b);
1079 assert_ne!(a, c);
1080 assert_ne!(b, c);
1081 }
1082
1083 #[test]
1084 fn test_serialize_and_deserialize_m128() {
1085 let mut rng = StdRng::from_seed([0; 32]);
1086
1087 let original_value = M128::from(rng.random::<u128>());
1088
1089 let mut buf = BytesMut::new();
1090 original_value.serialize(&mut buf).unwrap();
1091
1092 let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1093
1094 assert_eq!(original_value, deserialized_value);
1095 }
1096}