binius_field/arch/x86_64/
m128.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	arch::x86_64::*,
5	array,
6	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializeBytes,
11	bytes::{Buf, BufMut},
12	serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{
16	Rng,
17	distr::{Distribution, StandardUniform},
18};
19use seq_macro::seq;
20
21use crate::{
22	BinaryField,
23	arch::portable::packed::PackedPrimitiveType,
24	underlier::{
25		Divisible, NumCast, SmallU, SpreadToByte, U2, U4, UnderlierType, UnderlierWithBitOps,
26		impl_divisible_bitmask, mapget, spread_fallback,
27	},
28};
29
30/// 128-bit value that is used for 128-bit SIMD operations
31#[derive(Copy, Clone)]
32#[repr(transparent)]
33pub struct M128(pub(super) __m128i);
34
35impl M128 {
36	#[inline(always)]
37	pub const fn from_u128(val: u128) -> Self {
38		let mut result = Self::ZERO;
39		unsafe {
40			result.0 = std::mem::transmute_copy(&val);
41		}
42
43		result
44	}
45}
46
47impl From<__m128i> for M128 {
48	#[inline(always)]
49	fn from(value: __m128i) -> Self {
50		Self(value)
51	}
52}
53
54impl From<u128> for M128 {
55	fn from(value: u128) -> Self {
56		Self(unsafe {
57			// Safety: u128 is 16-byte aligned
58			assert_eq!(align_of::<u128>(), 16);
59			_mm_load_si128(&raw const value as *const __m128i)
60		})
61	}
62}
63
64impl From<u64> for M128 {
65	fn from(value: u64) -> Self {
66		Self::from(value as u128)
67	}
68}
69
70impl From<u32> for M128 {
71	fn from(value: u32) -> Self {
72		Self::from(value as u128)
73	}
74}
75
76impl From<u16> for M128 {
77	fn from(value: u16) -> Self {
78		Self::from(value as u128)
79	}
80}
81
82impl From<u8> for M128 {
83	fn from(value: u8) -> Self {
84		Self::from(value as u128)
85	}
86}
87
88impl<const N: usize> From<SmallU<N>> for M128 {
89	fn from(value: SmallU<N>) -> Self {
90		Self::from(value.val() as u128)
91	}
92}
93
94impl From<M128> for u128 {
95	fn from(value: M128) -> Self {
96		let mut result = 0u128;
97		unsafe {
98			// Safety: u128 is 16-byte aligned
99			assert_eq!(align_of::<u128>(), 16);
100			_mm_store_si128(&raw mut result as *mut __m128i, value.0)
101		};
102		result
103	}
104}
105
106impl From<M128> for __m128i {
107	#[inline(always)]
108	fn from(value: M128) -> Self {
109		value.0
110	}
111}
112
113impl SerializeBytes for M128 {
114	fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
115		assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
116
117		let raw_value: u128 = (*self).into();
118
119		write_buf.put_u128_le(raw_value);
120		Ok(())
121	}
122}
123
124impl DeserializeBytes for M128 {
125	fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
126	where
127		Self: Sized,
128	{
129		assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
130
131		let raw_value = read_buf.get_u128_le();
132
133		Ok(Self::from(raw_value))
134	}
135}
136
137impl_divisible_bitmask!(M128, 1, 2, 4);
138
139impl<U: NumCast<u128>> NumCast<M128> for U {
140	#[inline(always)]
141	fn num_cast_from(val: M128) -> Self {
142		Self::num_cast_from(u128::from(val))
143	}
144}
145
146impl Default for M128 {
147	#[inline(always)]
148	fn default() -> Self {
149		Self(unsafe { _mm_setzero_si128() })
150	}
151}
152
153impl BitAnd for M128 {
154	type Output = Self;
155
156	#[inline(always)]
157	fn bitand(self, rhs: Self) -> Self::Output {
158		Self(unsafe { _mm_and_si128(self.0, rhs.0) })
159	}
160}
161
162impl BitAndAssign for M128 {
163	#[inline(always)]
164	fn bitand_assign(&mut self, rhs: Self) {
165		*self = *self & rhs
166	}
167}
168
169impl BitOr for M128 {
170	type Output = Self;
171
172	#[inline(always)]
173	fn bitor(self, rhs: Self) -> Self::Output {
174		Self(unsafe { _mm_or_si128(self.0, rhs.0) })
175	}
176}
177
178impl BitOrAssign for M128 {
179	#[inline(always)]
180	fn bitor_assign(&mut self, rhs: Self) {
181		*self = *self | rhs
182	}
183}
184
185impl BitXor for M128 {
186	type Output = Self;
187
188	#[inline(always)]
189	fn bitxor(self, rhs: Self) -> Self::Output {
190		Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
191	}
192}
193
194impl BitXorAssign for M128 {
195	#[inline(always)]
196	fn bitxor_assign(&mut self, rhs: Self) {
197		*self = *self ^ rhs;
198	}
199}
200
201impl Not for M128 {
202	type Output = Self;
203
204	fn not(self) -> Self::Output {
205		const ONES: __m128i = m128_from_u128!(u128::MAX);
206
207		self ^ Self(ONES)
208	}
209}
210
211/// `std::cmp::max` isn't const, so we need our own implementation
212pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
213	if left > right { left } else { right }
214}
215
216/// This solution shows 4X better performance.
217/// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed
218/// as constant and Rust currently doesn't allow passing expressions (`count - 64`) where variable
219/// is a generic constant parameter. Source: <https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688>
220macro_rules! bitshift_128b {
221	($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
222		unsafe {
223			let carry = $byte_shift($val, 8);
224			seq!(N in 64..128 {
225				if $shift == N {
226					return $bit_shift_64(
227						carry,
228						crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
229					).into();
230				}
231			});
232			seq!(N in 0..64 {
233				if $shift == N {
234					let carry = $bit_shift_64_opposite(
235						carry,
236						crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
237					);
238
239					let val = $bit_shift_64($val, N);
240					return $or(val, carry).into();
241				}
242			});
243
244			return Default::default()
245		}
246	};
247}
248
249pub(crate) use bitshift_128b;
250
251impl Shr<usize> for M128 {
252	type Output = Self;
253
254	#[inline(always)]
255	fn shr(self, rhs: usize) -> Self::Output {
256		// This implementation is effective when `rhs` is known at compile-time.
257		// In our code this is always the case.
258		bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
259	}
260}
261
262impl Shl<usize> for M128 {
263	type Output = Self;
264
265	#[inline(always)]
266	fn shl(self, rhs: usize) -> Self::Output {
267		// This implementation is effective when `rhs` is known at compile-time.
268		// In our code this is always the case.
269		bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
270	}
271}
272
273impl PartialEq for M128 {
274	fn eq(&self, other: &Self) -> bool {
275		unsafe {
276			let neq = _mm_xor_si128(self.0, other.0);
277			_mm_test_all_zeros(neq, neq) == 1
278		}
279	}
280}
281
282impl Eq for M128 {}
283
284impl PartialOrd for M128 {
285	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
286		Some(self.cmp(other))
287	}
288}
289
290impl Ord for M128 {
291	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
292		u128::from(*self).cmp(&u128::from(*other))
293	}
294}
295
296impl Distribution<M128> for StandardUniform {
297	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
298		M128(rng.random())
299	}
300}
301
302impl std::fmt::Display for M128 {
303	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304		let data: u128 = (*self).into();
305		write!(f, "{data:02X?}")
306	}
307}
308
309impl std::fmt::Debug for M128 {
310	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311		write!(f, "M128({self})")
312	}
313}
314
315#[repr(align(16))]
316pub struct AlignedData(pub [u128; 1]);
317
318macro_rules! m128_from_u128 {
319	($val:expr) => {{
320		let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
321		unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
322	}};
323}
324
325pub(super) use m128_from_u128;
326
327impl UnderlierType for M128 {
328	const LOG_BITS: usize = 7;
329}
330
331impl UnderlierWithBitOps for M128 {
332	const ZERO: Self = { Self(m128_from_u128!(0)) };
333	const ONE: Self = { Self(m128_from_u128!(1)) };
334	const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
335
336	#[inline(always)]
337	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
338		unsafe {
339			let (c, d) = interleave_bits(
340				Into::<Self>::into(self).into(),
341				Into::<Self>::into(other).into(),
342				log_block_len,
343			);
344			(Self::from(c), Self::from(d))
345		}
346	}
347
348	#[inline(always)]
349	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
350	where
351		T: UnderlierWithBitOps,
352		Self: Divisible<T>,
353	{
354		match T::LOG_BITS {
355			0 => match log_block_len {
356				0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
357				1 => unsafe {
358					let bits: [u8; 2] =
359						array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
360
361					_mm_set_epi64x(
362						u64::fill_with_bit(bits[1]) as i64,
363						u64::fill_with_bit(bits[0]) as i64,
364					)
365					.into()
366				},
367				2 => unsafe {
368					let bits: [u8; 4] =
369						array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
370
371					_mm_set_epi32(
372						u32::fill_with_bit(bits[3]) as i32,
373						u32::fill_with_bit(bits[2]) as i32,
374						u32::fill_with_bit(bits[1]) as i32,
375						u32::fill_with_bit(bits[0]) as i32,
376					)
377					.into()
378				},
379				3 => unsafe {
380					let bits: [u8; 8] =
381						array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
382
383					_mm_set_epi16(
384						u16::fill_with_bit(bits[7]) as i16,
385						u16::fill_with_bit(bits[6]) as i16,
386						u16::fill_with_bit(bits[5]) as i16,
387						u16::fill_with_bit(bits[4]) as i16,
388						u16::fill_with_bit(bits[3]) as i16,
389						u16::fill_with_bit(bits[2]) as i16,
390						u16::fill_with_bit(bits[1]) as i16,
391						u16::fill_with_bit(bits[0]) as i16,
392					)
393					.into()
394				},
395				4 => unsafe {
396					let bits: [u8; 16] =
397						array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
398
399					_mm_set_epi8(
400						u8::fill_with_bit(bits[15]) as i8,
401						u8::fill_with_bit(bits[14]) as i8,
402						u8::fill_with_bit(bits[13]) as i8,
403						u8::fill_with_bit(bits[12]) as i8,
404						u8::fill_with_bit(bits[11]) as i8,
405						u8::fill_with_bit(bits[10]) as i8,
406						u8::fill_with_bit(bits[9]) as i8,
407						u8::fill_with_bit(bits[8]) as i8,
408						u8::fill_with_bit(bits[7]) as i8,
409						u8::fill_with_bit(bits[6]) as i8,
410						u8::fill_with_bit(bits[5]) as i8,
411						u8::fill_with_bit(bits[4]) as i8,
412						u8::fill_with_bit(bits[3]) as i8,
413						u8::fill_with_bit(bits[2]) as i8,
414						u8::fill_with_bit(bits[1]) as i8,
415						u8::fill_with_bit(bits[0]) as i8,
416					)
417					.into()
418				},
419				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
420			},
421			1 => match log_block_len {
422				0 => unsafe {
423					let value =
424						U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
425
426					_mm_set1_epi8(value as i8).into()
427				},
428				1 => {
429					let bytes: [u8; 2] = array::from_fn(|i| {
430						U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
431					});
432
433					Self::from_fn::<u8>(|i| bytes[i / 8])
434				}
435				2 => {
436					let bytes: [u8; 4] = array::from_fn(|i| {
437						U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
438					});
439
440					Self::from_fn::<u8>(|i| bytes[i / 4])
441				}
442				3 => {
443					let bytes: [u8; 8] = array::from_fn(|i| {
444						U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
445							.spread_to_byte()
446					});
447
448					Self::from_fn::<u8>(|i| bytes[i / 2])
449				}
450				4 => {
451					let bytes: [u8; 16] = array::from_fn(|i| {
452						U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
453							.spread_to_byte()
454					});
455
456					Self::from_fn::<u8>(|i| bytes[i])
457				}
458				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
459			},
460			2 => match log_block_len {
461				0 => {
462					let value =
463						U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
464
465					unsafe { _mm_set1_epi8(value as i8).into() }
466				}
467				1 => {
468					let values: [u8; 2] = array::from_fn(|i| {
469						U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
470					});
471
472					Self::from_fn::<u8>(|i| values[i / 8])
473				}
474				2 => {
475					let values: [u8; 4] = array::from_fn(|i| {
476						U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
477							.spread_to_byte()
478					});
479
480					Self::from_fn::<u8>(|i| values[i / 4])
481				}
482				3 => {
483					let values: [u8; 8] = array::from_fn(|i| {
484						U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
485							.spread_to_byte()
486					});
487
488					Self::from_fn::<u8>(|i| values[i / 2])
489				}
490				4 => {
491					let values: [u8; 16] = array::from_fn(|i| {
492						U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
493							.spread_to_byte()
494					});
495
496					Self::from_fn::<u8>(|i| values[i])
497				}
498				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
499			},
500			3 => match log_block_len {
501				0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
502				1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
503				2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
504				3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
505				4 => self,
506				_ => panic!("unsupported block length"),
507			},
508			4 => match log_block_len {
509				0 => {
510					let value = (u128::from(self) >> (block_idx * 16)) as u16;
511
512					unsafe { _mm_set1_epi16(value as i16).into() }
513				}
514				1 => {
515					let values: [u16; 2] =
516						array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
517
518					Self::from_fn::<u16>(|i| values[i / 4])
519				}
520				2 => {
521					let values: [u16; 4] =
522						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
523
524					Self::from_fn::<u16>(|i| values[i / 2])
525				}
526				3 => self,
527				_ => panic!("unsupported block length"),
528			},
529			5 => match log_block_len {
530				0 => unsafe {
531					let value = (u128::from(self) >> (block_idx * 32)) as u32;
532
533					_mm_set1_epi32(value as i32).into()
534				},
535				1 => {
536					let values: [u32; 2] =
537						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
538
539					Self::from_fn::<u32>(|i| values[i / 2])
540				}
541				2 => self,
542				_ => panic!("unsupported block length"),
543			},
544			6 => match log_block_len {
545				0 => unsafe {
546					let value = (u128::from(self) >> (block_idx * 64)) as u64;
547
548					_mm_set1_epi64x(value as i64).into()
549				},
550				1 => self,
551				_ => panic!("unsupported block length"),
552			},
553			7 => self,
554			_ => panic!("unsupported bit length"),
555		}
556	}
557}
558
559unsafe impl Zeroable for M128 {}
560
561unsafe impl Pod for M128 {}
562
563unsafe impl Send for M128 {}
564
565unsafe impl Sync for M128 {}
566
567static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
568static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
569static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
570static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
571
572const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
573	log_block_len: usize,
574	t_log_bits: usize,
575) -> [M128; BLOCK_IDX_AMOUNT] {
576	let element_log_width = t_log_bits - 3;
577
578	let element_width = 1 << element_log_width;
579
580	let block_size = 1 << (log_block_len + element_log_width);
581	let repeat = 1 << (4 - element_log_width - log_block_len);
582	let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
583
584	let mut block_idx = 0;
585
586	while block_idx < BLOCK_IDX_AMOUNT {
587		let base = block_idx * block_size;
588		let mut j = 0;
589		while j < 16 {
590			masks[block_idx][j] =
591				(base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
592			j += 1;
593		}
594		block_idx += 1;
595	}
596	let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
597
598	let mut block_idx = 0;
599
600	while block_idx < BLOCK_IDX_AMOUNT {
601		m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
602		block_idx += 1;
603	}
604
605	m128_masks
606}
607
608impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
609	fn from(value: __m128i) -> Self {
610		M128::from(value).into()
611	}
612}
613
614impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
615	fn from(value: u128) -> Self {
616		M128::from(value).into()
617	}
618}
619
620impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
621	fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
622		value.to_underlier().into()
623	}
624}
625
626#[inline]
627unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
628	match log_block_len {
629		0 => unsafe {
630			let mask = _mm_set1_epi8(0x55i8);
631			interleave_bits_imm::<1>(a, b, mask)
632		},
633		1 => unsafe {
634			let mask = _mm_set1_epi8(0x33i8);
635			interleave_bits_imm::<2>(a, b, mask)
636		},
637		2 => unsafe {
638			let mask = _mm_set1_epi8(0x0fi8);
639			interleave_bits_imm::<4>(a, b, mask)
640		},
641		3 => unsafe {
642			let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
643			let a = _mm_shuffle_epi8(a, shuffle);
644			let b = _mm_shuffle_epi8(b, shuffle);
645			let a_prime = _mm_unpacklo_epi8(a, b);
646			let b_prime = _mm_unpackhi_epi8(a, b);
647			(a_prime, b_prime)
648		},
649		4 => unsafe {
650			let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
651			let a = _mm_shuffle_epi8(a, shuffle);
652			let b = _mm_shuffle_epi8(b, shuffle);
653			let a_prime = _mm_unpacklo_epi16(a, b);
654			let b_prime = _mm_unpackhi_epi16(a, b);
655			(a_prime, b_prime)
656		},
657		5 => unsafe {
658			let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
659			let a = _mm_shuffle_epi8(a, shuffle);
660			let b = _mm_shuffle_epi8(b, shuffle);
661			let a_prime = _mm_unpacklo_epi32(a, b);
662			let b_prime = _mm_unpackhi_epi32(a, b);
663			(a_prime, b_prime)
664		},
665		6 => unsafe {
666			let a_prime = _mm_unpacklo_epi64(a, b);
667			let b_prime = _mm_unpackhi_epi64(a, b);
668			(a_prime, b_prime)
669		},
670		_ => panic!("unsupported block length"),
671	}
672}
673
674#[inline]
675unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
676	a: __m128i,
677	b: __m128i,
678	mask: __m128i,
679) -> (__m128i, __m128i) {
680	unsafe {
681		let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
682		let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
683		let b_prime = _mm_xor_si128(b, t);
684		(a_prime, b_prime)
685	}
686}
687
688// Divisible implementations using SIMD extract/insert intrinsics
689
690impl Divisible<u128> for M128 {
691	const LOG_N: usize = 0;
692
693	#[inline]
694	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone {
695		std::iter::once(u128::from(value))
696	}
697
698	#[inline]
699	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
700		std::iter::once(u128::from(*value))
701	}
702
703	#[inline]
704	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
705		slice.iter().map(|&v| u128::from(v))
706	}
707
708	#[inline]
709	fn get(self, index: usize) -> u128 {
710		assert!(index == 0, "index out of bounds");
711		u128::from(self)
712	}
713
714	#[inline]
715	fn set(self, index: usize, val: u128) -> Self {
716		assert!(index == 0, "index out of bounds");
717		Self::from(val)
718	}
719
720	#[inline]
721	fn broadcast(val: u128) -> Self {
722		Self::from(val)
723	}
724
725	#[inline]
726	fn from_iter(mut iter: impl Iterator<Item = u128>) -> Self {
727		iter.next().map(Self::from).unwrap_or(Self::ZERO)
728	}
729}
730
731impl Divisible<u64> for M128 {
732	const LOG_N: usize = 1;
733
734	#[inline]
735	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone {
736		mapget::value_iter(value)
737	}
738
739	#[inline]
740	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
741		mapget::value_iter(*value)
742	}
743
744	#[inline]
745	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
746		mapget::slice_iter(slice)
747	}
748
749	#[inline]
750	fn get(self, index: usize) -> u64 {
751		unsafe {
752			match index {
753				0 => _mm_extract_epi64(self.0, 0) as u64,
754				1 => _mm_extract_epi64(self.0, 1) as u64,
755				_ => panic!("index out of bounds"),
756			}
757		}
758	}
759
760	#[inline]
761	fn set(self, index: usize, val: u64) -> Self {
762		unsafe {
763			match index {
764				0 => Self(_mm_insert_epi64(self.0, val as i64, 0)),
765				1 => Self(_mm_insert_epi64(self.0, val as i64, 1)),
766				_ => panic!("index out of bounds"),
767			}
768		}
769	}
770
771	#[inline]
772	fn broadcast(val: u64) -> Self {
773		unsafe { Self(_mm_set1_epi64x(val as i64)) }
774	}
775
776	#[inline]
777	fn from_iter(iter: impl Iterator<Item = u64>) -> Self {
778		let mut result = Self::ZERO;
779		let arr: &mut [u64; 2] = bytemuck::cast_mut(&mut result);
780		for (i, val) in iter.take(2).enumerate() {
781			arr[i] = val;
782		}
783		result
784	}
785}
786
787impl Divisible<u32> for M128 {
788	const LOG_N: usize = 2;
789
790	#[inline]
791	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone {
792		mapget::value_iter(value)
793	}
794
795	#[inline]
796	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
797		mapget::value_iter(*value)
798	}
799
800	#[inline]
801	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
802		mapget::slice_iter(slice)
803	}
804
805	#[inline]
806	fn get(self, index: usize) -> u32 {
807		unsafe {
808			match index {
809				0 => _mm_extract_epi32(self.0, 0) as u32,
810				1 => _mm_extract_epi32(self.0, 1) as u32,
811				2 => _mm_extract_epi32(self.0, 2) as u32,
812				3 => _mm_extract_epi32(self.0, 3) as u32,
813				_ => panic!("index out of bounds"),
814			}
815		}
816	}
817
818	#[inline]
819	fn set(self, index: usize, val: u32) -> Self {
820		unsafe {
821			match index {
822				0 => Self(_mm_insert_epi32(self.0, val as i32, 0)),
823				1 => Self(_mm_insert_epi32(self.0, val as i32, 1)),
824				2 => Self(_mm_insert_epi32(self.0, val as i32, 2)),
825				3 => Self(_mm_insert_epi32(self.0, val as i32, 3)),
826				_ => panic!("index out of bounds"),
827			}
828		}
829	}
830
831	#[inline]
832	fn broadcast(val: u32) -> Self {
833		unsafe { Self(_mm_set1_epi32(val as i32)) }
834	}
835
836	#[inline]
837	fn from_iter(iter: impl Iterator<Item = u32>) -> Self {
838		let mut result = Self::ZERO;
839		let arr: &mut [u32; 4] = bytemuck::cast_mut(&mut result);
840		for (i, val) in iter.take(4).enumerate() {
841			arr[i] = val;
842		}
843		result
844	}
845}
846
847impl Divisible<u16> for M128 {
848	const LOG_N: usize = 3;
849
850	#[inline]
851	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone {
852		mapget::value_iter(value)
853	}
854
855	#[inline]
856	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
857		mapget::value_iter(*value)
858	}
859
860	#[inline]
861	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
862		mapget::slice_iter(slice)
863	}
864
865	#[inline]
866	fn get(self, index: usize) -> u16 {
867		unsafe {
868			match index {
869				0 => _mm_extract_epi16(self.0, 0) as u16,
870				1 => _mm_extract_epi16(self.0, 1) as u16,
871				2 => _mm_extract_epi16(self.0, 2) as u16,
872				3 => _mm_extract_epi16(self.0, 3) as u16,
873				4 => _mm_extract_epi16(self.0, 4) as u16,
874				5 => _mm_extract_epi16(self.0, 5) as u16,
875				6 => _mm_extract_epi16(self.0, 6) as u16,
876				7 => _mm_extract_epi16(self.0, 7) as u16,
877				_ => panic!("index out of bounds"),
878			}
879		}
880	}
881
882	#[inline]
883	fn set(self, index: usize, val: u16) -> Self {
884		unsafe {
885			match index {
886				0 => Self(_mm_insert_epi16(self.0, val as i32, 0)),
887				1 => Self(_mm_insert_epi16(self.0, val as i32, 1)),
888				2 => Self(_mm_insert_epi16(self.0, val as i32, 2)),
889				3 => Self(_mm_insert_epi16(self.0, val as i32, 3)),
890				4 => Self(_mm_insert_epi16(self.0, val as i32, 4)),
891				5 => Self(_mm_insert_epi16(self.0, val as i32, 5)),
892				6 => Self(_mm_insert_epi16(self.0, val as i32, 6)),
893				7 => Self(_mm_insert_epi16(self.0, val as i32, 7)),
894				_ => panic!("index out of bounds"),
895			}
896		}
897	}
898
899	#[inline]
900	fn broadcast(val: u16) -> Self {
901		unsafe { Self(_mm_set1_epi16(val as i16)) }
902	}
903
904	#[inline]
905	fn from_iter(iter: impl Iterator<Item = u16>) -> Self {
906		let mut result = Self::ZERO;
907		let arr: &mut [u16; 8] = bytemuck::cast_mut(&mut result);
908		for (i, val) in iter.take(8).enumerate() {
909			arr[i] = val;
910		}
911		result
912	}
913}
914
915impl Divisible<u8> for M128 {
916	const LOG_N: usize = 4;
917
918	#[inline]
919	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone {
920		mapget::value_iter(value)
921	}
922
923	#[inline]
924	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
925		mapget::value_iter(*value)
926	}
927
928	#[inline]
929	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
930		mapget::slice_iter(slice)
931	}
932
933	#[inline]
934	fn get(self, index: usize) -> u8 {
935		unsafe {
936			match index {
937				0 => _mm_extract_epi8(self.0, 0) as u8,
938				1 => _mm_extract_epi8(self.0, 1) as u8,
939				2 => _mm_extract_epi8(self.0, 2) as u8,
940				3 => _mm_extract_epi8(self.0, 3) as u8,
941				4 => _mm_extract_epi8(self.0, 4) as u8,
942				5 => _mm_extract_epi8(self.0, 5) as u8,
943				6 => _mm_extract_epi8(self.0, 6) as u8,
944				7 => _mm_extract_epi8(self.0, 7) as u8,
945				8 => _mm_extract_epi8(self.0, 8) as u8,
946				9 => _mm_extract_epi8(self.0, 9) as u8,
947				10 => _mm_extract_epi8(self.0, 10) as u8,
948				11 => _mm_extract_epi8(self.0, 11) as u8,
949				12 => _mm_extract_epi8(self.0, 12) as u8,
950				13 => _mm_extract_epi8(self.0, 13) as u8,
951				14 => _mm_extract_epi8(self.0, 14) as u8,
952				15 => _mm_extract_epi8(self.0, 15) as u8,
953				_ => panic!("index out of bounds"),
954			}
955		}
956	}
957
958	#[inline]
959	fn set(self, index: usize, val: u8) -> Self {
960		unsafe {
961			match index {
962				0 => Self(_mm_insert_epi8(self.0, val as i32, 0)),
963				1 => Self(_mm_insert_epi8(self.0, val as i32, 1)),
964				2 => Self(_mm_insert_epi8(self.0, val as i32, 2)),
965				3 => Self(_mm_insert_epi8(self.0, val as i32, 3)),
966				4 => Self(_mm_insert_epi8(self.0, val as i32, 4)),
967				5 => Self(_mm_insert_epi8(self.0, val as i32, 5)),
968				6 => Self(_mm_insert_epi8(self.0, val as i32, 6)),
969				7 => Self(_mm_insert_epi8(self.0, val as i32, 7)),
970				8 => Self(_mm_insert_epi8(self.0, val as i32, 8)),
971				9 => Self(_mm_insert_epi8(self.0, val as i32, 9)),
972				10 => Self(_mm_insert_epi8(self.0, val as i32, 10)),
973				11 => Self(_mm_insert_epi8(self.0, val as i32, 11)),
974				12 => Self(_mm_insert_epi8(self.0, val as i32, 12)),
975				13 => Self(_mm_insert_epi8(self.0, val as i32, 13)),
976				14 => Self(_mm_insert_epi8(self.0, val as i32, 14)),
977				15 => Self(_mm_insert_epi8(self.0, val as i32, 15)),
978				_ => panic!("index out of bounds"),
979			}
980		}
981	}
982
983	#[inline]
984	fn broadcast(val: u8) -> Self {
985		unsafe { Self(_mm_set1_epi8(val as i8)) }
986	}
987
988	#[inline]
989	fn from_iter(iter: impl Iterator<Item = u8>) -> Self {
990		let mut result = Self::ZERO;
991		let arr: &mut [u8; 16] = bytemuck::cast_mut(&mut result);
992		for (i, val) in iter.take(16).enumerate() {
993			arr[i] = val;
994		}
995		result
996	}
997}
998
999#[cfg(test)]
1000mod tests {
1001	use binius_utils::bytes::BytesMut;
1002	use proptest::{arbitrary::any, proptest};
1003	use rand::{SeedableRng, rngs::StdRng};
1004
1005	use super::*;
1006
1007	fn check_roundtrip<T>(val: M128)
1008	where
1009		T: From<M128>,
1010		M128: From<T>,
1011	{
1012		assert_eq!(M128::from(T::from(val)), val);
1013	}
1014
1015	#[test]
1016	fn test_constants() {
1017		assert_eq!(M128::default(), M128::ZERO);
1018		assert_eq!(M128::from(0u128), M128::ZERO);
1019		assert_eq!(M128::from(1u128), M128::ONE);
1020	}
1021
1022	fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1023		(value >> (index << log_block_len)) & M128::from(1u128 << log_block_len)
1024	}
1025
1026	proptest! {
1027		#[test]
1028		fn test_conversion(a in any::<u128>()) {
1029			check_roundtrip::<u128>(a.into());
1030			check_roundtrip::<__m128i>(a.into());
1031		}
1032
1033		#[test]
1034		fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1035			assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1036			assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1037			assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1038		}
1039
1040		#[test]
1041		fn test_negate(a in any::<u128>()) {
1042			assert_eq!(M128::from(!a), !M128::from(a))
1043		}
1044
1045		#[test]
1046		fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1047			assert_eq!(M128::from(a << b), M128::from(a) << b);
1048			assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1049		}
1050
1051		#[test]
1052		fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1053			let a = M128::from(a);
1054			let b = M128::from(b);
1055
1056			let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1057			let (c, d) = (M128::from(c), M128::from(d));
1058
1059			for i in (0..128>>height).step_by(2) {
1060				assert_eq!(get(c, height, i), get(a, height, i));
1061				assert_eq!(get(c, height, i+1), get(b, height, i));
1062				assert_eq!(get(d, height, i), get(a, height, i+1));
1063				assert_eq!(get(d, height, i+1), get(b, height, i+1));
1064			}
1065		}
1066	}
1067
1068	#[test]
1069	fn test_fill_with_bit() {
1070		assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1071		assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1072	}
1073
1074	#[test]
1075	fn test_eq() {
1076		let a = M128::from(0u128);
1077		let b = M128::from(42u128);
1078		let c = M128::from(u128::MAX);
1079
1080		assert_eq!(a, a);
1081		assert_eq!(b, b);
1082		assert_eq!(c, c);
1083
1084		assert_ne!(a, b);
1085		assert_ne!(a, c);
1086		assert_ne!(b, c);
1087	}
1088
1089	#[test]
1090	fn test_serialize_and_deserialize_m128() {
1091		let mut rng = StdRng::from_seed([0; 32]);
1092
1093		let original_value = M128::from(rng.random::<u128>());
1094
1095		let mut buf = BytesMut::new();
1096		original_value.serialize(&mut buf).unwrap();
1097
1098		let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1099
1100		assert_eq!(original_value, deserialized_value);
1101	}
1102}