binius_field/arch/x86_64/
m128.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	arch::x86_64::*,
5	array,
6	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use bytemuck::{must_cast, Pod, Zeroable};
10use rand::{Rng, RngCore};
11use seq_macro::seq;
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
13
14use crate::{
15	arch::{
16		binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
17		portable::{
18			packed::{impl_pack_scalar, PackedPrimitiveType},
19			packed_arithmetic::{
20				interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants,
21			},
22		},
23	},
24	arithmetic_traits::Broadcast,
25	underlier::{
26		impl_divisible, impl_iteration, spread_fallback, unpack_hi_128b_fallback,
27		unpack_lo_128b_fallback, NumCast, Random, SmallU, SpreadToByte, UnderlierType,
28		UnderlierWithBitOps, WithUnderlier, U1, U2, U4,
29	},
30	BinaryField,
31};
32
33/// 128-bit value that is used for 128-bit SIMD operations
34#[derive(Copy, Clone)]
35#[repr(transparent)]
36pub struct M128(pub(super) __m128i);
37
38impl M128 {
39	#[inline(always)]
40	pub const fn from_u128(val: u128) -> Self {
41		let mut result = Self::ZERO;
42		unsafe {
43			result.0 = std::mem::transmute_copy(&val);
44		}
45
46		result
47	}
48}
49
50impl From<__m128i> for M128 {
51	#[inline(always)]
52	fn from(value: __m128i) -> Self {
53		Self(value)
54	}
55}
56
57impl From<u128> for M128 {
58	fn from(value: u128) -> Self {
59		Self(unsafe { _mm_loadu_si128(&raw const value as *const __m128i) })
60	}
61}
62
63impl From<u64> for M128 {
64	fn from(value: u64) -> Self {
65		Self::from(value as u128)
66	}
67}
68
69impl From<u32> for M128 {
70	fn from(value: u32) -> Self {
71		Self::from(value as u128)
72	}
73}
74
75impl From<u16> for M128 {
76	fn from(value: u16) -> Self {
77		Self::from(value as u128)
78	}
79}
80
81impl From<u8> for M128 {
82	fn from(value: u8) -> Self {
83		Self::from(value as u128)
84	}
85}
86
87impl<const N: usize> From<SmallU<N>> for M128 {
88	fn from(value: SmallU<N>) -> Self {
89		Self::from(value.val() as u128)
90	}
91}
92
93impl From<M128> for u128 {
94	fn from(value: M128) -> Self {
95		let mut result = 0u128;
96		unsafe { _mm_storeu_si128(&raw mut result as *mut __m128i, value.0) };
97
98		result
99	}
100}
101
102impl From<M128> for __m128i {
103	#[inline(always)]
104	fn from(value: M128) -> Self {
105		value.0
106	}
107}
108
109impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
110impl_pack_scalar!(M128);
111
112impl<U: NumCast<u128>> NumCast<M128> for U {
113	#[inline(always)]
114	fn num_cast_from(val: M128) -> Self {
115		Self::num_cast_from(u128::from(val))
116	}
117}
118
119impl Default for M128 {
120	#[inline(always)]
121	fn default() -> Self {
122		Self(unsafe { _mm_setzero_si128() })
123	}
124}
125
126impl BitAnd for M128 {
127	type Output = Self;
128
129	#[inline(always)]
130	fn bitand(self, rhs: Self) -> Self::Output {
131		Self(unsafe { _mm_and_si128(self.0, rhs.0) })
132	}
133}
134
135impl BitAndAssign for M128 {
136	#[inline(always)]
137	fn bitand_assign(&mut self, rhs: Self) {
138		*self = *self & rhs
139	}
140}
141
142impl BitOr for M128 {
143	type Output = Self;
144
145	#[inline(always)]
146	fn bitor(self, rhs: Self) -> Self::Output {
147		Self(unsafe { _mm_or_si128(self.0, rhs.0) })
148	}
149}
150
151impl BitOrAssign for M128 {
152	#[inline(always)]
153	fn bitor_assign(&mut self, rhs: Self) {
154		*self = *self | rhs
155	}
156}
157
158impl BitXor for M128 {
159	type Output = Self;
160
161	#[inline(always)]
162	fn bitxor(self, rhs: Self) -> Self::Output {
163		Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
164	}
165}
166
167impl BitXorAssign for M128 {
168	#[inline(always)]
169	fn bitxor_assign(&mut self, rhs: Self) {
170		*self = *self ^ rhs;
171	}
172}
173
174impl Not for M128 {
175	type Output = Self;
176
177	fn not(self) -> Self::Output {
178		const ONES: __m128i = m128_from_u128!(u128::MAX);
179
180		self ^ Self(ONES)
181	}
182}
183
184/// `std::cmp::max` isn't const, so we need our own implementation
185pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
186	if left > right {
187		left
188	} else {
189		right
190	}
191}
192
193/// This solution shows 4X better performance.
194/// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed as constant
195/// and Rust currently doesn't allow passing expressions (`count - 64`) where variable is a generic constant parameter.
196/// Source: https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688
197macro_rules! bitshift_128b {
198	($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
199		unsafe {
200			let carry = $byte_shift($val, 8);
201			seq!(N in 64..128 {
202				if $shift == N {
203					return $bit_shift_64(
204						carry,
205						crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
206					).into();
207				}
208			});
209			seq!(N in 0..64 {
210				if $shift == N {
211					let carry = $bit_shift_64_opposite(
212						carry,
213						crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
214					);
215
216					let val = $bit_shift_64($val, N);
217					return $or(val, carry).into();
218				}
219			});
220
221			return Default::default()
222		}
223	};
224}
225
226pub(crate) use bitshift_128b;
227
228impl Shr<usize> for M128 {
229	type Output = Self;
230
231	#[inline(always)]
232	fn shr(self, rhs: usize) -> Self::Output {
233		// This implementation is effective when `rhs` is known at compile-time.
234		// In our code this is always the case.
235		bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
236	}
237}
238
239impl Shl<usize> for M128 {
240	type Output = Self;
241
242	#[inline(always)]
243	fn shl(self, rhs: usize) -> Self::Output {
244		// This implementation is effective when `rhs` is known at compile-time.
245		// In our code this is always the case.
246		bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
247	}
248}
249
250impl PartialEq for M128 {
251	fn eq(&self, other: &Self) -> bool {
252		unsafe {
253			let neq = _mm_xor_si128(self.0, other.0);
254			_mm_test_all_zeros(neq, neq) == 1
255		}
256	}
257}
258
259impl Eq for M128 {}
260
261impl PartialOrd for M128 {
262	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
263		Some(self.cmp(other))
264	}
265}
266
267impl Ord for M128 {
268	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
269		u128::from(*self).cmp(&u128::from(*other))
270	}
271}
272
273impl ConstantTimeEq for M128 {
274	fn ct_eq(&self, other: &Self) -> Choice {
275		unsafe {
276			let neq = _mm_xor_si128(self.0, other.0);
277			Choice::from(_mm_test_all_zeros(neq, neq) as u8)
278		}
279	}
280}
281
282impl ConditionallySelectable for M128 {
283	fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
284		ConditionallySelectable::conditional_select(&u128::from(*a), &u128::from(*b), choice).into()
285	}
286}
287
288impl Random for M128 {
289	fn random(mut rng: impl RngCore) -> Self {
290		let val: u128 = rng.gen();
291		val.into()
292	}
293}
294
295impl std::fmt::Display for M128 {
296	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
297		let data: u128 = (*self).into();
298		write!(f, "{data:02X?}")
299	}
300}
301
302impl std::fmt::Debug for M128 {
303	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
304		write!(f, "M128({})", self)
305	}
306}
307
308#[repr(align(16))]
309pub struct AlignedData(pub [u128; 1]);
310
311macro_rules! m128_from_u128 {
312	($val:expr) => {{
313		let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
314		unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
315	}};
316}
317
318pub(super) use m128_from_u128;
319
320impl UnderlierType for M128 {
321	const LOG_BITS: usize = 7;
322}
323
324impl UnderlierWithBitOps for M128 {
325	const ZERO: Self = { Self(m128_from_u128!(0)) };
326	const ONE: Self = { Self(m128_from_u128!(1)) };
327	const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
328
329	#[inline(always)]
330	fn fill_with_bit(val: u8) -> Self {
331		assert!(val == 0 || val == 1);
332		Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
333	}
334
335	#[inline(always)]
336	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
337	where
338		T: UnderlierType,
339		Self: From<T>,
340	{
341		match T::BITS {
342			1 | 2 | 4 => {
343				let mut f = make_func_to_i8::<T, Self>(f);
344
345				unsafe {
346					_mm_set_epi8(
347						f(15),
348						f(14),
349						f(13),
350						f(12),
351						f(11),
352						f(10),
353						f(9),
354						f(8),
355						f(7),
356						f(6),
357						f(5),
358						f(4),
359						f(3),
360						f(2),
361						f(1),
362						f(0),
363					)
364				}
365				.into()
366			}
367			8 => {
368				let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
369				unsafe {
370					_mm_set_epi8(
371						f(15),
372						f(14),
373						f(13),
374						f(12),
375						f(11),
376						f(10),
377						f(9),
378						f(8),
379						f(7),
380						f(6),
381						f(5),
382						f(4),
383						f(3),
384						f(2),
385						f(1),
386						f(0),
387					)
388				}
389				.into()
390			}
391			16 => {
392				let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
393				unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
394			}
395			32 => {
396				let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
397				unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
398			}
399			64 => {
400				let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
401				unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
402			}
403			128 => Self::from(f(0)),
404			_ => panic!("unsupported bit count"),
405		}
406	}
407
408	#[inline(always)]
409	unsafe fn get_subvalue<T>(&self, i: usize) -> T
410	where
411		T: UnderlierType + NumCast<Self>,
412	{
413		match T::BITS {
414			1 | 2 | 4 => {
415				let elements_in_8 = 8 / T::BITS;
416				let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
417					*arr.get_unchecked(i / elements_in_8)
418				});
419
420				let shift = (i % elements_in_8) * T::BITS;
421				value_u8 >>= shift;
422
423				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
424			}
425			8 => {
426				let value_u8 =
427					as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
428				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
429			}
430			16 => {
431				let value_u16 =
432					as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
433				T::from_underlier(T::num_cast_from(Self::from(value_u16)))
434			}
435			32 => {
436				let value_u32 =
437					as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
438				T::from_underlier(T::num_cast_from(Self::from(value_u32)))
439			}
440			64 => {
441				let value_u64 =
442					as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
443				T::from_underlier(T::num_cast_from(Self::from(value_u64)))
444			}
445			128 => T::from_underlier(T::num_cast_from(*self)),
446			_ => panic!("unsupported bit count"),
447		}
448	}
449
450	#[inline(always)]
451	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
452	where
453		T: UnderlierWithBitOps,
454		Self: From<T>,
455	{
456		match T::BITS {
457			1 | 2 | 4 => {
458				let elements_in_8 = 8 / T::BITS;
459				let mask = (1u8 << T::BITS) - 1;
460				let shift = (i % elements_in_8) * T::BITS;
461				let val = u8::num_cast_from(Self::from(val)) << shift;
462				let mask = mask << shift;
463
464				as_array_mut::<_, u8, 16>(self, |array| unsafe {
465					let element = array.get_unchecked_mut(i / elements_in_8);
466					*element &= !mask;
467					*element |= val;
468				});
469			}
470			8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
471				*array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
472			}),
473			16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
474				*array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
475			}),
476			32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
477				*array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
478			}),
479			64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
480				*array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
481			}),
482			128 => {
483				*self = Self::from(val);
484			}
485			_ => panic!("unsupported bit count"),
486		}
487	}
488
489	#[inline(always)]
490	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
491	where
492		T: UnderlierWithBitOps + NumCast<Self>,
493		Self: From<T>,
494	{
495		match T::LOG_BITS {
496			0 => match log_block_len {
497				0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
498				1 => {
499					let bits: [u8; 2] =
500						array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
501
502					_mm_set_epi64x(
503						u64::fill_with_bit(bits[1]) as i64,
504						u64::fill_with_bit(bits[0]) as i64,
505					)
506					.into()
507				}
508				2 => {
509					let bits: [u8; 4] =
510						array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
511
512					_mm_set_epi32(
513						u32::fill_with_bit(bits[3]) as i32,
514						u32::fill_with_bit(bits[2]) as i32,
515						u32::fill_with_bit(bits[1]) as i32,
516						u32::fill_with_bit(bits[0]) as i32,
517					)
518					.into()
519				}
520				3 => {
521					let bits: [u8; 8] =
522						array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
523
524					_mm_set_epi16(
525						u16::fill_with_bit(bits[7]) as i16,
526						u16::fill_with_bit(bits[6]) as i16,
527						u16::fill_with_bit(bits[5]) as i16,
528						u16::fill_with_bit(bits[4]) as i16,
529						u16::fill_with_bit(bits[3]) as i16,
530						u16::fill_with_bit(bits[2]) as i16,
531						u16::fill_with_bit(bits[1]) as i16,
532						u16::fill_with_bit(bits[0]) as i16,
533					)
534					.into()
535				}
536				4 => {
537					let bits: [u8; 16] =
538						array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
539
540					_mm_set_epi8(
541						u8::fill_with_bit(bits[15]) as i8,
542						u8::fill_with_bit(bits[14]) as i8,
543						u8::fill_with_bit(bits[13]) as i8,
544						u8::fill_with_bit(bits[12]) as i8,
545						u8::fill_with_bit(bits[11]) as i8,
546						u8::fill_with_bit(bits[10]) as i8,
547						u8::fill_with_bit(bits[9]) as i8,
548						u8::fill_with_bit(bits[8]) as i8,
549						u8::fill_with_bit(bits[7]) as i8,
550						u8::fill_with_bit(bits[6]) as i8,
551						u8::fill_with_bit(bits[5]) as i8,
552						u8::fill_with_bit(bits[4]) as i8,
553						u8::fill_with_bit(bits[3]) as i8,
554						u8::fill_with_bit(bits[2]) as i8,
555						u8::fill_with_bit(bits[1]) as i8,
556						u8::fill_with_bit(bits[0]) as i8,
557					)
558					.into()
559				}
560				_ => spread_fallback(self, log_block_len, block_idx),
561			},
562			1 => match log_block_len {
563				0 => {
564					let value =
565						U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
566
567					_mm_set1_epi8(value as i8).into()
568				}
569				1 => {
570					let bytes: [u8; 2] = array::from_fn(|i| {
571						U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
572					});
573
574					Self::from_fn::<u8>(|i| bytes[i / 8])
575				}
576				2 => {
577					let bytes: [u8; 4] = array::from_fn(|i| {
578						U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
579					});
580
581					Self::from_fn::<u8>(|i| bytes[i / 4])
582				}
583				3 => {
584					let bytes: [u8; 8] = array::from_fn(|i| {
585						U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
586							.spread_to_byte()
587					});
588
589					Self::from_fn::<u8>(|i| bytes[i / 2])
590				}
591				4 => {
592					let bytes: [u8; 16] = array::from_fn(|i| {
593						U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
594							.spread_to_byte()
595					});
596
597					Self::from_fn::<u8>(|i| bytes[i])
598				}
599				_ => spread_fallback(self, log_block_len, block_idx),
600			},
601			2 => match log_block_len {
602				0 => {
603					let value =
604						U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
605
606					_mm_set1_epi8(value as i8).into()
607				}
608				1 => {
609					let values: [u8; 2] = array::from_fn(|i| {
610						U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
611					});
612
613					Self::from_fn::<u8>(|i| values[i / 8])
614				}
615				2 => {
616					let values: [u8; 4] = array::from_fn(|i| {
617						U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
618							.spread_to_byte()
619					});
620
621					Self::from_fn::<u8>(|i| values[i / 4])
622				}
623				3 => {
624					let values: [u8; 8] = array::from_fn(|i| {
625						U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
626							.spread_to_byte()
627					});
628
629					Self::from_fn::<u8>(|i| values[i / 2])
630				}
631				4 => {
632					let values: [u8; 16] = array::from_fn(|i| {
633						U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
634							.spread_to_byte()
635					});
636
637					Self::from_fn::<u8>(|i| values[i])
638				}
639				_ => spread_fallback(self, log_block_len, block_idx),
640			},
641			3 => match log_block_len {
642				0 => _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into(),
643				1 => _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into(),
644				2 => _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into(),
645				3 => _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into(),
646				4 => self,
647				_ => panic!("unsupported block length"),
648			},
649			4 => match log_block_len {
650				0 => {
651					let value = (u128::from(self) >> (block_idx * 16)) as u16;
652
653					_mm_set1_epi16(value as i16).into()
654				}
655				1 => {
656					let values: [u16; 2] =
657						array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
658
659					Self::from_fn::<u16>(|i| values[i / 4])
660				}
661				2 => {
662					let values: [u16; 4] =
663						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
664
665					Self::from_fn::<u16>(|i| values[i / 2])
666				}
667				3 => self,
668				_ => panic!("unsupported block length"),
669			},
670			5 => match log_block_len {
671				0 => {
672					let value = (u128::from(self) >> (block_idx * 32)) as u32;
673
674					_mm_set1_epi32(value as i32).into()
675				}
676				1 => {
677					let values: [u32; 2] =
678						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
679
680					Self::from_fn::<u32>(|i| values[i / 2])
681				}
682				2 => self,
683				_ => panic!("unsupported block length"),
684			},
685			6 => match log_block_len {
686				0 => {
687					let value = (u128::from(self) >> (block_idx * 64)) as u64;
688
689					_mm_set1_epi64x(value as i64).into()
690				}
691				1 => self,
692				_ => panic!("unsupported block length"),
693			},
694			7 => self,
695			_ => panic!("unsupported bit length"),
696		}
697	}
698
699	#[inline]
700	fn shl_128b_lanes(self, shift: usize) -> Self {
701		self << shift
702	}
703
704	#[inline]
705	fn shr_128b_lanes(self, shift: usize) -> Self {
706		self >> shift
707	}
708
709	#[inline]
710	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
711		match log_block_len {
712			0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
713			3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
714			4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
715			5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
716			6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
717			_ => panic!("unsupported block length"),
718		}
719	}
720
721	#[inline]
722	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
723		match log_block_len {
724			0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
725			3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
726			4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
727			5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
728			6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
729			_ => panic!("unsupported block length"),
730		}
731	}
732}
733
734unsafe impl Zeroable for M128 {}
735
736unsafe impl Pod for M128 {}
737
738unsafe impl Send for M128 {}
739
740unsafe impl Sync for M128 {}
741
742static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
743static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
744static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
745static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
746
747const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
748	log_block_len: usize,
749	t_log_bits: usize,
750) -> [M128; BLOCK_IDX_AMOUNT] {
751	let element_log_width = t_log_bits - 3;
752
753	let element_width = 1 << element_log_width;
754
755	let block_size = 1 << (log_block_len + element_log_width);
756	let repeat = 1 << (4 - element_log_width - log_block_len);
757	let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
758
759	let mut block_idx = 0;
760
761	while block_idx < BLOCK_IDX_AMOUNT {
762		let base = block_idx * block_size;
763		let mut j = 0;
764		while j < 16 {
765			masks[block_idx][j] =
766				(base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
767			j += 1;
768		}
769		block_idx += 1;
770	}
771	let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
772
773	let mut block_idx = 0;
774
775	while block_idx < BLOCK_IDX_AMOUNT {
776		m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
777		block_idx += 1;
778	}
779
780	m128_masks
781}
782
783impl UnderlierWithBitConstants for M128 {
784	const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
785		Self::from_u128(interleave_mask_even!(u128, 0)),
786		Self::from_u128(interleave_mask_even!(u128, 1)),
787		Self::from_u128(interleave_mask_even!(u128, 2)),
788		Self::from_u128(interleave_mask_even!(u128, 3)),
789		Self::from_u128(interleave_mask_even!(u128, 4)),
790		Self::from_u128(interleave_mask_even!(u128, 5)),
791		Self::from_u128(interleave_mask_even!(u128, 6)),
792	];
793
794	const INTERLEAVE_ODD_MASK: &'static [Self] = &[
795		Self::from_u128(interleave_mask_odd!(u128, 0)),
796		Self::from_u128(interleave_mask_odd!(u128, 1)),
797		Self::from_u128(interleave_mask_odd!(u128, 2)),
798		Self::from_u128(interleave_mask_odd!(u128, 3)),
799		Self::from_u128(interleave_mask_odd!(u128, 4)),
800		Self::from_u128(interleave_mask_odd!(u128, 5)),
801		Self::from_u128(interleave_mask_odd!(u128, 6)),
802	];
803
804	#[inline(always)]
805	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
806		unsafe {
807			let (c, d) = interleave_bits(
808				Into::<Self>::into(self).into(),
809				Into::<Self>::into(other).into(),
810				log_block_len,
811			);
812			(Self::from(c), Self::from(d))
813		}
814	}
815}
816
817impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
818	fn from(value: __m128i) -> Self {
819		M128::from(value).into()
820	}
821}
822
823impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
824	fn from(value: u128) -> Self {
825		M128::from(value).into()
826	}
827}
828
829impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
830	fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
831		value.to_underlier().into()
832	}
833}
834
835impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
836where
837	u128: From<Scalar::Underlier>,
838{
839	#[inline(always)]
840	fn broadcast(scalar: Scalar) -> Self {
841		let tower_level = Scalar::N_BITS.ilog2() as usize;
842		let mut value = u128::from(scalar.to_underlier());
843		for n in tower_level..3 {
844			value |= value << (1 << n);
845		}
846
847		let value = must_cast(value);
848		let value = match tower_level {
849			0..=3 => unsafe { _mm_broadcastb_epi8(value) },
850			4 => unsafe { _mm_broadcastw_epi16(value) },
851			5 => unsafe { _mm_broadcastd_epi32(value) },
852			6 => unsafe { _mm_broadcastq_epi64(value) },
853			7 => value,
854			_ => unreachable!(),
855		};
856
857		value.into()
858	}
859}
860
861#[inline]
862unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
863	match log_block_len {
864		0 => {
865			let mask = _mm_set1_epi8(0x55i8);
866			interleave_bits_imm::<1>(a, b, mask)
867		}
868		1 => {
869			let mask = _mm_set1_epi8(0x33i8);
870			interleave_bits_imm::<2>(a, b, mask)
871		}
872		2 => {
873			let mask = _mm_set1_epi8(0x0fi8);
874			interleave_bits_imm::<4>(a, b, mask)
875		}
876		3 => {
877			let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
878			let a = _mm_shuffle_epi8(a, shuffle);
879			let b = _mm_shuffle_epi8(b, shuffle);
880			let a_prime = _mm_unpacklo_epi8(a, b);
881			let b_prime = _mm_unpackhi_epi8(a, b);
882			(a_prime, b_prime)
883		}
884		4 => {
885			let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
886			let a = _mm_shuffle_epi8(a, shuffle);
887			let b = _mm_shuffle_epi8(b, shuffle);
888			let a_prime = _mm_unpacklo_epi16(a, b);
889			let b_prime = _mm_unpackhi_epi16(a, b);
890			(a_prime, b_prime)
891		}
892		5 => {
893			let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
894			let a = _mm_shuffle_epi8(a, shuffle);
895			let b = _mm_shuffle_epi8(b, shuffle);
896			let a_prime = _mm_unpacklo_epi32(a, b);
897			let b_prime = _mm_unpackhi_epi32(a, b);
898			(a_prime, b_prime)
899		}
900		6 => {
901			let a_prime = _mm_unpacklo_epi64(a, b);
902			let b_prime = _mm_unpackhi_epi64(a, b);
903			(a_prime, b_prime)
904		}
905		_ => panic!("unsupported block length"),
906	}
907}
908
909#[inline]
910unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
911	a: __m128i,
912	b: __m128i,
913	mask: __m128i,
914) -> (__m128i, __m128i) {
915	let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
916	let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
917	let b_prime = _mm_xor_si128(b, t);
918	(a_prime, b_prime)
919}
920
921impl_iteration!(M128,
922	@strategy BitIterationStrategy, U1,
923	@strategy FallbackStrategy, U2, U4,
924	@strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
925);
926
927#[cfg(test)]
928mod tests {
929	use proptest::{arbitrary::any, proptest};
930
931	use super::*;
932	use crate::underlier::single_element_mask_bits;
933
934	fn check_roundtrip<T>(val: M128)
935	where
936		T: From<M128>,
937		M128: From<T>,
938	{
939		assert_eq!(M128::from(T::from(val)), val);
940	}
941
942	#[test]
943	fn test_constants() {
944		assert_eq!(M128::default(), M128::ZERO);
945		assert_eq!(M128::from(0u128), M128::ZERO);
946		assert_eq!(M128::from(1u128), M128::ONE);
947	}
948
949	fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
950		(value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
951	}
952
953	proptest! {
954		#[test]
955		fn test_conversion(a in any::<u128>()) {
956			check_roundtrip::<u128>(a.into());
957			check_roundtrip::<__m128i>(a.into());
958		}
959
960		#[test]
961		fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
962			assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
963			assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
964			assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
965		}
966
967		#[test]
968		fn test_negate(a in any::<u128>()) {
969			assert_eq!(M128::from(!a), !M128::from(a))
970		}
971
972		#[test]
973		fn test_shifts(a in any::<u128>(), b in 0..128usize) {
974			assert_eq!(M128::from(a << b), M128::from(a) << b);
975			assert_eq!(M128::from(a >> b), M128::from(a) >> b);
976		}
977
978		#[test]
979		fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
980			let a = M128::from(a);
981			let b = M128::from(b);
982
983			let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
984			let (c, d) = (M128::from(c), M128::from(d));
985
986			for i in (0..128>>height).step_by(2) {
987				assert_eq!(get(c, height, i), get(a, height, i));
988				assert_eq!(get(c, height, i+1), get(b, height, i));
989				assert_eq!(get(d, height, i), get(a, height, i+1));
990				assert_eq!(get(d, height, i+1), get(b, height, i+1));
991			}
992		}
993
994		#[test]
995		fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
996			let a = M128::from(a);
997			let b = M128::from(b);
998
999			let result = a.unpack_lo_128b_lanes(b, height);
1000			for i in 0..128>>(height + 1) {
1001				assert_eq!(get(result, height, 2*i), get(a, height, i));
1002				assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1003			}
1004		}
1005
1006		#[test]
1007		fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1008			let a = M128::from(a);
1009			let b = M128::from(b);
1010
1011			let result = a.unpack_hi_128b_lanes(b, height);
1012			let half_block_count = 128>>(height + 1);
1013			for i in 0..half_block_count {
1014				assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1015				assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1016			}
1017		}
1018	}
1019
1020	#[test]
1021	fn test_fill_with_bit() {
1022		assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1023		assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1024	}
1025
1026	#[test]
1027	fn test_eq() {
1028		let a = M128::from(0u128);
1029		let b = M128::from(42u128);
1030		let c = M128::from(u128::MAX);
1031
1032		assert_eq!(a, a);
1033		assert_eq!(b, b);
1034		assert_eq!(c, c);
1035
1036		assert_ne!(a, b);
1037		assert_ne!(a, c);
1038		assert_ne!(b, c);
1039	}
1040}