Skip to main content

binius_field/arch/x86_64/
m128.rs

1// Copyright 2024-2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{
5	arch::x86_64::*,
6	array, mem,
7	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
8};
9
10use binius_utils::{
11	DeserializeBytes, SerializationError, SerializeBytes,
12	bytes::{Buf, BufMut},
13	serialization::{assert_enough_data_for, assert_enough_space_for},
14};
15use bytemuck::{Pod, Zeroable};
16use rand::{
17	distr::{Distribution, StandardUniform},
18	prelude::*,
19};
20use seq_macro::seq;
21
22use crate::{
23	BinaryField,
24	arch::portable::packed::PackedPrimitiveType,
25	underlier::{
26		Divisible, NumCast, SmallU, SpreadToByte, U2, U4, UnderlierType, impl_divisible_bitmask,
27		impl_divisible_self, mapget, spread_fallback,
28	},
29};
30
31pub const fn m128i_from_u128(x: u128) -> __m128i {
32	// Static assertion that u128 and __m128i have equal alignment
33	let _: [(); align_of::<u128>()] = [(); align_of::<__m128i>()];
34	unsafe { mem::transmute(x) }
35}
36
37/// 128-bit value that is used for 128-bit SIMD operations
38#[derive(Copy, Clone)]
39#[repr(transparent)]
40pub struct M128(pub(super) __m128i);
41
42impl M128 {
43	#[inline(always)]
44	pub const fn from_u128(val: u128) -> Self {
45		Self(m128i_from_u128(val))
46	}
47}
48
49impl From<__m128i> for M128 {
50	#[inline(always)]
51	fn from(value: __m128i) -> Self {
52		Self(value)
53	}
54}
55
56impl From<u128> for M128 {
57	fn from(value: u128) -> Self {
58		Self(m128i_from_u128(value))
59	}
60}
61
62impl From<u64> for M128 {
63	fn from(value: u64) -> Self {
64		Self::from(value as u128)
65	}
66}
67
68impl From<u32> for M128 {
69	fn from(value: u32) -> Self {
70		Self::from(value as u128)
71	}
72}
73
74impl From<u16> for M128 {
75	fn from(value: u16) -> Self {
76		Self::from(value as u128)
77	}
78}
79
80impl From<u8> for M128 {
81	fn from(value: u8) -> Self {
82		Self::from(value as u128)
83	}
84}
85
86impl<const N: usize> From<SmallU<N>> for M128 {
87	fn from(value: SmallU<N>) -> Self {
88		Self::from(value.val() as u128)
89	}
90}
91
92impl From<M128> for u128 {
93	fn from(value: M128) -> Self {
94		let mut result = 0u128;
95		unsafe {
96			// Safety: u128 is 16-byte aligned
97			assert_eq!(align_of::<u128>(), 16);
98			_mm_store_si128(&raw mut result as *mut __m128i, value.0)
99		};
100		result
101	}
102}
103
104impl From<M128> for __m128i {
105	#[inline(always)]
106	fn from(value: M128) -> Self {
107		value.0
108	}
109}
110
111impl SerializeBytes for M128 {
112	fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
113		assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
114
115		let raw_value: u128 = (*self).into();
116
117		write_buf.put_u128_le(raw_value);
118		Ok(())
119	}
120}
121
122impl DeserializeBytes for M128 {
123	fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
124	where
125		Self: Sized,
126	{
127		assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
128
129		let raw_value = read_buf.get_u128_le();
130
131		Ok(Self::from(raw_value))
132	}
133}
134
135impl_divisible_bitmask!(M128, 1, 2, 4);
136
137impl<U: NumCast<u128>> NumCast<M128> for U {
138	#[inline(always)]
139	fn num_cast_from(val: M128) -> Self {
140		Self::num_cast_from(u128::from(val))
141	}
142}
143
144impl Default for M128 {
145	#[inline(always)]
146	fn default() -> Self {
147		Self(unsafe { _mm_setzero_si128() })
148	}
149}
150
151impl BitAnd for M128 {
152	type Output = Self;
153
154	#[inline(always)]
155	fn bitand(self, rhs: Self) -> Self::Output {
156		Self(unsafe { _mm_and_si128(self.0, rhs.0) })
157	}
158}
159
160impl BitAndAssign for M128 {
161	#[inline(always)]
162	fn bitand_assign(&mut self, rhs: Self) {
163		*self = *self & rhs
164	}
165}
166
167impl BitOr for M128 {
168	type Output = Self;
169
170	#[inline(always)]
171	fn bitor(self, rhs: Self) -> Self::Output {
172		Self(unsafe { _mm_or_si128(self.0, rhs.0) })
173	}
174}
175
176impl BitOrAssign for M128 {
177	#[inline(always)]
178	fn bitor_assign(&mut self, rhs: Self) {
179		*self = *self | rhs
180	}
181}
182
183impl BitXor for M128 {
184	type Output = Self;
185
186	#[inline(always)]
187	fn bitxor(self, rhs: Self) -> Self::Output {
188		Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
189	}
190}
191
192impl BitXorAssign for M128 {
193	#[inline(always)]
194	fn bitxor_assign(&mut self, rhs: Self) {
195		*self = *self ^ rhs;
196	}
197}
198
199impl Not for M128 {
200	type Output = Self;
201
202	fn not(self) -> Self::Output {
203		const ONES: M128 = M128::from_u128(u128::MAX);
204
205		self ^ ONES
206	}
207}
208
209/// `std::cmp::max` isn't const, so we need our own implementation
210const fn max_i32(left: i32, right: i32) -> i32 {
211	if left > right { left } else { right }
212}
213
214/// This solution shows 4X better performance.
215/// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed
216/// as constant and Rust currently doesn't allow passing expressions (`count - 64`) where variable
217/// is a generic constant parameter. Source: <https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688>
218macro_rules! bitshift_128b {
219	($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
220		unsafe {
221			let carry = $byte_shift($val, 8);
222			seq!(N in 64..128 {
223				if $shift == N {
224					return $bit_shift_64(
225						carry,
226						crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
227					).into();
228				}
229			});
230			seq!(N in 0..64 {
231				if $shift == N {
232					let carry = $bit_shift_64_opposite(
233						carry,
234						crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
235					);
236
237					let val = $bit_shift_64($val, N);
238					return $or(val, carry).into();
239				}
240			});
241
242			return Default::default()
243		}
244	};
245}
246
247impl Shr<usize> for M128 {
248	type Output = Self;
249
250	#[inline(always)]
251	fn shr(self, rhs: usize) -> Self::Output {
252		// This implementation is effective when `rhs` is known at compile-time.
253		// In our code this is always the case.
254		bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
255	}
256}
257
258impl Shl<usize> for M128 {
259	type Output = Self;
260
261	#[inline(always)]
262	fn shl(self, rhs: usize) -> Self::Output {
263		// This implementation is effective when `rhs` is known at compile-time.
264		// In our code this is always the case.
265		bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
266	}
267}
268
269impl PartialEq for M128 {
270	fn eq(&self, other: &Self) -> bool {
271		unsafe {
272			let neq = _mm_xor_si128(self.0, other.0);
273			_mm_test_all_zeros(neq, neq) == 1
274		}
275	}
276}
277
278impl Eq for M128 {}
279
280impl PartialOrd for M128 {
281	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
282		Some(self.cmp(other))
283	}
284}
285
286impl Ord for M128 {
287	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
288		u128::from(*self).cmp(&u128::from(*other))
289	}
290}
291
292impl std::hash::Hash for M128 {
293	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
294		u128::from(*self).hash(state);
295	}
296}
297
298impl std::fmt::LowerHex for M128 {
299	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300		std::fmt::LowerHex::fmt(&u128::from(*self), f)
301	}
302}
303
304impl Distribution<M128> for StandardUniform {
305	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
306		M128(rng.random())
307	}
308}
309
310impl std::fmt::Display for M128 {
311	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
312		let data: u128 = (*self).into();
313		write!(f, "{data:02X?}")
314	}
315}
316
317impl std::fmt::Debug for M128 {
318	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319		write!(f, "M128({self})")
320	}
321}
322
323impl UnderlierType for M128 {
324	const LOG_BITS: usize = 7;
325	const ZERO: Self = { Self::from_u128(0) };
326	const ONE: Self = { Self::from_u128(1) };
327	const ONES: Self = { Self::from_u128(u128::MAX) };
328
329	#[inline(always)]
330	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
331		unsafe {
332			let (c, d) = interleave_bits(
333				Into::<Self>::into(self).into(),
334				Into::<Self>::into(other).into(),
335				log_block_len,
336			);
337			(Self::from(c), Self::from(d))
338		}
339	}
340
341	#[inline(always)]
342	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
343	where
344		T: UnderlierType,
345		Self: Divisible<T>,
346	{
347		match T::LOG_BITS {
348			0 => match log_block_len {
349				0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
350				1 => unsafe {
351					let bits: [u8; 2] =
352						array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
353
354					_mm_set_epi64x(
355						u64::fill_with_bit(bits[1]) as i64,
356						u64::fill_with_bit(bits[0]) as i64,
357					)
358					.into()
359				},
360				2 => unsafe {
361					let bits: [u8; 4] =
362						array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
363
364					_mm_set_epi32(
365						u32::fill_with_bit(bits[3]) as i32,
366						u32::fill_with_bit(bits[2]) as i32,
367						u32::fill_with_bit(bits[1]) as i32,
368						u32::fill_with_bit(bits[0]) as i32,
369					)
370					.into()
371				},
372				3 => unsafe {
373					let bits: [u8; 8] =
374						array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
375
376					_mm_set_epi16(
377						u16::fill_with_bit(bits[7]) as i16,
378						u16::fill_with_bit(bits[6]) as i16,
379						u16::fill_with_bit(bits[5]) as i16,
380						u16::fill_with_bit(bits[4]) as i16,
381						u16::fill_with_bit(bits[3]) as i16,
382						u16::fill_with_bit(bits[2]) as i16,
383						u16::fill_with_bit(bits[1]) as i16,
384						u16::fill_with_bit(bits[0]) as i16,
385					)
386					.into()
387				},
388				4 => unsafe {
389					let bits: [u8; 16] =
390						array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
391
392					_mm_set_epi8(
393						u8::fill_with_bit(bits[15]) as i8,
394						u8::fill_with_bit(bits[14]) as i8,
395						u8::fill_with_bit(bits[13]) as i8,
396						u8::fill_with_bit(bits[12]) as i8,
397						u8::fill_with_bit(bits[11]) as i8,
398						u8::fill_with_bit(bits[10]) as i8,
399						u8::fill_with_bit(bits[9]) as i8,
400						u8::fill_with_bit(bits[8]) as i8,
401						u8::fill_with_bit(bits[7]) as i8,
402						u8::fill_with_bit(bits[6]) as i8,
403						u8::fill_with_bit(bits[5]) as i8,
404						u8::fill_with_bit(bits[4]) as i8,
405						u8::fill_with_bit(bits[3]) as i8,
406						u8::fill_with_bit(bits[2]) as i8,
407						u8::fill_with_bit(bits[1]) as i8,
408						u8::fill_with_bit(bits[0]) as i8,
409					)
410					.into()
411				},
412				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
413			},
414			1 => match log_block_len {
415				0 => unsafe {
416					let value =
417						U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
418
419					_mm_set1_epi8(value as i8).into()
420				},
421				1 => {
422					let bytes: [u8; 2] = array::from_fn(|i| {
423						U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
424					});
425
426					Self::from_fn::<u8>(|i| bytes[i / 8])
427				}
428				2 => {
429					let bytes: [u8; 4] = array::from_fn(|i| {
430						U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
431					});
432
433					Self::from_fn::<u8>(|i| bytes[i / 4])
434				}
435				3 => {
436					let bytes: [u8; 8] = array::from_fn(|i| {
437						U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
438							.spread_to_byte()
439					});
440
441					Self::from_fn::<u8>(|i| bytes[i / 2])
442				}
443				4 => {
444					let bytes: [u8; 16] = array::from_fn(|i| {
445						U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
446							.spread_to_byte()
447					});
448
449					Self::from_fn::<u8>(|i| bytes[i])
450				}
451				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
452			},
453			2 => match log_block_len {
454				0 => {
455					let value =
456						U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
457
458					unsafe { _mm_set1_epi8(value as i8).into() }
459				}
460				1 => {
461					let values: [u8; 2] = array::from_fn(|i| {
462						U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
463					});
464
465					Self::from_fn::<u8>(|i| values[i / 8])
466				}
467				2 => {
468					let values: [u8; 4] = array::from_fn(|i| {
469						U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
470							.spread_to_byte()
471					});
472
473					Self::from_fn::<u8>(|i| values[i / 4])
474				}
475				3 => {
476					let values: [u8; 8] = array::from_fn(|i| {
477						U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
478							.spread_to_byte()
479					});
480
481					Self::from_fn::<u8>(|i| values[i / 2])
482				}
483				4 => {
484					let values: [u8; 16] = array::from_fn(|i| {
485						U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
486							.spread_to_byte()
487					});
488
489					Self::from_fn::<u8>(|i| values[i])
490				}
491				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
492			},
493			3 => match log_block_len {
494				0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
495				1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
496				2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
497				3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
498				4 => self,
499				_ => panic!("unsupported block length"),
500			},
501			4 => match log_block_len {
502				0 => {
503					let value = (u128::from(self) >> (block_idx * 16)) as u16;
504
505					unsafe { _mm_set1_epi16(value as i16).into() }
506				}
507				1 => {
508					let values: [u16; 2] =
509						array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
510
511					Self::from_fn::<u16>(|i| values[i / 4])
512				}
513				2 => {
514					let values: [u16; 4] =
515						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
516
517					Self::from_fn::<u16>(|i| values[i / 2])
518				}
519				3 => self,
520				_ => panic!("unsupported block length"),
521			},
522			5 => match log_block_len {
523				0 => unsafe {
524					let value = (u128::from(self) >> (block_idx * 32)) as u32;
525
526					_mm_set1_epi32(value as i32).into()
527				},
528				1 => {
529					let values: [u32; 2] =
530						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
531
532					Self::from_fn::<u32>(|i| values[i / 2])
533				}
534				2 => self,
535				_ => panic!("unsupported block length"),
536			},
537			6 => match log_block_len {
538				0 => unsafe {
539					let value = (u128::from(self) >> (block_idx * 64)) as u64;
540
541					_mm_set1_epi64x(value as i64).into()
542				},
543				1 => self,
544				_ => panic!("unsupported block length"),
545			},
546			7 => self,
547			_ => panic!("unsupported bit length"),
548		}
549	}
550}
551
552unsafe impl Zeroable for M128 {}
553
554unsafe impl Pod for M128 {}
555
556unsafe impl Send for M128 {}
557
558unsafe impl Sync for M128 {}
559
560static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
561static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
562static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
563static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
564
565const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
566	log_block_len: usize,
567	t_log_bits: usize,
568) -> [M128; BLOCK_IDX_AMOUNT] {
569	let element_log_width = t_log_bits - 3;
570
571	let element_width = 1 << element_log_width;
572
573	let block_size = 1 << (log_block_len + element_log_width);
574	let repeat = 1 << (4 - element_log_width - log_block_len);
575	let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
576
577	let mut block_idx = 0;
578
579	while block_idx < BLOCK_IDX_AMOUNT {
580		let base = block_idx * block_size;
581		let mut j = 0;
582		while j < 16 {
583			masks[block_idx][j] =
584				(base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
585			j += 1;
586		}
587		block_idx += 1;
588	}
589	let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
590
591	let mut block_idx = 0;
592
593	while block_idx < BLOCK_IDX_AMOUNT {
594		m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
595		block_idx += 1;
596	}
597
598	m128_masks
599}
600
601impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
602	fn from(value: __m128i) -> Self {
603		M128::from(value).into()
604	}
605}
606
607impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
608	fn from(value: u128) -> Self {
609		M128::from(value).into()
610	}
611}
612
613impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
614	fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
615		value.to_underlier().into()
616	}
617}
618
619#[inline]
620unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
621	match log_block_len {
622		0 => unsafe {
623			let mask = _mm_set1_epi8(0x55i8);
624			interleave_bits_imm::<1>(a, b, mask)
625		},
626		1 => unsafe {
627			let mask = _mm_set1_epi8(0x33i8);
628			interleave_bits_imm::<2>(a, b, mask)
629		},
630		2 => unsafe {
631			let mask = _mm_set1_epi8(0x0fi8);
632			interleave_bits_imm::<4>(a, b, mask)
633		},
634		3 => unsafe {
635			let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
636			let a = _mm_shuffle_epi8(a, shuffle);
637			let b = _mm_shuffle_epi8(b, shuffle);
638			let a_prime = _mm_unpacklo_epi8(a, b);
639			let b_prime = _mm_unpackhi_epi8(a, b);
640			(a_prime, b_prime)
641		},
642		4 => unsafe {
643			let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
644			let a = _mm_shuffle_epi8(a, shuffle);
645			let b = _mm_shuffle_epi8(b, shuffle);
646			let a_prime = _mm_unpacklo_epi16(a, b);
647			let b_prime = _mm_unpackhi_epi16(a, b);
648			(a_prime, b_prime)
649		},
650		5 => unsafe {
651			let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
652			let a = _mm_shuffle_epi8(a, shuffle);
653			let b = _mm_shuffle_epi8(b, shuffle);
654			let a_prime = _mm_unpacklo_epi32(a, b);
655			let b_prime = _mm_unpackhi_epi32(a, b);
656			(a_prime, b_prime)
657		},
658		6 => unsafe {
659			let a_prime = _mm_unpacklo_epi64(a, b);
660			let b_prime = _mm_unpackhi_epi64(a, b);
661			(a_prime, b_prime)
662		},
663		_ => panic!("unsupported block length"),
664	}
665}
666
667#[inline]
668unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
669	a: __m128i,
670	b: __m128i,
671	mask: __m128i,
672) -> (__m128i, __m128i) {
673	unsafe {
674		let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
675		let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
676		let b_prime = _mm_xor_si128(b, t);
677		(a_prime, b_prime)
678	}
679}
680
681// Divisible implementations using SIMD extract/insert intrinsics
682
683// Reflexive divisibility, needed when M128 is itself a field underlier (a width-1 packed field).
684impl_divisible_self!(M128);
685
686impl Divisible<u128> for M128 {
687	const LOG_N: usize = 0;
688
689	#[inline]
690	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone {
691		std::iter::once(u128::from(value))
692	}
693
694	#[inline]
695	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
696		std::iter::once(u128::from(*value))
697	}
698
699	#[inline]
700	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
701		slice.iter().map(|&v| u128::from(v))
702	}
703
704	#[inline]
705	unsafe fn get_unchecked(&self, _index: usize) -> u128 {
706		u128::from(*self)
707	}
708
709	#[inline]
710	unsafe fn set_unchecked(&mut self, _index: usize, val: u128) {
711		*self = Self::from(val);
712	}
713
714	#[inline]
715	fn broadcast(val: u128) -> Self {
716		Self::from(val)
717	}
718
719	#[inline]
720	fn from_iter(mut iter: impl Iterator<Item = u128>) -> Self {
721		iter.next().map(Self::from).unwrap_or(Self::ZERO)
722	}
723}
724
725impl Divisible<u64> for M128 {
726	const LOG_N: usize = 1;
727
728	#[inline]
729	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone {
730		mapget::value_iter(value)
731	}
732
733	#[inline]
734	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
735		mapget::value_iter(*value)
736	}
737
738	#[inline]
739	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
740		mapget::slice_iter(slice)
741	}
742
743	#[inline]
744	unsafe fn get_unchecked(&self, index: usize) -> u64 {
745		unsafe {
746			match index {
747				0 => _mm_extract_epi64(self.0, 0) as u64,
748				1 => _mm_extract_epi64(self.0, 1) as u64,
749				_ => core::hint::unreachable_unchecked(),
750			}
751		}
752	}
753
754	#[inline]
755	unsafe fn set_unchecked(&mut self, index: usize, val: u64) {
756		*self = unsafe {
757			match index {
758				0 => Self(_mm_insert_epi64(self.0, val as i64, 0)),
759				1 => Self(_mm_insert_epi64(self.0, val as i64, 1)),
760				_ => core::hint::unreachable_unchecked(),
761			}
762		};
763	}
764
765	#[inline]
766	fn broadcast(val: u64) -> Self {
767		unsafe { Self(_mm_set1_epi64x(val as i64)) }
768	}
769
770	#[inline]
771	fn from_iter(iter: impl Iterator<Item = u64>) -> Self {
772		let mut result = Self::ZERO;
773		let arr: &mut [u64; 2] = bytemuck::cast_mut(&mut result);
774		for (i, val) in iter.take(2).enumerate() {
775			arr[i] = val;
776		}
777		result
778	}
779}
780
781impl Divisible<u32> for M128 {
782	const LOG_N: usize = 2;
783
784	#[inline]
785	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone {
786		mapget::value_iter(value)
787	}
788
789	#[inline]
790	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
791		mapget::value_iter(*value)
792	}
793
794	#[inline]
795	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
796		mapget::slice_iter(slice)
797	}
798
799	#[inline]
800	unsafe fn get_unchecked(&self, index: usize) -> u32 {
801		unsafe {
802			match index {
803				0 => _mm_extract_epi32(self.0, 0) as u32,
804				1 => _mm_extract_epi32(self.0, 1) as u32,
805				2 => _mm_extract_epi32(self.0, 2) as u32,
806				3 => _mm_extract_epi32(self.0, 3) as u32,
807				_ => core::hint::unreachable_unchecked(),
808			}
809		}
810	}
811
812	#[inline]
813	unsafe fn set_unchecked(&mut self, index: usize, val: u32) {
814		*self = unsafe {
815			match index {
816				0 => Self(_mm_insert_epi32(self.0, val as i32, 0)),
817				1 => Self(_mm_insert_epi32(self.0, val as i32, 1)),
818				2 => Self(_mm_insert_epi32(self.0, val as i32, 2)),
819				3 => Self(_mm_insert_epi32(self.0, val as i32, 3)),
820				_ => core::hint::unreachable_unchecked(),
821			}
822		};
823	}
824
825	#[inline]
826	fn broadcast(val: u32) -> Self {
827		unsafe { Self(_mm_set1_epi32(val as i32)) }
828	}
829
830	#[inline]
831	fn from_iter(iter: impl Iterator<Item = u32>) -> Self {
832		let mut result = Self::ZERO;
833		let arr: &mut [u32; 4] = bytemuck::cast_mut(&mut result);
834		for (i, val) in iter.take(4).enumerate() {
835			arr[i] = val;
836		}
837		result
838	}
839}
840
841impl Divisible<u16> for M128 {
842	const LOG_N: usize = 3;
843
844	#[inline]
845	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone {
846		mapget::value_iter(value)
847	}
848
849	#[inline]
850	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
851		mapget::value_iter(*value)
852	}
853
854	#[inline]
855	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
856		mapget::slice_iter(slice)
857	}
858
859	#[inline]
860	unsafe fn get_unchecked(&self, index: usize) -> u16 {
861		unsafe {
862			match index {
863				0 => _mm_extract_epi16(self.0, 0) as u16,
864				1 => _mm_extract_epi16(self.0, 1) as u16,
865				2 => _mm_extract_epi16(self.0, 2) as u16,
866				3 => _mm_extract_epi16(self.0, 3) as u16,
867				4 => _mm_extract_epi16(self.0, 4) as u16,
868				5 => _mm_extract_epi16(self.0, 5) as u16,
869				6 => _mm_extract_epi16(self.0, 6) as u16,
870				7 => _mm_extract_epi16(self.0, 7) as u16,
871				_ => core::hint::unreachable_unchecked(),
872			}
873		}
874	}
875
876	#[inline]
877	unsafe fn set_unchecked(&mut self, index: usize, val: u16) {
878		*self = unsafe {
879			match index {
880				0 => Self(_mm_insert_epi16(self.0, val as i32, 0)),
881				1 => Self(_mm_insert_epi16(self.0, val as i32, 1)),
882				2 => Self(_mm_insert_epi16(self.0, val as i32, 2)),
883				3 => Self(_mm_insert_epi16(self.0, val as i32, 3)),
884				4 => Self(_mm_insert_epi16(self.0, val as i32, 4)),
885				5 => Self(_mm_insert_epi16(self.0, val as i32, 5)),
886				6 => Self(_mm_insert_epi16(self.0, val as i32, 6)),
887				7 => Self(_mm_insert_epi16(self.0, val as i32, 7)),
888				_ => core::hint::unreachable_unchecked(),
889			}
890		};
891	}
892
893	#[inline]
894	fn broadcast(val: u16) -> Self {
895		unsafe { Self(_mm_set1_epi16(val as i16)) }
896	}
897
898	#[inline]
899	fn from_iter(iter: impl Iterator<Item = u16>) -> Self {
900		let mut result = Self::ZERO;
901		let arr: &mut [u16; 8] = bytemuck::cast_mut(&mut result);
902		for (i, val) in iter.take(8).enumerate() {
903			arr[i] = val;
904		}
905		result
906	}
907}
908
909impl Divisible<u8> for M128 {
910	const LOG_N: usize = 4;
911
912	#[inline]
913	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone {
914		mapget::value_iter(value)
915	}
916
917	#[inline]
918	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
919		mapget::value_iter(*value)
920	}
921
922	#[inline]
923	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
924		mapget::slice_iter(slice)
925	}
926
927	#[inline]
928	unsafe fn get_unchecked(&self, index: usize) -> u8 {
929		unsafe {
930			match index {
931				0 => _mm_extract_epi8(self.0, 0) as u8,
932				1 => _mm_extract_epi8(self.0, 1) as u8,
933				2 => _mm_extract_epi8(self.0, 2) as u8,
934				3 => _mm_extract_epi8(self.0, 3) as u8,
935				4 => _mm_extract_epi8(self.0, 4) as u8,
936				5 => _mm_extract_epi8(self.0, 5) as u8,
937				6 => _mm_extract_epi8(self.0, 6) as u8,
938				7 => _mm_extract_epi8(self.0, 7) as u8,
939				8 => _mm_extract_epi8(self.0, 8) as u8,
940				9 => _mm_extract_epi8(self.0, 9) as u8,
941				10 => _mm_extract_epi8(self.0, 10) as u8,
942				11 => _mm_extract_epi8(self.0, 11) as u8,
943				12 => _mm_extract_epi8(self.0, 12) as u8,
944				13 => _mm_extract_epi8(self.0, 13) as u8,
945				14 => _mm_extract_epi8(self.0, 14) as u8,
946				15 => _mm_extract_epi8(self.0, 15) as u8,
947				_ => core::hint::unreachable_unchecked(),
948			}
949		}
950	}
951
952	#[inline]
953	unsafe fn set_unchecked(&mut self, index: usize, val: u8) {
954		*self = unsafe {
955			match index {
956				0 => Self(_mm_insert_epi8(self.0, val as i32, 0)),
957				1 => Self(_mm_insert_epi8(self.0, val as i32, 1)),
958				2 => Self(_mm_insert_epi8(self.0, val as i32, 2)),
959				3 => Self(_mm_insert_epi8(self.0, val as i32, 3)),
960				4 => Self(_mm_insert_epi8(self.0, val as i32, 4)),
961				5 => Self(_mm_insert_epi8(self.0, val as i32, 5)),
962				6 => Self(_mm_insert_epi8(self.0, val as i32, 6)),
963				7 => Self(_mm_insert_epi8(self.0, val as i32, 7)),
964				8 => Self(_mm_insert_epi8(self.0, val as i32, 8)),
965				9 => Self(_mm_insert_epi8(self.0, val as i32, 9)),
966				10 => Self(_mm_insert_epi8(self.0, val as i32, 10)),
967				11 => Self(_mm_insert_epi8(self.0, val as i32, 11)),
968				12 => Self(_mm_insert_epi8(self.0, val as i32, 12)),
969				13 => Self(_mm_insert_epi8(self.0, val as i32, 13)),
970				14 => Self(_mm_insert_epi8(self.0, val as i32, 14)),
971				15 => Self(_mm_insert_epi8(self.0, val as i32, 15)),
972				_ => core::hint::unreachable_unchecked(),
973			}
974		};
975	}
976
977	#[inline]
978	fn broadcast(val: u8) -> Self {
979		unsafe { Self(_mm_set1_epi8(val as i8)) }
980	}
981
982	#[inline]
983	fn from_iter(iter: impl Iterator<Item = u8>) -> Self {
984		let mut result = Self::ZERO;
985		let arr: &mut [u8; 16] = bytemuck::cast_mut(&mut result);
986		for (i, val) in iter.take(16).enumerate() {
987			arr[i] = val;
988		}
989		result
990	}
991}
992
993#[cfg(test)]
994mod tests {
995	use binius_utils::bytes::BytesMut;
996	use proptest::{arbitrary::any, proptest};
997	use rand::prelude::*;
998
999	use super::*;
1000
1001	fn check_roundtrip<T>(val: M128)
1002	where
1003		T: From<M128>,
1004		M128: From<T>,
1005	{
1006		assert_eq!(M128::from(T::from(val)), val);
1007	}
1008
1009	#[test]
1010	fn test_constants() {
1011		assert_eq!(M128::default(), M128::ZERO);
1012		assert_eq!(M128::from(0u128), M128::ZERO);
1013		assert_eq!(M128::from(1u128), M128::ONE);
1014	}
1015
1016	fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1017		(value >> (index << log_block_len)) & M128::from(1u128 << log_block_len)
1018	}
1019
1020	proptest! {
1021		#[test]
1022		fn test_conversion(a in any::<u128>()) {
1023			check_roundtrip::<u128>(a.into());
1024			check_roundtrip::<__m128i>(a.into());
1025		}
1026
1027		#[test]
1028		fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1029			assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1030			assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1031			assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1032		}
1033
1034		#[test]
1035		fn test_negate(a in any::<u128>()) {
1036			assert_eq!(M128::from(!a), !M128::from(a))
1037		}
1038
1039		#[test]
1040		fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1041			assert_eq!(M128::from(a << b), M128::from(a) << b);
1042			assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1043		}
1044
1045		#[test]
1046		fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1047			let a = M128::from(a);
1048			let b = M128::from(b);
1049
1050			let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1051			let (c, d) = (M128::from(c), M128::from(d));
1052
1053			for i in (0..128>>height).step_by(2) {
1054				assert_eq!(get(c, height, i), get(a, height, i));
1055				assert_eq!(get(c, height, i+1), get(b, height, i));
1056				assert_eq!(get(d, height, i), get(a, height, i+1));
1057				assert_eq!(get(d, height, i+1), get(b, height, i+1));
1058			}
1059		}
1060	}
1061
1062	#[test]
1063	fn test_fill_with_bit() {
1064		assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1065		assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1066	}
1067
1068	#[test]
1069	fn test_eq() {
1070		let a = M128::from(0u128);
1071		let b = M128::from(42u128);
1072		let c = M128::from(u128::MAX);
1073
1074		assert_eq!(a, a);
1075		assert_eq!(b, b);
1076		assert_eq!(c, c);
1077
1078		assert_ne!(a, b);
1079		assert_ne!(a, c);
1080		assert_ne!(b, c);
1081	}
1082
1083	#[test]
1084	fn test_serialize_and_deserialize_m128() {
1085		let mut rng = StdRng::from_seed([0; 32]);
1086
1087		let original_value = M128::from(rng.random::<u128>());
1088
1089		let mut buf = BytesMut::new();
1090		original_value.serialize(&mut buf).unwrap();
1091
1092		let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1093
1094		assert_eq!(original_value, deserialized_value);
1095	}
1096}