binius_field/arch/x86_64/
m128.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	arch::x86_64::*,
5	array,
6	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializeBytes,
11	bytes::{Buf, BufMut},
12	serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{
16	Rng,
17	distr::{Distribution, StandardUniform},
18};
19use seq_macro::seq;
20
21use crate::{
22	BinaryField,
23	arch::{
24		binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
25		portable::{
26			packed::{PackedPrimitiveType, impl_pack_scalar},
27			packed_arithmetic::{
28				UnderlierWithBitConstants, interleave_mask_even, interleave_mask_odd,
29			},
30		},
31	},
32	arithmetic_traits::Broadcast,
33	underlier::{
34		NumCast, SmallU, SpreadToByte, U1, U2, U4, UnderlierType, UnderlierWithBitOps,
35		WithUnderlier, impl_divisible, impl_iteration, spread_fallback, unpack_hi_128b_fallback,
36		unpack_lo_128b_fallback,
37	},
38};
39
40/// 128-bit value that is used for 128-bit SIMD operations
41#[derive(Copy, Clone)]
42#[repr(transparent)]
43pub struct M128(pub(super) __m128i);
44
45impl M128 {
46	#[inline(always)]
47	pub const fn from_u128(val: u128) -> Self {
48		let mut result = Self::ZERO;
49		unsafe {
50			result.0 = std::mem::transmute_copy(&val);
51		}
52
53		result
54	}
55}
56
57impl From<__m128i> for M128 {
58	#[inline(always)]
59	fn from(value: __m128i) -> Self {
60		Self(value)
61	}
62}
63
64impl From<u128> for M128 {
65	fn from(value: u128) -> Self {
66		Self(unsafe {
67			// Safety: u128 is 16-byte aligned
68			assert_eq!(align_of::<u128>(), 16);
69			_mm_load_si128(&raw const value as *const __m128i)
70		})
71	}
72}
73
74impl From<u64> for M128 {
75	fn from(value: u64) -> Self {
76		Self::from(value as u128)
77	}
78}
79
80impl From<u32> for M128 {
81	fn from(value: u32) -> Self {
82		Self::from(value as u128)
83	}
84}
85
86impl From<u16> for M128 {
87	fn from(value: u16) -> Self {
88		Self::from(value as u128)
89	}
90}
91
92impl From<u8> for M128 {
93	fn from(value: u8) -> Self {
94		Self::from(value as u128)
95	}
96}
97
98impl<const N: usize> From<SmallU<N>> for M128 {
99	fn from(value: SmallU<N>) -> Self {
100		Self::from(value.val() as u128)
101	}
102}
103
104impl From<M128> for u128 {
105	fn from(value: M128) -> Self {
106		let mut result = 0u128;
107		unsafe {
108			// Safety: u128 is 16-byte aligned
109			assert_eq!(align_of::<u128>(), 16);
110			_mm_store_si128(&raw mut result as *mut __m128i, value.0)
111		};
112		result
113	}
114}
115
116impl From<M128> for __m128i {
117	#[inline(always)]
118	fn from(value: M128) -> Self {
119		value.0
120	}
121}
122
123impl SerializeBytes for M128 {
124	fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
125		assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
126
127		let raw_value: u128 = (*self).into();
128
129		write_buf.put_u128_le(raw_value);
130		Ok(())
131	}
132}
133
134impl DeserializeBytes for M128 {
135	fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
136	where
137		Self: Sized,
138	{
139		assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
140
141		let raw_value = read_buf.get_u128_le();
142
143		Ok(Self::from(raw_value))
144	}
145}
146
147impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
148impl_pack_scalar!(M128);
149
150impl<U: NumCast<u128>> NumCast<M128> for U {
151	#[inline(always)]
152	fn num_cast_from(val: M128) -> Self {
153		Self::num_cast_from(u128::from(val))
154	}
155}
156
157impl Default for M128 {
158	#[inline(always)]
159	fn default() -> Self {
160		Self(unsafe { _mm_setzero_si128() })
161	}
162}
163
164impl BitAnd for M128 {
165	type Output = Self;
166
167	#[inline(always)]
168	fn bitand(self, rhs: Self) -> Self::Output {
169		Self(unsafe { _mm_and_si128(self.0, rhs.0) })
170	}
171}
172
173impl BitAndAssign for M128 {
174	#[inline(always)]
175	fn bitand_assign(&mut self, rhs: Self) {
176		*self = *self & rhs
177	}
178}
179
180impl BitOr for M128 {
181	type Output = Self;
182
183	#[inline(always)]
184	fn bitor(self, rhs: Self) -> Self::Output {
185		Self(unsafe { _mm_or_si128(self.0, rhs.0) })
186	}
187}
188
189impl BitOrAssign for M128 {
190	#[inline(always)]
191	fn bitor_assign(&mut self, rhs: Self) {
192		*self = *self | rhs
193	}
194}
195
196impl BitXor for M128 {
197	type Output = Self;
198
199	#[inline(always)]
200	fn bitxor(self, rhs: Self) -> Self::Output {
201		Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
202	}
203}
204
205impl BitXorAssign for M128 {
206	#[inline(always)]
207	fn bitxor_assign(&mut self, rhs: Self) {
208		*self = *self ^ rhs;
209	}
210}
211
212impl Not for M128 {
213	type Output = Self;
214
215	fn not(self) -> Self::Output {
216		const ONES: __m128i = m128_from_u128!(u128::MAX);
217
218		self ^ Self(ONES)
219	}
220}
221
222/// `std::cmp::max` isn't const, so we need our own implementation
223pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
224	if left > right { left } else { right }
225}
226
227/// This solution shows 4X better performance.
228/// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed
229/// as constant and Rust currently doesn't allow passing expressions (`count - 64`) where variable
230/// is a generic constant parameter. Source: <https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688>
231macro_rules! bitshift_128b {
232	($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
233		unsafe {
234			let carry = $byte_shift($val, 8);
235			seq!(N in 64..128 {
236				if $shift == N {
237					return $bit_shift_64(
238						carry,
239						crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
240					).into();
241				}
242			});
243			seq!(N in 0..64 {
244				if $shift == N {
245					let carry = $bit_shift_64_opposite(
246						carry,
247						crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
248					);
249
250					let val = $bit_shift_64($val, N);
251					return $or(val, carry).into();
252				}
253			});
254
255			return Default::default()
256		}
257	};
258}
259
260pub(crate) use bitshift_128b;
261
262impl Shr<usize> for M128 {
263	type Output = Self;
264
265	#[inline(always)]
266	fn shr(self, rhs: usize) -> Self::Output {
267		// This implementation is effective when `rhs` is known at compile-time.
268		// In our code this is always the case.
269		bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
270	}
271}
272
273impl Shl<usize> for M128 {
274	type Output = Self;
275
276	#[inline(always)]
277	fn shl(self, rhs: usize) -> Self::Output {
278		// This implementation is effective when `rhs` is known at compile-time.
279		// In our code this is always the case.
280		bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
281	}
282}
283
284impl PartialEq for M128 {
285	fn eq(&self, other: &Self) -> bool {
286		unsafe {
287			let neq = _mm_xor_si128(self.0, other.0);
288			_mm_test_all_zeros(neq, neq) == 1
289		}
290	}
291}
292
293impl Eq for M128 {}
294
295impl PartialOrd for M128 {
296	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
297		Some(self.cmp(other))
298	}
299}
300
301impl Ord for M128 {
302	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
303		u128::from(*self).cmp(&u128::from(*other))
304	}
305}
306
307impl Distribution<M128> for StandardUniform {
308	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
309		M128(rng.random())
310	}
311}
312
313impl std::fmt::Display for M128 {
314	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315		let data: u128 = (*self).into();
316		write!(f, "{data:02X?}")
317	}
318}
319
320impl std::fmt::Debug for M128 {
321	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322		write!(f, "M128({self})")
323	}
324}
325
326#[repr(align(16))]
327pub struct AlignedData(pub [u128; 1]);
328
329macro_rules! m128_from_u128 {
330	($val:expr) => {{
331		let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
332		unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
333	}};
334}
335
336pub(super) use m128_from_u128;
337
338impl UnderlierType for M128 {
339	const LOG_BITS: usize = 7;
340}
341
342impl UnderlierWithBitOps for M128 {
343	const ZERO: Self = { Self(m128_from_u128!(0)) };
344	const ONE: Self = { Self(m128_from_u128!(1)) };
345	const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
346
347	#[inline(always)]
348	fn fill_with_bit(val: u8) -> Self {
349		assert!(val == 0 || val == 1);
350		Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
351	}
352
353	#[inline(always)]
354	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
355	where
356		T: UnderlierType,
357		Self: From<T>,
358	{
359		match T::BITS {
360			1 | 2 | 4 => {
361				let mut f = make_func_to_i8::<T, Self>(f);
362
363				unsafe {
364					_mm_set_epi8(
365						f(15),
366						f(14),
367						f(13),
368						f(12),
369						f(11),
370						f(10),
371						f(9),
372						f(8),
373						f(7),
374						f(6),
375						f(5),
376						f(4),
377						f(3),
378						f(2),
379						f(1),
380						f(0),
381					)
382				}
383				.into()
384			}
385			8 => {
386				let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
387				unsafe {
388					_mm_set_epi8(
389						f(15),
390						f(14),
391						f(13),
392						f(12),
393						f(11),
394						f(10),
395						f(9),
396						f(8),
397						f(7),
398						f(6),
399						f(5),
400						f(4),
401						f(3),
402						f(2),
403						f(1),
404						f(0),
405					)
406				}
407				.into()
408			}
409			16 => {
410				let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
411				unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
412			}
413			32 => {
414				let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
415				unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
416			}
417			64 => {
418				let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
419				unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
420			}
421			128 => Self::from(f(0)),
422			_ => panic!("unsupported bit count"),
423		}
424	}
425
426	#[inline(always)]
427	unsafe fn get_subvalue<T>(&self, i: usize) -> T
428	where
429		T: UnderlierType + NumCast<Self>,
430	{
431		match T::BITS {
432			1 | 2 | 4 => {
433				let elements_in_8 = 8 / T::BITS;
434				let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
435					*arr.get_unchecked(i / elements_in_8)
436				});
437
438				let shift = (i % elements_in_8) * T::BITS;
439				value_u8 >>= shift;
440
441				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
442			}
443			8 => {
444				let value_u8 =
445					as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
446				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
447			}
448			16 => {
449				let value_u16 =
450					as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
451				T::from_underlier(T::num_cast_from(Self::from(value_u16)))
452			}
453			32 => {
454				let value_u32 =
455					as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
456				T::from_underlier(T::num_cast_from(Self::from(value_u32)))
457			}
458			64 => {
459				let value_u64 =
460					as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
461				T::from_underlier(T::num_cast_from(Self::from(value_u64)))
462			}
463			128 => T::from_underlier(T::num_cast_from(*self)),
464			_ => panic!("unsupported bit count"),
465		}
466	}
467
468	#[inline(always)]
469	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
470	where
471		T: UnderlierWithBitOps,
472		Self: From<T>,
473	{
474		match T::BITS {
475			1 | 2 | 4 => {
476				let elements_in_8 = 8 / T::BITS;
477				let mask = (1u8 << T::BITS) - 1;
478				let shift = (i % elements_in_8) * T::BITS;
479				let val = u8::num_cast_from(Self::from(val)) << shift;
480				let mask = mask << shift;
481
482				as_array_mut::<_, u8, 16>(self, |array| unsafe {
483					let element = array.get_unchecked_mut(i / elements_in_8);
484					*element &= !mask;
485					*element |= val;
486				});
487			}
488			8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
489				*array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
490			}),
491			16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
492				*array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
493			}),
494			32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
495				*array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
496			}),
497			64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
498				*array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
499			}),
500			128 => {
501				*self = Self::from(val);
502			}
503			_ => panic!("unsupported bit count"),
504		}
505	}
506
507	#[inline(always)]
508	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
509	where
510		T: UnderlierWithBitOps + NumCast<Self>,
511		Self: From<T>,
512	{
513		match T::LOG_BITS {
514			0 => match log_block_len {
515				0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
516				1 => unsafe {
517					let bits: [u8; 2] =
518						array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
519
520					_mm_set_epi64x(
521						u64::fill_with_bit(bits[1]) as i64,
522						u64::fill_with_bit(bits[0]) as i64,
523					)
524					.into()
525				},
526				2 => unsafe {
527					let bits: [u8; 4] =
528						array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
529
530					_mm_set_epi32(
531						u32::fill_with_bit(bits[3]) as i32,
532						u32::fill_with_bit(bits[2]) as i32,
533						u32::fill_with_bit(bits[1]) as i32,
534						u32::fill_with_bit(bits[0]) as i32,
535					)
536					.into()
537				},
538				3 => unsafe {
539					let bits: [u8; 8] =
540						array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
541
542					_mm_set_epi16(
543						u16::fill_with_bit(bits[7]) as i16,
544						u16::fill_with_bit(bits[6]) as i16,
545						u16::fill_with_bit(bits[5]) as i16,
546						u16::fill_with_bit(bits[4]) as i16,
547						u16::fill_with_bit(bits[3]) as i16,
548						u16::fill_with_bit(bits[2]) as i16,
549						u16::fill_with_bit(bits[1]) as i16,
550						u16::fill_with_bit(bits[0]) as i16,
551					)
552					.into()
553				},
554				4 => unsafe {
555					let bits: [u8; 16] =
556						array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
557
558					_mm_set_epi8(
559						u8::fill_with_bit(bits[15]) as i8,
560						u8::fill_with_bit(bits[14]) as i8,
561						u8::fill_with_bit(bits[13]) as i8,
562						u8::fill_with_bit(bits[12]) as i8,
563						u8::fill_with_bit(bits[11]) as i8,
564						u8::fill_with_bit(bits[10]) as i8,
565						u8::fill_with_bit(bits[9]) as i8,
566						u8::fill_with_bit(bits[8]) as i8,
567						u8::fill_with_bit(bits[7]) as i8,
568						u8::fill_with_bit(bits[6]) as i8,
569						u8::fill_with_bit(bits[5]) as i8,
570						u8::fill_with_bit(bits[4]) as i8,
571						u8::fill_with_bit(bits[3]) as i8,
572						u8::fill_with_bit(bits[2]) as i8,
573						u8::fill_with_bit(bits[1]) as i8,
574						u8::fill_with_bit(bits[0]) as i8,
575					)
576					.into()
577				},
578				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
579			},
580			1 => match log_block_len {
581				0 => unsafe {
582					let value =
583						U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
584
585					_mm_set1_epi8(value as i8).into()
586				},
587				1 => {
588					let bytes: [u8; 2] = array::from_fn(|i| {
589						U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
590					});
591
592					Self::from_fn::<u8>(|i| bytes[i / 8])
593				}
594				2 => {
595					let bytes: [u8; 4] = array::from_fn(|i| {
596						U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
597					});
598
599					Self::from_fn::<u8>(|i| bytes[i / 4])
600				}
601				3 => {
602					let bytes: [u8; 8] = array::from_fn(|i| {
603						U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
604							.spread_to_byte()
605					});
606
607					Self::from_fn::<u8>(|i| bytes[i / 2])
608				}
609				4 => {
610					let bytes: [u8; 16] = array::from_fn(|i| {
611						U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
612							.spread_to_byte()
613					});
614
615					Self::from_fn::<u8>(|i| bytes[i])
616				}
617				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
618			},
619			2 => match log_block_len {
620				0 => {
621					let value =
622						U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
623
624					unsafe { _mm_set1_epi8(value as i8).into() }
625				}
626				1 => {
627					let values: [u8; 2] = array::from_fn(|i| {
628						U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
629					});
630
631					Self::from_fn::<u8>(|i| values[i / 8])
632				}
633				2 => {
634					let values: [u8; 4] = array::from_fn(|i| {
635						U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
636							.spread_to_byte()
637					});
638
639					Self::from_fn::<u8>(|i| values[i / 4])
640				}
641				3 => {
642					let values: [u8; 8] = array::from_fn(|i| {
643						U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
644							.spread_to_byte()
645					});
646
647					Self::from_fn::<u8>(|i| values[i / 2])
648				}
649				4 => {
650					let values: [u8; 16] = array::from_fn(|i| {
651						U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
652							.spread_to_byte()
653					});
654
655					Self::from_fn::<u8>(|i| values[i])
656				}
657				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
658			},
659			3 => match log_block_len {
660				0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
661				1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
662				2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
663				3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
664				4 => self,
665				_ => panic!("unsupported block length"),
666			},
667			4 => match log_block_len {
668				0 => {
669					let value = (u128::from(self) >> (block_idx * 16)) as u16;
670
671					unsafe { _mm_set1_epi16(value as i16).into() }
672				}
673				1 => {
674					let values: [u16; 2] =
675						array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
676
677					Self::from_fn::<u16>(|i| values[i / 4])
678				}
679				2 => {
680					let values: [u16; 4] =
681						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
682
683					Self::from_fn::<u16>(|i| values[i / 2])
684				}
685				3 => self,
686				_ => panic!("unsupported block length"),
687			},
688			5 => match log_block_len {
689				0 => unsafe {
690					let value = (u128::from(self) >> (block_idx * 32)) as u32;
691
692					_mm_set1_epi32(value as i32).into()
693				},
694				1 => {
695					let values: [u32; 2] =
696						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
697
698					Self::from_fn::<u32>(|i| values[i / 2])
699				}
700				2 => self,
701				_ => panic!("unsupported block length"),
702			},
703			6 => match log_block_len {
704				0 => unsafe {
705					let value = (u128::from(self) >> (block_idx * 64)) as u64;
706
707					_mm_set1_epi64x(value as i64).into()
708				},
709				1 => self,
710				_ => panic!("unsupported block length"),
711			},
712			7 => self,
713			_ => panic!("unsupported bit length"),
714		}
715	}
716
717	#[inline]
718	fn shl_128b_lanes(self, shift: usize) -> Self {
719		self << shift
720	}
721
722	#[inline]
723	fn shr_128b_lanes(self, shift: usize) -> Self {
724		self >> shift
725	}
726
727	#[inline]
728	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
729		match log_block_len {
730			0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
731			3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
732			4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
733			5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
734			6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
735			_ => panic!("unsupported block length"),
736		}
737	}
738
739	#[inline]
740	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
741		match log_block_len {
742			0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
743			3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
744			4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
745			5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
746			6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
747			_ => panic!("unsupported block length"),
748		}
749	}
750}
751
752unsafe impl Zeroable for M128 {}
753
754unsafe impl Pod for M128 {}
755
756unsafe impl Send for M128 {}
757
758unsafe impl Sync for M128 {}
759
760static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
761static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
762static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
763static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
764
765const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
766	log_block_len: usize,
767	t_log_bits: usize,
768) -> [M128; BLOCK_IDX_AMOUNT] {
769	let element_log_width = t_log_bits - 3;
770
771	let element_width = 1 << element_log_width;
772
773	let block_size = 1 << (log_block_len + element_log_width);
774	let repeat = 1 << (4 - element_log_width - log_block_len);
775	let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
776
777	let mut block_idx = 0;
778
779	while block_idx < BLOCK_IDX_AMOUNT {
780		let base = block_idx * block_size;
781		let mut j = 0;
782		while j < 16 {
783			masks[block_idx][j] =
784				(base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
785			j += 1;
786		}
787		block_idx += 1;
788	}
789	let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
790
791	let mut block_idx = 0;
792
793	while block_idx < BLOCK_IDX_AMOUNT {
794		m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
795		block_idx += 1;
796	}
797
798	m128_masks
799}
800
801impl UnderlierWithBitConstants for M128 {
802	const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
803		Self::from_u128(interleave_mask_even!(u128, 0)),
804		Self::from_u128(interleave_mask_even!(u128, 1)),
805		Self::from_u128(interleave_mask_even!(u128, 2)),
806		Self::from_u128(interleave_mask_even!(u128, 3)),
807		Self::from_u128(interleave_mask_even!(u128, 4)),
808		Self::from_u128(interleave_mask_even!(u128, 5)),
809		Self::from_u128(interleave_mask_even!(u128, 6)),
810	];
811
812	const INTERLEAVE_ODD_MASK: &'static [Self] = &[
813		Self::from_u128(interleave_mask_odd!(u128, 0)),
814		Self::from_u128(interleave_mask_odd!(u128, 1)),
815		Self::from_u128(interleave_mask_odd!(u128, 2)),
816		Self::from_u128(interleave_mask_odd!(u128, 3)),
817		Self::from_u128(interleave_mask_odd!(u128, 4)),
818		Self::from_u128(interleave_mask_odd!(u128, 5)),
819		Self::from_u128(interleave_mask_odd!(u128, 6)),
820	];
821
822	#[inline(always)]
823	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
824		unsafe {
825			let (c, d) = interleave_bits(
826				Into::<Self>::into(self).into(),
827				Into::<Self>::into(other).into(),
828				log_block_len,
829			);
830			(Self::from(c), Self::from(d))
831		}
832	}
833}
834
835impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
836	fn from(value: __m128i) -> Self {
837		M128::from(value).into()
838	}
839}
840
841impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
842	fn from(value: u128) -> Self {
843		M128::from(value).into()
844	}
845}
846
847impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
848	fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
849		value.to_underlier().into()
850	}
851}
852
853impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
854where
855	u128: From<Scalar::Underlier>,
856{
857	#[inline(always)]
858	fn broadcast(scalar: Scalar) -> Self {
859		let tower_level = Scalar::N_BITS.ilog2() as usize;
860		match tower_level {
861			0..=3 => {
862				let mut value = u128::from(scalar.to_underlier()) as u8;
863				for n in tower_level..3 {
864					value |= value << (1 << n);
865				}
866
867				unsafe { _mm_set1_epi8(value as i8) }.into()
868			}
869			4 => {
870				let value = u128::from(scalar.to_underlier()) as u16;
871				unsafe { _mm_set1_epi16(value as i16) }.into()
872			}
873			5 => {
874				let value = u128::from(scalar.to_underlier()) as u32;
875				unsafe { _mm_set1_epi32(value as i32) }.into()
876			}
877			6 => {
878				let value = u128::from(scalar.to_underlier()) as u64;
879				unsafe { _mm_set1_epi64x(value as i64) }.into()
880			}
881			7 => {
882				let value = u128::from(scalar.to_underlier());
883				value.into()
884			}
885			_ => {
886				unreachable!("invalid tower level")
887			}
888		}
889	}
890}
891
892#[inline]
893unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
894	match log_block_len {
895		0 => unsafe {
896			let mask = _mm_set1_epi8(0x55i8);
897			interleave_bits_imm::<1>(a, b, mask)
898		},
899		1 => unsafe {
900			let mask = _mm_set1_epi8(0x33i8);
901			interleave_bits_imm::<2>(a, b, mask)
902		},
903		2 => unsafe {
904			let mask = _mm_set1_epi8(0x0fi8);
905			interleave_bits_imm::<4>(a, b, mask)
906		},
907		3 => unsafe {
908			let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
909			let a = _mm_shuffle_epi8(a, shuffle);
910			let b = _mm_shuffle_epi8(b, shuffle);
911			let a_prime = _mm_unpacklo_epi8(a, b);
912			let b_prime = _mm_unpackhi_epi8(a, b);
913			(a_prime, b_prime)
914		},
915		4 => unsafe {
916			let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
917			let a = _mm_shuffle_epi8(a, shuffle);
918			let b = _mm_shuffle_epi8(b, shuffle);
919			let a_prime = _mm_unpacklo_epi16(a, b);
920			let b_prime = _mm_unpackhi_epi16(a, b);
921			(a_prime, b_prime)
922		},
923		5 => unsafe {
924			let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
925			let a = _mm_shuffle_epi8(a, shuffle);
926			let b = _mm_shuffle_epi8(b, shuffle);
927			let a_prime = _mm_unpacklo_epi32(a, b);
928			let b_prime = _mm_unpackhi_epi32(a, b);
929			(a_prime, b_prime)
930		},
931		6 => unsafe {
932			let a_prime = _mm_unpacklo_epi64(a, b);
933			let b_prime = _mm_unpackhi_epi64(a, b);
934			(a_prime, b_prime)
935		},
936		_ => panic!("unsupported block length"),
937	}
938}
939
940#[inline]
941unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
942	a: __m128i,
943	b: __m128i,
944	mask: __m128i,
945) -> (__m128i, __m128i) {
946	unsafe {
947		let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
948		let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
949		let b_prime = _mm_xor_si128(b, t);
950		(a_prime, b_prime)
951	}
952}
953
954impl_iteration!(M128,
955	@strategy BitIterationStrategy, U1,
956	@strategy FallbackStrategy, U2, U4,
957	@strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
958);
959
960#[cfg(test)]
961mod tests {
962	use binius_utils::bytes::BytesMut;
963	use proptest::{arbitrary::any, proptest};
964	use rand::{SeedableRng, rngs::StdRng};
965
966	use super::*;
967	use crate::underlier::single_element_mask_bits;
968
969	fn check_roundtrip<T>(val: M128)
970	where
971		T: From<M128>,
972		M128: From<T>,
973	{
974		assert_eq!(M128::from(T::from(val)), val);
975	}
976
977	#[test]
978	fn test_constants() {
979		assert_eq!(M128::default(), M128::ZERO);
980		assert_eq!(M128::from(0u128), M128::ZERO);
981		assert_eq!(M128::from(1u128), M128::ONE);
982	}
983
984	fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
985		(value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
986	}
987
988	proptest! {
989		#[test]
990		fn test_conversion(a in any::<u128>()) {
991			check_roundtrip::<u128>(a.into());
992			check_roundtrip::<__m128i>(a.into());
993		}
994
995		#[test]
996		fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
997			assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
998			assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
999			assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1000		}
1001
1002		#[test]
1003		fn test_negate(a in any::<u128>()) {
1004			assert_eq!(M128::from(!a), !M128::from(a))
1005		}
1006
1007		#[test]
1008		fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1009			assert_eq!(M128::from(a << b), M128::from(a) << b);
1010			assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1011		}
1012
1013		#[test]
1014		fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1015			let a = M128::from(a);
1016			let b = M128::from(b);
1017
1018			let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1019			let (c, d) = (M128::from(c), M128::from(d));
1020
1021			for i in (0..128>>height).step_by(2) {
1022				assert_eq!(get(c, height, i), get(a, height, i));
1023				assert_eq!(get(c, height, i+1), get(b, height, i));
1024				assert_eq!(get(d, height, i), get(a, height, i+1));
1025				assert_eq!(get(d, height, i+1), get(b, height, i+1));
1026			}
1027		}
1028
1029		#[test]
1030		fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1031			let a = M128::from(a);
1032			let b = M128::from(b);
1033
1034			let result = a.unpack_lo_128b_lanes(b, height);
1035			for i in 0..128>>(height + 1) {
1036				assert_eq!(get(result, height, 2*i), get(a, height, i));
1037				assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1038			}
1039		}
1040
1041		#[test]
1042		fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1043			let a = M128::from(a);
1044			let b = M128::from(b);
1045
1046			let result = a.unpack_hi_128b_lanes(b, height);
1047			let half_block_count = 128>>(height + 1);
1048			for i in 0..half_block_count {
1049				assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1050				assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1051			}
1052		}
1053	}
1054
1055	#[test]
1056	fn test_fill_with_bit() {
1057		assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1058		assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1059	}
1060
1061	#[test]
1062	fn test_eq() {
1063		let a = M128::from(0u128);
1064		let b = M128::from(42u128);
1065		let c = M128::from(u128::MAX);
1066
1067		assert_eq!(a, a);
1068		assert_eq!(b, b);
1069		assert_eq!(c, c);
1070
1071		assert_ne!(a, b);
1072		assert_ne!(a, c);
1073		assert_ne!(b, c);
1074	}
1075
1076	#[test]
1077	fn test_serialize_and_deserialize_m128() {
1078		let mut rng = StdRng::from_seed([0; 32]);
1079
1080		let original_value = M128::from(rng.random::<u128>());
1081
1082		let mut buf = BytesMut::new();
1083		original_value.serialize(&mut buf).unwrap();
1084
1085		let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1086
1087		assert_eq!(original_value, deserialized_value);
1088	}
1089}