binius_field/arch/x86_64/
m128.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	arch::x86_64::*,
5	array,
6	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializeBytes,
11	bytes::{Buf, BufMut},
12	serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{
16	Rng,
17	distr::{Distribution, StandardUniform},
18};
19use seq_macro::seq;
20
21use crate::{
22	BinaryField,
23	arch::portable::packed::PackedPrimitiveType,
24	underlier::{
25		Divisible, NumCast, SmallU, SpreadToByte, U2, U4, UnderlierType, UnderlierWithBitOps,
26		impl_divisible_bitmask, mapget, spread_fallback,
27	},
28};
29
30/// 128-bit value that is used for 128-bit SIMD operations
31#[derive(Copy, Clone)]
32#[repr(transparent)]
33pub struct M128(pub(super) __m128i);
34
35impl M128 {
36	#[inline(always)]
37	pub const fn from_u128(val: u128) -> Self {
38		let mut result = Self::ZERO;
39		unsafe {
40			result.0 = std::mem::transmute_copy(&val);
41		}
42
43		result
44	}
45}
46
47impl From<__m128i> for M128 {
48	#[inline(always)]
49	fn from(value: __m128i) -> Self {
50		Self(value)
51	}
52}
53
54impl From<u128> for M128 {
55	fn from(value: u128) -> Self {
56		Self(unsafe {
57			// Safety: u128 is 16-byte aligned
58			assert_eq!(align_of::<u128>(), 16);
59			_mm_load_si128(&raw const value as *const __m128i)
60		})
61	}
62}
63
64impl From<u64> for M128 {
65	fn from(value: u64) -> Self {
66		Self::from(value as u128)
67	}
68}
69
70impl From<u32> for M128 {
71	fn from(value: u32) -> Self {
72		Self::from(value as u128)
73	}
74}
75
76impl From<u16> for M128 {
77	fn from(value: u16) -> Self {
78		Self::from(value as u128)
79	}
80}
81
82impl From<u8> for M128 {
83	fn from(value: u8) -> Self {
84		Self::from(value as u128)
85	}
86}
87
88impl<const N: usize> From<SmallU<N>> for M128 {
89	fn from(value: SmallU<N>) -> Self {
90		Self::from(value.val() as u128)
91	}
92}
93
94impl From<M128> for u128 {
95	fn from(value: M128) -> Self {
96		let mut result = 0u128;
97		unsafe {
98			// Safety: u128 is 16-byte aligned
99			assert_eq!(align_of::<u128>(), 16);
100			_mm_store_si128(&raw mut result as *mut __m128i, value.0)
101		};
102		result
103	}
104}
105
106impl From<M128> for __m128i {
107	#[inline(always)]
108	fn from(value: M128) -> Self {
109		value.0
110	}
111}
112
113impl SerializeBytes for M128 {
114	fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
115		assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
116
117		let raw_value: u128 = (*self).into();
118
119		write_buf.put_u128_le(raw_value);
120		Ok(())
121	}
122}
123
124impl DeserializeBytes for M128 {
125	fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
126	where
127		Self: Sized,
128	{
129		assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
130
131		let raw_value = read_buf.get_u128_le();
132
133		Ok(Self::from(raw_value))
134	}
135}
136
137impl_divisible_bitmask!(M128, 1, 2, 4);
138
139impl<U: NumCast<u128>> NumCast<M128> for U {
140	#[inline(always)]
141	fn num_cast_from(val: M128) -> Self {
142		Self::num_cast_from(u128::from(val))
143	}
144}
145
146impl Default for M128 {
147	#[inline(always)]
148	fn default() -> Self {
149		Self(unsafe { _mm_setzero_si128() })
150	}
151}
152
153impl BitAnd for M128 {
154	type Output = Self;
155
156	#[inline(always)]
157	fn bitand(self, rhs: Self) -> Self::Output {
158		Self(unsafe { _mm_and_si128(self.0, rhs.0) })
159	}
160}
161
162impl BitAndAssign for M128 {
163	#[inline(always)]
164	fn bitand_assign(&mut self, rhs: Self) {
165		*self = *self & rhs
166	}
167}
168
169impl BitOr for M128 {
170	type Output = Self;
171
172	#[inline(always)]
173	fn bitor(self, rhs: Self) -> Self::Output {
174		Self(unsafe { _mm_or_si128(self.0, rhs.0) })
175	}
176}
177
178impl BitOrAssign for M128 {
179	#[inline(always)]
180	fn bitor_assign(&mut self, rhs: Self) {
181		*self = *self | rhs
182	}
183}
184
185impl BitXor for M128 {
186	type Output = Self;
187
188	#[inline(always)]
189	fn bitxor(self, rhs: Self) -> Self::Output {
190		Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
191	}
192}
193
194impl BitXorAssign for M128 {
195	#[inline(always)]
196	fn bitxor_assign(&mut self, rhs: Self) {
197		*self = *self ^ rhs;
198	}
199}
200
201impl Not for M128 {
202	type Output = Self;
203
204	fn not(self) -> Self::Output {
205		const ONES: __m128i = m128_from_u128!(u128::MAX);
206
207		self ^ Self(ONES)
208	}
209}
210
211/// `std::cmp::max` isn't const, so we need our own implementation
212pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
213	if left > right { left } else { right }
214}
215
216/// This solution shows 4X better performance.
217/// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed
218/// as constant and Rust currently doesn't allow passing expressions (`count - 64`) where variable
219/// is a generic constant parameter. Source: <https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688>
220macro_rules! bitshift_128b {
221	($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
222		unsafe {
223			let carry = $byte_shift($val, 8);
224			seq!(N in 64..128 {
225				if $shift == N {
226					return $bit_shift_64(
227						carry,
228						crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
229					).into();
230				}
231			});
232			seq!(N in 0..64 {
233				if $shift == N {
234					let carry = $bit_shift_64_opposite(
235						carry,
236						crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
237					);
238
239					let val = $bit_shift_64($val, N);
240					return $or(val, carry).into();
241				}
242			});
243
244			return Default::default()
245		}
246	};
247}
248
249impl Shr<usize> for M128 {
250	type Output = Self;
251
252	#[inline(always)]
253	fn shr(self, rhs: usize) -> Self::Output {
254		// This implementation is effective when `rhs` is known at compile-time.
255		// In our code this is always the case.
256		bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
257	}
258}
259
260impl Shl<usize> for M128 {
261	type Output = Self;
262
263	#[inline(always)]
264	fn shl(self, rhs: usize) -> Self::Output {
265		// This implementation is effective when `rhs` is known at compile-time.
266		// In our code this is always the case.
267		bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
268	}
269}
270
271impl PartialEq for M128 {
272	fn eq(&self, other: &Self) -> bool {
273		unsafe {
274			let neq = _mm_xor_si128(self.0, other.0);
275			_mm_test_all_zeros(neq, neq) == 1
276		}
277	}
278}
279
280impl Eq for M128 {}
281
282impl PartialOrd for M128 {
283	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
284		Some(self.cmp(other))
285	}
286}
287
288impl Ord for M128 {
289	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
290		u128::from(*self).cmp(&u128::from(*other))
291	}
292}
293
294impl Distribution<M128> for StandardUniform {
295	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
296		M128(rng.random())
297	}
298}
299
300impl std::fmt::Display for M128 {
301	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302		let data: u128 = (*self).into();
303		write!(f, "{data:02X?}")
304	}
305}
306
307impl std::fmt::Debug for M128 {
308	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309		write!(f, "M128({self})")
310	}
311}
312
313#[repr(align(16))]
314pub struct AlignedData(pub [u128; 1]);
315
316macro_rules! m128_from_u128 {
317	($val:expr) => {{
318		let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
319		unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
320	}};
321}
322
323pub(super) use m128_from_u128;
324
325impl UnderlierType for M128 {
326	const LOG_BITS: usize = 7;
327}
328
329impl UnderlierWithBitOps for M128 {
330	const ZERO: Self = { Self(m128_from_u128!(0)) };
331	const ONE: Self = { Self(m128_from_u128!(1)) };
332	const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
333
334	#[inline(always)]
335	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
336		unsafe {
337			let (c, d) = interleave_bits(
338				Into::<Self>::into(self).into(),
339				Into::<Self>::into(other).into(),
340				log_block_len,
341			);
342			(Self::from(c), Self::from(d))
343		}
344	}
345
346	#[inline(always)]
347	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
348	where
349		T: UnderlierWithBitOps,
350		Self: Divisible<T>,
351	{
352		match T::LOG_BITS {
353			0 => match log_block_len {
354				0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
355				1 => unsafe {
356					let bits: [u8; 2] =
357						array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
358
359					_mm_set_epi64x(
360						u64::fill_with_bit(bits[1]) as i64,
361						u64::fill_with_bit(bits[0]) as i64,
362					)
363					.into()
364				},
365				2 => unsafe {
366					let bits: [u8; 4] =
367						array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
368
369					_mm_set_epi32(
370						u32::fill_with_bit(bits[3]) as i32,
371						u32::fill_with_bit(bits[2]) as i32,
372						u32::fill_with_bit(bits[1]) as i32,
373						u32::fill_with_bit(bits[0]) as i32,
374					)
375					.into()
376				},
377				3 => unsafe {
378					let bits: [u8; 8] =
379						array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
380
381					_mm_set_epi16(
382						u16::fill_with_bit(bits[7]) as i16,
383						u16::fill_with_bit(bits[6]) as i16,
384						u16::fill_with_bit(bits[5]) as i16,
385						u16::fill_with_bit(bits[4]) as i16,
386						u16::fill_with_bit(bits[3]) as i16,
387						u16::fill_with_bit(bits[2]) as i16,
388						u16::fill_with_bit(bits[1]) as i16,
389						u16::fill_with_bit(bits[0]) as i16,
390					)
391					.into()
392				},
393				4 => unsafe {
394					let bits: [u8; 16] =
395						array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
396
397					_mm_set_epi8(
398						u8::fill_with_bit(bits[15]) as i8,
399						u8::fill_with_bit(bits[14]) as i8,
400						u8::fill_with_bit(bits[13]) as i8,
401						u8::fill_with_bit(bits[12]) as i8,
402						u8::fill_with_bit(bits[11]) as i8,
403						u8::fill_with_bit(bits[10]) as i8,
404						u8::fill_with_bit(bits[9]) as i8,
405						u8::fill_with_bit(bits[8]) as i8,
406						u8::fill_with_bit(bits[7]) as i8,
407						u8::fill_with_bit(bits[6]) as i8,
408						u8::fill_with_bit(bits[5]) as i8,
409						u8::fill_with_bit(bits[4]) as i8,
410						u8::fill_with_bit(bits[3]) as i8,
411						u8::fill_with_bit(bits[2]) as i8,
412						u8::fill_with_bit(bits[1]) as i8,
413						u8::fill_with_bit(bits[0]) as i8,
414					)
415					.into()
416				},
417				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
418			},
419			1 => match log_block_len {
420				0 => unsafe {
421					let value =
422						U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
423
424					_mm_set1_epi8(value as i8).into()
425				},
426				1 => {
427					let bytes: [u8; 2] = array::from_fn(|i| {
428						U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
429					});
430
431					Self::from_fn::<u8>(|i| bytes[i / 8])
432				}
433				2 => {
434					let bytes: [u8; 4] = array::from_fn(|i| {
435						U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
436					});
437
438					Self::from_fn::<u8>(|i| bytes[i / 4])
439				}
440				3 => {
441					let bytes: [u8; 8] = array::from_fn(|i| {
442						U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
443							.spread_to_byte()
444					});
445
446					Self::from_fn::<u8>(|i| bytes[i / 2])
447				}
448				4 => {
449					let bytes: [u8; 16] = array::from_fn(|i| {
450						U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
451							.spread_to_byte()
452					});
453
454					Self::from_fn::<u8>(|i| bytes[i])
455				}
456				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
457			},
458			2 => match log_block_len {
459				0 => {
460					let value =
461						U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
462
463					unsafe { _mm_set1_epi8(value as i8).into() }
464				}
465				1 => {
466					let values: [u8; 2] = array::from_fn(|i| {
467						U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
468					});
469
470					Self::from_fn::<u8>(|i| values[i / 8])
471				}
472				2 => {
473					let values: [u8; 4] = array::from_fn(|i| {
474						U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
475							.spread_to_byte()
476					});
477
478					Self::from_fn::<u8>(|i| values[i / 4])
479				}
480				3 => {
481					let values: [u8; 8] = array::from_fn(|i| {
482						U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
483							.spread_to_byte()
484					});
485
486					Self::from_fn::<u8>(|i| values[i / 2])
487				}
488				4 => {
489					let values: [u8; 16] = array::from_fn(|i| {
490						U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
491							.spread_to_byte()
492					});
493
494					Self::from_fn::<u8>(|i| values[i])
495				}
496				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
497			},
498			3 => match log_block_len {
499				0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
500				1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
501				2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
502				3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
503				4 => self,
504				_ => panic!("unsupported block length"),
505			},
506			4 => match log_block_len {
507				0 => {
508					let value = (u128::from(self) >> (block_idx * 16)) as u16;
509
510					unsafe { _mm_set1_epi16(value as i16).into() }
511				}
512				1 => {
513					let values: [u16; 2] =
514						array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
515
516					Self::from_fn::<u16>(|i| values[i / 4])
517				}
518				2 => {
519					let values: [u16; 4] =
520						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
521
522					Self::from_fn::<u16>(|i| values[i / 2])
523				}
524				3 => self,
525				_ => panic!("unsupported block length"),
526			},
527			5 => match log_block_len {
528				0 => unsafe {
529					let value = (u128::from(self) >> (block_idx * 32)) as u32;
530
531					_mm_set1_epi32(value as i32).into()
532				},
533				1 => {
534					let values: [u32; 2] =
535						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
536
537					Self::from_fn::<u32>(|i| values[i / 2])
538				}
539				2 => self,
540				_ => panic!("unsupported block length"),
541			},
542			6 => match log_block_len {
543				0 => unsafe {
544					let value = (u128::from(self) >> (block_idx * 64)) as u64;
545
546					_mm_set1_epi64x(value as i64).into()
547				},
548				1 => self,
549				_ => panic!("unsupported block length"),
550			},
551			7 => self,
552			_ => panic!("unsupported bit length"),
553		}
554	}
555}
556
557unsafe impl Zeroable for M128 {}
558
559unsafe impl Pod for M128 {}
560
561unsafe impl Send for M128 {}
562
563unsafe impl Sync for M128 {}
564
565static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
566static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
567static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
568static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
569
570const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
571	log_block_len: usize,
572	t_log_bits: usize,
573) -> [M128; BLOCK_IDX_AMOUNT] {
574	let element_log_width = t_log_bits - 3;
575
576	let element_width = 1 << element_log_width;
577
578	let block_size = 1 << (log_block_len + element_log_width);
579	let repeat = 1 << (4 - element_log_width - log_block_len);
580	let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
581
582	let mut block_idx = 0;
583
584	while block_idx < BLOCK_IDX_AMOUNT {
585		let base = block_idx * block_size;
586		let mut j = 0;
587		while j < 16 {
588			masks[block_idx][j] =
589				(base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
590			j += 1;
591		}
592		block_idx += 1;
593	}
594	let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
595
596	let mut block_idx = 0;
597
598	while block_idx < BLOCK_IDX_AMOUNT {
599		m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
600		block_idx += 1;
601	}
602
603	m128_masks
604}
605
606impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
607	fn from(value: __m128i) -> Self {
608		M128::from(value).into()
609	}
610}
611
612impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
613	fn from(value: u128) -> Self {
614		M128::from(value).into()
615	}
616}
617
618impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
619	fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
620		value.to_underlier().into()
621	}
622}
623
624#[inline]
625unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
626	match log_block_len {
627		0 => unsafe {
628			let mask = _mm_set1_epi8(0x55i8);
629			interleave_bits_imm::<1>(a, b, mask)
630		},
631		1 => unsafe {
632			let mask = _mm_set1_epi8(0x33i8);
633			interleave_bits_imm::<2>(a, b, mask)
634		},
635		2 => unsafe {
636			let mask = _mm_set1_epi8(0x0fi8);
637			interleave_bits_imm::<4>(a, b, mask)
638		},
639		3 => unsafe {
640			let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
641			let a = _mm_shuffle_epi8(a, shuffle);
642			let b = _mm_shuffle_epi8(b, shuffle);
643			let a_prime = _mm_unpacklo_epi8(a, b);
644			let b_prime = _mm_unpackhi_epi8(a, b);
645			(a_prime, b_prime)
646		},
647		4 => unsafe {
648			let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
649			let a = _mm_shuffle_epi8(a, shuffle);
650			let b = _mm_shuffle_epi8(b, shuffle);
651			let a_prime = _mm_unpacklo_epi16(a, b);
652			let b_prime = _mm_unpackhi_epi16(a, b);
653			(a_prime, b_prime)
654		},
655		5 => unsafe {
656			let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
657			let a = _mm_shuffle_epi8(a, shuffle);
658			let b = _mm_shuffle_epi8(b, shuffle);
659			let a_prime = _mm_unpacklo_epi32(a, b);
660			let b_prime = _mm_unpackhi_epi32(a, b);
661			(a_prime, b_prime)
662		},
663		6 => unsafe {
664			let a_prime = _mm_unpacklo_epi64(a, b);
665			let b_prime = _mm_unpackhi_epi64(a, b);
666			(a_prime, b_prime)
667		},
668		_ => panic!("unsupported block length"),
669	}
670}
671
672#[inline]
673unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
674	a: __m128i,
675	b: __m128i,
676	mask: __m128i,
677) -> (__m128i, __m128i) {
678	unsafe {
679		let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
680		let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
681		let b_prime = _mm_xor_si128(b, t);
682		(a_prime, b_prime)
683	}
684}
685
686// Divisible implementations using SIMD extract/insert intrinsics
687
688impl Divisible<u128> for M128 {
689	const LOG_N: usize = 0;
690
691	#[inline]
692	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone {
693		std::iter::once(u128::from(value))
694	}
695
696	#[inline]
697	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
698		std::iter::once(u128::from(*value))
699	}
700
701	#[inline]
702	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
703		slice.iter().map(|&v| u128::from(v))
704	}
705
706	#[inline]
707	fn get(self, index: usize) -> u128 {
708		assert!(index == 0, "index out of bounds");
709		u128::from(self)
710	}
711
712	#[inline]
713	fn set(self, index: usize, val: u128) -> Self {
714		assert!(index == 0, "index out of bounds");
715		Self::from(val)
716	}
717
718	#[inline]
719	fn broadcast(val: u128) -> Self {
720		Self::from(val)
721	}
722
723	#[inline]
724	fn from_iter(mut iter: impl Iterator<Item = u128>) -> Self {
725		iter.next().map(Self::from).unwrap_or(Self::ZERO)
726	}
727}
728
729impl Divisible<u64> for M128 {
730	const LOG_N: usize = 1;
731
732	#[inline]
733	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone {
734		mapget::value_iter(value)
735	}
736
737	#[inline]
738	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
739		mapget::value_iter(*value)
740	}
741
742	#[inline]
743	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
744		mapget::slice_iter(slice)
745	}
746
747	#[inline]
748	fn get(self, index: usize) -> u64 {
749		unsafe {
750			match index {
751				0 => _mm_extract_epi64(self.0, 0) as u64,
752				1 => _mm_extract_epi64(self.0, 1) as u64,
753				_ => panic!("index out of bounds"),
754			}
755		}
756	}
757
758	#[inline]
759	fn set(self, index: usize, val: u64) -> Self {
760		unsafe {
761			match index {
762				0 => Self(_mm_insert_epi64(self.0, val as i64, 0)),
763				1 => Self(_mm_insert_epi64(self.0, val as i64, 1)),
764				_ => panic!("index out of bounds"),
765			}
766		}
767	}
768
769	#[inline]
770	fn broadcast(val: u64) -> Self {
771		unsafe { Self(_mm_set1_epi64x(val as i64)) }
772	}
773
774	#[inline]
775	fn from_iter(iter: impl Iterator<Item = u64>) -> Self {
776		let mut result = Self::ZERO;
777		let arr: &mut [u64; 2] = bytemuck::cast_mut(&mut result);
778		for (i, val) in iter.take(2).enumerate() {
779			arr[i] = val;
780		}
781		result
782	}
783}
784
785impl Divisible<u32> for M128 {
786	const LOG_N: usize = 2;
787
788	#[inline]
789	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone {
790		mapget::value_iter(value)
791	}
792
793	#[inline]
794	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
795		mapget::value_iter(*value)
796	}
797
798	#[inline]
799	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
800		mapget::slice_iter(slice)
801	}
802
803	#[inline]
804	fn get(self, index: usize) -> u32 {
805		unsafe {
806			match index {
807				0 => _mm_extract_epi32(self.0, 0) as u32,
808				1 => _mm_extract_epi32(self.0, 1) as u32,
809				2 => _mm_extract_epi32(self.0, 2) as u32,
810				3 => _mm_extract_epi32(self.0, 3) as u32,
811				_ => panic!("index out of bounds"),
812			}
813		}
814	}
815
816	#[inline]
817	fn set(self, index: usize, val: u32) -> Self {
818		unsafe {
819			match index {
820				0 => Self(_mm_insert_epi32(self.0, val as i32, 0)),
821				1 => Self(_mm_insert_epi32(self.0, val as i32, 1)),
822				2 => Self(_mm_insert_epi32(self.0, val as i32, 2)),
823				3 => Self(_mm_insert_epi32(self.0, val as i32, 3)),
824				_ => panic!("index out of bounds"),
825			}
826		}
827	}
828
829	#[inline]
830	fn broadcast(val: u32) -> Self {
831		unsafe { Self(_mm_set1_epi32(val as i32)) }
832	}
833
834	#[inline]
835	fn from_iter(iter: impl Iterator<Item = u32>) -> Self {
836		let mut result = Self::ZERO;
837		let arr: &mut [u32; 4] = bytemuck::cast_mut(&mut result);
838		for (i, val) in iter.take(4).enumerate() {
839			arr[i] = val;
840		}
841		result
842	}
843}
844
845impl Divisible<u16> for M128 {
846	const LOG_N: usize = 3;
847
848	#[inline]
849	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone {
850		mapget::value_iter(value)
851	}
852
853	#[inline]
854	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
855		mapget::value_iter(*value)
856	}
857
858	#[inline]
859	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
860		mapget::slice_iter(slice)
861	}
862
863	#[inline]
864	fn get(self, index: usize) -> u16 {
865		unsafe {
866			match index {
867				0 => _mm_extract_epi16(self.0, 0) as u16,
868				1 => _mm_extract_epi16(self.0, 1) as u16,
869				2 => _mm_extract_epi16(self.0, 2) as u16,
870				3 => _mm_extract_epi16(self.0, 3) as u16,
871				4 => _mm_extract_epi16(self.0, 4) as u16,
872				5 => _mm_extract_epi16(self.0, 5) as u16,
873				6 => _mm_extract_epi16(self.0, 6) as u16,
874				7 => _mm_extract_epi16(self.0, 7) as u16,
875				_ => panic!("index out of bounds"),
876			}
877		}
878	}
879
880	#[inline]
881	fn set(self, index: usize, val: u16) -> Self {
882		unsafe {
883			match index {
884				0 => Self(_mm_insert_epi16(self.0, val as i32, 0)),
885				1 => Self(_mm_insert_epi16(self.0, val as i32, 1)),
886				2 => Self(_mm_insert_epi16(self.0, val as i32, 2)),
887				3 => Self(_mm_insert_epi16(self.0, val as i32, 3)),
888				4 => Self(_mm_insert_epi16(self.0, val as i32, 4)),
889				5 => Self(_mm_insert_epi16(self.0, val as i32, 5)),
890				6 => Self(_mm_insert_epi16(self.0, val as i32, 6)),
891				7 => Self(_mm_insert_epi16(self.0, val as i32, 7)),
892				_ => panic!("index out of bounds"),
893			}
894		}
895	}
896
897	#[inline]
898	fn broadcast(val: u16) -> Self {
899		unsafe { Self(_mm_set1_epi16(val as i16)) }
900	}
901
902	#[inline]
903	fn from_iter(iter: impl Iterator<Item = u16>) -> Self {
904		let mut result = Self::ZERO;
905		let arr: &mut [u16; 8] = bytemuck::cast_mut(&mut result);
906		for (i, val) in iter.take(8).enumerate() {
907			arr[i] = val;
908		}
909		result
910	}
911}
912
913impl Divisible<u8> for M128 {
914	const LOG_N: usize = 4;
915
916	#[inline]
917	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone {
918		mapget::value_iter(value)
919	}
920
921	#[inline]
922	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
923		mapget::value_iter(*value)
924	}
925
926	#[inline]
927	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
928		mapget::slice_iter(slice)
929	}
930
931	#[inline]
932	fn get(self, index: usize) -> u8 {
933		unsafe {
934			match index {
935				0 => _mm_extract_epi8(self.0, 0) as u8,
936				1 => _mm_extract_epi8(self.0, 1) as u8,
937				2 => _mm_extract_epi8(self.0, 2) as u8,
938				3 => _mm_extract_epi8(self.0, 3) as u8,
939				4 => _mm_extract_epi8(self.0, 4) as u8,
940				5 => _mm_extract_epi8(self.0, 5) as u8,
941				6 => _mm_extract_epi8(self.0, 6) as u8,
942				7 => _mm_extract_epi8(self.0, 7) as u8,
943				8 => _mm_extract_epi8(self.0, 8) as u8,
944				9 => _mm_extract_epi8(self.0, 9) as u8,
945				10 => _mm_extract_epi8(self.0, 10) as u8,
946				11 => _mm_extract_epi8(self.0, 11) as u8,
947				12 => _mm_extract_epi8(self.0, 12) as u8,
948				13 => _mm_extract_epi8(self.0, 13) as u8,
949				14 => _mm_extract_epi8(self.0, 14) as u8,
950				15 => _mm_extract_epi8(self.0, 15) as u8,
951				_ => panic!("index out of bounds"),
952			}
953		}
954	}
955
956	#[inline]
957	fn set(self, index: usize, val: u8) -> Self {
958		unsafe {
959			match index {
960				0 => Self(_mm_insert_epi8(self.0, val as i32, 0)),
961				1 => Self(_mm_insert_epi8(self.0, val as i32, 1)),
962				2 => Self(_mm_insert_epi8(self.0, val as i32, 2)),
963				3 => Self(_mm_insert_epi8(self.0, val as i32, 3)),
964				4 => Self(_mm_insert_epi8(self.0, val as i32, 4)),
965				5 => Self(_mm_insert_epi8(self.0, val as i32, 5)),
966				6 => Self(_mm_insert_epi8(self.0, val as i32, 6)),
967				7 => Self(_mm_insert_epi8(self.0, val as i32, 7)),
968				8 => Self(_mm_insert_epi8(self.0, val as i32, 8)),
969				9 => Self(_mm_insert_epi8(self.0, val as i32, 9)),
970				10 => Self(_mm_insert_epi8(self.0, val as i32, 10)),
971				11 => Self(_mm_insert_epi8(self.0, val as i32, 11)),
972				12 => Self(_mm_insert_epi8(self.0, val as i32, 12)),
973				13 => Self(_mm_insert_epi8(self.0, val as i32, 13)),
974				14 => Self(_mm_insert_epi8(self.0, val as i32, 14)),
975				15 => Self(_mm_insert_epi8(self.0, val as i32, 15)),
976				_ => panic!("index out of bounds"),
977			}
978		}
979	}
980
981	#[inline]
982	fn broadcast(val: u8) -> Self {
983		unsafe { Self(_mm_set1_epi8(val as i8)) }
984	}
985
986	#[inline]
987	fn from_iter(iter: impl Iterator<Item = u8>) -> Self {
988		let mut result = Self::ZERO;
989		let arr: &mut [u8; 16] = bytemuck::cast_mut(&mut result);
990		for (i, val) in iter.take(16).enumerate() {
991			arr[i] = val;
992		}
993		result
994	}
995}
996
997#[cfg(test)]
998mod tests {
999	use binius_utils::bytes::BytesMut;
1000	use proptest::{arbitrary::any, proptest};
1001	use rand::{SeedableRng, rngs::StdRng};
1002
1003	use super::*;
1004
1005	fn check_roundtrip<T>(val: M128)
1006	where
1007		T: From<M128>,
1008		M128: From<T>,
1009	{
1010		assert_eq!(M128::from(T::from(val)), val);
1011	}
1012
1013	#[test]
1014	fn test_constants() {
1015		assert_eq!(M128::default(), M128::ZERO);
1016		assert_eq!(M128::from(0u128), M128::ZERO);
1017		assert_eq!(M128::from(1u128), M128::ONE);
1018	}
1019
1020	fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1021		(value >> (index << log_block_len)) & M128::from(1u128 << log_block_len)
1022	}
1023
1024	proptest! {
1025		#[test]
1026		fn test_conversion(a in any::<u128>()) {
1027			check_roundtrip::<u128>(a.into());
1028			check_roundtrip::<__m128i>(a.into());
1029		}
1030
1031		#[test]
1032		fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1033			assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1034			assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1035			assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1036		}
1037
1038		#[test]
1039		fn test_negate(a in any::<u128>()) {
1040			assert_eq!(M128::from(!a), !M128::from(a))
1041		}
1042
1043		#[test]
1044		fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1045			assert_eq!(M128::from(a << b), M128::from(a) << b);
1046			assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1047		}
1048
1049		#[test]
1050		fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1051			let a = M128::from(a);
1052			let b = M128::from(b);
1053
1054			let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1055			let (c, d) = (M128::from(c), M128::from(d));
1056
1057			for i in (0..128>>height).step_by(2) {
1058				assert_eq!(get(c, height, i), get(a, height, i));
1059				assert_eq!(get(c, height, i+1), get(b, height, i));
1060				assert_eq!(get(d, height, i), get(a, height, i+1));
1061				assert_eq!(get(d, height, i+1), get(b, height, i+1));
1062			}
1063		}
1064	}
1065
1066	#[test]
1067	fn test_fill_with_bit() {
1068		assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1069		assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1070	}
1071
1072	#[test]
1073	fn test_eq() {
1074		let a = M128::from(0u128);
1075		let b = M128::from(42u128);
1076		let c = M128::from(u128::MAX);
1077
1078		assert_eq!(a, a);
1079		assert_eq!(b, b);
1080		assert_eq!(c, c);
1081
1082		assert_ne!(a, b);
1083		assert_ne!(a, c);
1084		assert_ne!(b, c);
1085	}
1086
1087	#[test]
1088	fn test_serialize_and_deserialize_m128() {
1089		let mut rng = StdRng::from_seed([0; 32]);
1090
1091		let original_value = M128::from(rng.random::<u128>());
1092
1093		let mut buf = BytesMut::new();
1094		original_value.serialize(&mut buf).unwrap();
1095
1096		let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1097
1098		assert_eq!(original_value, deserialized_value);
1099	}
1100}