binius_field/arch/x86_64/
m128.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	arch::x86_64::*,
5	array,
6	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use bytemuck::{must_cast, Pod, Zeroable};
10use rand::{Rng, RngCore};
11use seq_macro::seq;
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
13
14use crate::{
15	arch::{
16		binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
17		portable::{
18			packed::{impl_pack_scalar, PackedPrimitiveType},
19			packed_arithmetic::{
20				interleave_mask_even, interleave_mask_odd, UnderlierWithBitConstants,
21			},
22		},
23	},
24	arithmetic_traits::Broadcast,
25	tower_levels::TowerLevel,
26	underlier::{
27		impl_divisible, impl_iteration, spread_fallback, transpose_128b_values,
28		unpack_hi_128b_fallback, unpack_lo_128b_fallback, NumCast, Random, SmallU, SpreadToByte,
29		UnderlierType, UnderlierWithBitOps, WithUnderlier, U1, U2, U4,
30	},
31	BinaryField,
32};
33
34/// 128-bit value that is used for 128-bit SIMD operations
35#[derive(Copy, Clone)]
36#[repr(transparent)]
37pub struct M128(pub(super) __m128i);
38
39impl M128 {
40	#[inline(always)]
41	pub const fn from_u128(val: u128) -> Self {
42		let mut result = Self::ZERO;
43		unsafe {
44			result.0 = std::mem::transmute_copy(&val);
45		}
46
47		result
48	}
49}
50
51impl From<__m128i> for M128 {
52	#[inline(always)]
53	fn from(value: __m128i) -> Self {
54		Self(value)
55	}
56}
57
58impl From<u128> for M128 {
59	fn from(value: u128) -> Self {
60		Self(unsafe { _mm_loadu_si128(&raw const value as *const __m128i) })
61	}
62}
63
64impl From<u64> for M128 {
65	fn from(value: u64) -> Self {
66		Self::from(value as u128)
67	}
68}
69
70impl From<u32> for M128 {
71	fn from(value: u32) -> Self {
72		Self::from(value as u128)
73	}
74}
75
76impl From<u16> for M128 {
77	fn from(value: u16) -> Self {
78		Self::from(value as u128)
79	}
80}
81
82impl From<u8> for M128 {
83	fn from(value: u8) -> Self {
84		Self::from(value as u128)
85	}
86}
87
88impl<const N: usize> From<SmallU<N>> for M128 {
89	fn from(value: SmallU<N>) -> Self {
90		Self::from(value.val() as u128)
91	}
92}
93
94impl From<M128> for u128 {
95	fn from(value: M128) -> Self {
96		let mut result = 0u128;
97		unsafe { _mm_storeu_si128(&raw mut result as *mut __m128i, value.0) };
98
99		result
100	}
101}
102
103impl From<M128> for __m128i {
104	#[inline(always)]
105	fn from(value: M128) -> Self {
106		value.0
107	}
108}
109
110impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
111impl_pack_scalar!(M128);
112
113impl<U: NumCast<u128>> NumCast<M128> for U {
114	#[inline(always)]
115	fn num_cast_from(val: M128) -> Self {
116		Self::num_cast_from(u128::from(val))
117	}
118}
119
120impl Default for M128 {
121	#[inline(always)]
122	fn default() -> Self {
123		Self(unsafe { _mm_setzero_si128() })
124	}
125}
126
127impl BitAnd for M128 {
128	type Output = Self;
129
130	#[inline(always)]
131	fn bitand(self, rhs: Self) -> Self::Output {
132		Self(unsafe { _mm_and_si128(self.0, rhs.0) })
133	}
134}
135
136impl BitAndAssign for M128 {
137	#[inline(always)]
138	fn bitand_assign(&mut self, rhs: Self) {
139		*self = *self & rhs
140	}
141}
142
143impl BitOr for M128 {
144	type Output = Self;
145
146	#[inline(always)]
147	fn bitor(self, rhs: Self) -> Self::Output {
148		Self(unsafe { _mm_or_si128(self.0, rhs.0) })
149	}
150}
151
152impl BitOrAssign for M128 {
153	#[inline(always)]
154	fn bitor_assign(&mut self, rhs: Self) {
155		*self = *self | rhs
156	}
157}
158
159impl BitXor for M128 {
160	type Output = Self;
161
162	#[inline(always)]
163	fn bitxor(self, rhs: Self) -> Self::Output {
164		Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
165	}
166}
167
168impl BitXorAssign for M128 {
169	#[inline(always)]
170	fn bitxor_assign(&mut self, rhs: Self) {
171		*self = *self ^ rhs;
172	}
173}
174
175impl Not for M128 {
176	type Output = Self;
177
178	fn not(self) -> Self::Output {
179		const ONES: __m128i = m128_from_u128!(u128::MAX);
180
181		self ^ Self(ONES)
182	}
183}
184
185/// `std::cmp::max` isn't const, so we need our own implementation
186pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
187	if left > right {
188		left
189	} else {
190		right
191	}
192}
193
194/// This solution shows 4X better performance.
195/// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed as constant
196/// and Rust currently doesn't allow passing expressions (`count - 64`) where variable is a generic constant parameter.
197/// Source: https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688
198macro_rules! bitshift_128b {
199	($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
200		unsafe {
201			let carry = $byte_shift($val, 8);
202			seq!(N in 64..128 {
203				if $shift == N {
204					return $bit_shift_64(
205						carry,
206						crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
207					).into();
208				}
209			});
210			seq!(N in 0..64 {
211				if $shift == N {
212					let carry = $bit_shift_64_opposite(
213						carry,
214						crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
215					);
216
217					let val = $bit_shift_64($val, N);
218					return $or(val, carry).into();
219				}
220			});
221
222			return Default::default()
223		}
224	};
225}
226
227pub(crate) use bitshift_128b;
228
229impl Shr<usize> for M128 {
230	type Output = Self;
231
232	#[inline(always)]
233	fn shr(self, rhs: usize) -> Self::Output {
234		// This implementation is effective when `rhs` is known at compile-time.
235		// In our code this is always the case.
236		bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
237	}
238}
239
240impl Shl<usize> for M128 {
241	type Output = Self;
242
243	#[inline(always)]
244	fn shl(self, rhs: usize) -> Self::Output {
245		// This implementation is effective when `rhs` is known at compile-time.
246		// In our code this is always the case.
247		bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
248	}
249}
250
251impl PartialEq for M128 {
252	fn eq(&self, other: &Self) -> bool {
253		unsafe {
254			let neq = _mm_xor_si128(self.0, other.0);
255			_mm_test_all_zeros(neq, neq) == 1
256		}
257	}
258}
259
260impl Eq for M128 {}
261
262impl PartialOrd for M128 {
263	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
264		Some(self.cmp(other))
265	}
266}
267
268impl Ord for M128 {
269	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
270		u128::from(*self).cmp(&u128::from(*other))
271	}
272}
273
274impl ConstantTimeEq for M128 {
275	fn ct_eq(&self, other: &Self) -> Choice {
276		unsafe {
277			let neq = _mm_xor_si128(self.0, other.0);
278			Choice::from(_mm_test_all_zeros(neq, neq) as u8)
279		}
280	}
281}
282
283impl ConditionallySelectable for M128 {
284	fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
285		ConditionallySelectable::conditional_select(&u128::from(*a), &u128::from(*b), choice).into()
286	}
287}
288
289impl Random for M128 {
290	fn random(mut rng: impl RngCore) -> Self {
291		let val: u128 = rng.gen();
292		val.into()
293	}
294}
295
296impl std::fmt::Display for M128 {
297	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298		let data: u128 = (*self).into();
299		write!(f, "{data:02X?}")
300	}
301}
302
303impl std::fmt::Debug for M128 {
304	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305		write!(f, "M128({})", self)
306	}
307}
308
309#[repr(align(16))]
310pub struct AlignedData(pub [u128; 1]);
311
312macro_rules! m128_from_u128 {
313	($val:expr) => {{
314		let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
315		unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
316	}};
317}
318
319pub(super) use m128_from_u128;
320
321impl UnderlierType for M128 {
322	const LOG_BITS: usize = 7;
323}
324
325impl UnderlierWithBitOps for M128 {
326	const ZERO: Self = { Self(m128_from_u128!(0)) };
327	const ONE: Self = { Self(m128_from_u128!(1)) };
328	const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
329
330	#[inline(always)]
331	fn fill_with_bit(val: u8) -> Self {
332		assert!(val == 0 || val == 1);
333		Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
334	}
335
336	#[inline(always)]
337	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
338	where
339		T: UnderlierType,
340		Self: From<T>,
341	{
342		match T::BITS {
343			1 | 2 | 4 => {
344				let mut f = make_func_to_i8::<T, Self>(f);
345
346				unsafe {
347					_mm_set_epi8(
348						f(15),
349						f(14),
350						f(13),
351						f(12),
352						f(11),
353						f(10),
354						f(9),
355						f(8),
356						f(7),
357						f(6),
358						f(5),
359						f(4),
360						f(3),
361						f(2),
362						f(1),
363						f(0),
364					)
365				}
366				.into()
367			}
368			8 => {
369				let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
370				unsafe {
371					_mm_set_epi8(
372						f(15),
373						f(14),
374						f(13),
375						f(12),
376						f(11),
377						f(10),
378						f(9),
379						f(8),
380						f(7),
381						f(6),
382						f(5),
383						f(4),
384						f(3),
385						f(2),
386						f(1),
387						f(0),
388					)
389				}
390				.into()
391			}
392			16 => {
393				let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
394				unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
395			}
396			32 => {
397				let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
398				unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
399			}
400			64 => {
401				let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
402				unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
403			}
404			128 => Self::from(f(0)),
405			_ => panic!("unsupported bit count"),
406		}
407	}
408
409	#[inline(always)]
410	unsafe fn get_subvalue<T>(&self, i: usize) -> T
411	where
412		T: UnderlierType + NumCast<Self>,
413	{
414		match T::BITS {
415			1 | 2 | 4 => {
416				let elements_in_8 = 8 / T::BITS;
417				let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
418					*arr.get_unchecked(i / elements_in_8)
419				});
420
421				let shift = (i % elements_in_8) * T::BITS;
422				value_u8 >>= shift;
423
424				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
425			}
426			8 => {
427				let value_u8 =
428					as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
429				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
430			}
431			16 => {
432				let value_u16 =
433					as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
434				T::from_underlier(T::num_cast_from(Self::from(value_u16)))
435			}
436			32 => {
437				let value_u32 =
438					as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
439				T::from_underlier(T::num_cast_from(Self::from(value_u32)))
440			}
441			64 => {
442				let value_u64 =
443					as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
444				T::from_underlier(T::num_cast_from(Self::from(value_u64)))
445			}
446			128 => T::from_underlier(T::num_cast_from(*self)),
447			_ => panic!("unsupported bit count"),
448		}
449	}
450
451	#[inline(always)]
452	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
453	where
454		T: UnderlierWithBitOps,
455		Self: From<T>,
456	{
457		match T::BITS {
458			1 | 2 | 4 => {
459				let elements_in_8 = 8 / T::BITS;
460				let mask = (1u8 << T::BITS) - 1;
461				let shift = (i % elements_in_8) * T::BITS;
462				let val = u8::num_cast_from(Self::from(val)) << shift;
463				let mask = mask << shift;
464
465				as_array_mut::<_, u8, 16>(self, |array| unsafe {
466					let element = array.get_unchecked_mut(i / elements_in_8);
467					*element &= !mask;
468					*element |= val;
469				});
470			}
471			8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
472				*array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
473			}),
474			16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
475				*array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
476			}),
477			32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
478				*array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
479			}),
480			64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
481				*array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
482			}),
483			128 => {
484				*self = Self::from(val);
485			}
486			_ => panic!("unsupported bit count"),
487		}
488	}
489
490	#[inline(always)]
491	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
492	where
493		T: UnderlierWithBitOps + NumCast<Self>,
494		Self: From<T>,
495	{
496		match T::LOG_BITS {
497			0 => match log_block_len {
498				0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
499				1 => {
500					let bits: [u8; 2] =
501						array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
502
503					_mm_set_epi64x(
504						u64::fill_with_bit(bits[1]) as i64,
505						u64::fill_with_bit(bits[0]) as i64,
506					)
507					.into()
508				}
509				2 => {
510					let bits: [u8; 4] =
511						array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
512
513					_mm_set_epi32(
514						u32::fill_with_bit(bits[3]) as i32,
515						u32::fill_with_bit(bits[2]) as i32,
516						u32::fill_with_bit(bits[1]) as i32,
517						u32::fill_with_bit(bits[0]) as i32,
518					)
519					.into()
520				}
521				3 => {
522					let bits: [u8; 8] =
523						array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
524
525					_mm_set_epi16(
526						u16::fill_with_bit(bits[7]) as i16,
527						u16::fill_with_bit(bits[6]) as i16,
528						u16::fill_with_bit(bits[5]) as i16,
529						u16::fill_with_bit(bits[4]) as i16,
530						u16::fill_with_bit(bits[3]) as i16,
531						u16::fill_with_bit(bits[2]) as i16,
532						u16::fill_with_bit(bits[1]) as i16,
533						u16::fill_with_bit(bits[0]) as i16,
534					)
535					.into()
536				}
537				4 => {
538					let bits: [u8; 16] =
539						array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
540
541					_mm_set_epi8(
542						u8::fill_with_bit(bits[15]) as i8,
543						u8::fill_with_bit(bits[14]) as i8,
544						u8::fill_with_bit(bits[13]) as i8,
545						u8::fill_with_bit(bits[12]) as i8,
546						u8::fill_with_bit(bits[11]) as i8,
547						u8::fill_with_bit(bits[10]) as i8,
548						u8::fill_with_bit(bits[9]) as i8,
549						u8::fill_with_bit(bits[8]) as i8,
550						u8::fill_with_bit(bits[7]) as i8,
551						u8::fill_with_bit(bits[6]) as i8,
552						u8::fill_with_bit(bits[5]) as i8,
553						u8::fill_with_bit(bits[4]) as i8,
554						u8::fill_with_bit(bits[3]) as i8,
555						u8::fill_with_bit(bits[2]) as i8,
556						u8::fill_with_bit(bits[1]) as i8,
557						u8::fill_with_bit(bits[0]) as i8,
558					)
559					.into()
560				}
561				_ => spread_fallback(self, log_block_len, block_idx),
562			},
563			1 => match log_block_len {
564				0 => {
565					let value =
566						U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
567
568					_mm_set1_epi8(value as i8).into()
569				}
570				1 => {
571					let bytes: [u8; 2] = array::from_fn(|i| {
572						U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
573					});
574
575					Self::from_fn::<u8>(|i| bytes[i / 8])
576				}
577				2 => {
578					let bytes: [u8; 4] = array::from_fn(|i| {
579						U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
580					});
581
582					Self::from_fn::<u8>(|i| bytes[i / 4])
583				}
584				3 => {
585					let bytes: [u8; 8] = array::from_fn(|i| {
586						U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
587							.spread_to_byte()
588					});
589
590					Self::from_fn::<u8>(|i| bytes[i / 2])
591				}
592				4 => {
593					let bytes: [u8; 16] = array::from_fn(|i| {
594						U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
595							.spread_to_byte()
596					});
597
598					Self::from_fn::<u8>(|i| bytes[i])
599				}
600				_ => spread_fallback(self, log_block_len, block_idx),
601			},
602			2 => match log_block_len {
603				0 => {
604					let value =
605						U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
606
607					_mm_set1_epi8(value as i8).into()
608				}
609				1 => {
610					let values: [u8; 2] = array::from_fn(|i| {
611						U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
612					});
613
614					Self::from_fn::<u8>(|i| values[i / 8])
615				}
616				2 => {
617					let values: [u8; 4] = array::from_fn(|i| {
618						U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
619							.spread_to_byte()
620					});
621
622					Self::from_fn::<u8>(|i| values[i / 4])
623				}
624				3 => {
625					let values: [u8; 8] = array::from_fn(|i| {
626						U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
627							.spread_to_byte()
628					});
629
630					Self::from_fn::<u8>(|i| values[i / 2])
631				}
632				4 => {
633					let values: [u8; 16] = array::from_fn(|i| {
634						U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
635							.spread_to_byte()
636					});
637
638					Self::from_fn::<u8>(|i| values[i])
639				}
640				_ => spread_fallback(self, log_block_len, block_idx),
641			},
642			3 => match log_block_len {
643				0 => _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into(),
644				1 => _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into(),
645				2 => _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into(),
646				3 => _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into(),
647				4 => self,
648				_ => panic!("unsupported block length"),
649			},
650			4 => match log_block_len {
651				0 => {
652					let value = (u128::from(self) >> (block_idx * 16)) as u16;
653
654					_mm_set1_epi16(value as i16).into()
655				}
656				1 => {
657					let values: [u16; 2] =
658						array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
659
660					Self::from_fn::<u16>(|i| values[i / 4])
661				}
662				2 => {
663					let values: [u16; 4] =
664						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
665
666					Self::from_fn::<u16>(|i| values[i / 2])
667				}
668				3 => self,
669				_ => panic!("unsupported block length"),
670			},
671			5 => match log_block_len {
672				0 => {
673					let value = (u128::from(self) >> (block_idx * 32)) as u32;
674
675					_mm_set1_epi32(value as i32).into()
676				}
677				1 => {
678					let values: [u32; 2] =
679						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
680
681					Self::from_fn::<u32>(|i| values[i / 2])
682				}
683				2 => self,
684				_ => panic!("unsupported block length"),
685			},
686			6 => match log_block_len {
687				0 => {
688					let value = (u128::from(self) >> (block_idx * 64)) as u64;
689
690					_mm_set1_epi64x(value as i64).into()
691				}
692				1 => self,
693				_ => panic!("unsupported block length"),
694			},
695			7 => self,
696			_ => panic!("unsupported bit length"),
697		}
698	}
699
700	#[inline]
701	fn shl_128b_lanes(self, shift: usize) -> Self {
702		self << shift
703	}
704
705	#[inline]
706	fn shr_128b_lanes(self, shift: usize) -> Self {
707		self >> shift
708	}
709
710	#[inline]
711	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
712		match log_block_len {
713			0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
714			3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
715			4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
716			5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
717			6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
718			_ => panic!("unsupported block length"),
719		}
720	}
721
722	#[inline]
723	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
724		match log_block_len {
725			0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
726			3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
727			4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
728			5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
729			6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
730			_ => panic!("unsupported block length"),
731		}
732	}
733
734	#[inline]
735	fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
736	where
737		u8: NumCast<Self>,
738		Self: From<u8>,
739	{
740		transpose_128b_values::<Self, TL>(values, 0);
741	}
742
743	#[inline]
744	fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
745	where
746		u8: NumCast<Self>,
747		Self: From<u8>,
748	{
749		if TL::LOG_WIDTH == 0 {
750			return;
751		}
752
753		match TL::LOG_WIDTH {
754			1 => unsafe {
755				let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
756				for v in values.as_mut().iter_mut() {
757					*v = _mm_shuffle_epi8(v.0, shuffle).into();
758				}
759			},
760			2 => unsafe {
761				let shuffle = _mm_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
762				for v in values.as_mut().iter_mut() {
763					*v = _mm_shuffle_epi8(v.0, shuffle).into();
764				}
765			},
766			3 => unsafe {
767				let shuffle = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
768				for v in values.as_mut().iter_mut() {
769					*v = _mm_shuffle_epi8(v.0, shuffle).into();
770				}
771			},
772			4 => {}
773			_ => unreachable!("Log width must be less than 5"),
774		}
775
776		transpose_128b_values::<_, TL>(values, 4 - TL::LOG_WIDTH);
777	}
778}
779
780unsafe impl Zeroable for M128 {}
781
782unsafe impl Pod for M128 {}
783
784unsafe impl Send for M128 {}
785
786unsafe impl Sync for M128 {}
787
788static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
789static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
790static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
791static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
792
793const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
794	log_block_len: usize,
795	t_log_bits: usize,
796) -> [M128; BLOCK_IDX_AMOUNT] {
797	let element_log_width = t_log_bits - 3;
798
799	let element_width = 1 << element_log_width;
800
801	let block_size = 1 << (log_block_len + element_log_width);
802	let repeat = 1 << (4 - element_log_width - log_block_len);
803	let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
804
805	let mut block_idx = 0;
806
807	while block_idx < BLOCK_IDX_AMOUNT {
808		let base = block_idx * block_size;
809		let mut j = 0;
810		while j < 16 {
811			masks[block_idx][j] =
812				(base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
813			j += 1;
814		}
815		block_idx += 1;
816	}
817	let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
818
819	let mut block_idx = 0;
820
821	while block_idx < BLOCK_IDX_AMOUNT {
822		m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
823		block_idx += 1;
824	}
825
826	m128_masks
827}
828
829impl UnderlierWithBitConstants for M128 {
830	const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
831		Self::from_u128(interleave_mask_even!(u128, 0)),
832		Self::from_u128(interleave_mask_even!(u128, 1)),
833		Self::from_u128(interleave_mask_even!(u128, 2)),
834		Self::from_u128(interleave_mask_even!(u128, 3)),
835		Self::from_u128(interleave_mask_even!(u128, 4)),
836		Self::from_u128(interleave_mask_even!(u128, 5)),
837		Self::from_u128(interleave_mask_even!(u128, 6)),
838	];
839
840	const INTERLEAVE_ODD_MASK: &'static [Self] = &[
841		Self::from_u128(interleave_mask_odd!(u128, 0)),
842		Self::from_u128(interleave_mask_odd!(u128, 1)),
843		Self::from_u128(interleave_mask_odd!(u128, 2)),
844		Self::from_u128(interleave_mask_odd!(u128, 3)),
845		Self::from_u128(interleave_mask_odd!(u128, 4)),
846		Self::from_u128(interleave_mask_odd!(u128, 5)),
847		Self::from_u128(interleave_mask_odd!(u128, 6)),
848	];
849
850	#[inline(always)]
851	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
852		unsafe {
853			let (c, d) = interleave_bits(
854				Into::<Self>::into(self).into(),
855				Into::<Self>::into(other).into(),
856				log_block_len,
857			);
858			(Self::from(c), Self::from(d))
859		}
860	}
861}
862
863impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
864	fn from(value: __m128i) -> Self {
865		M128::from(value).into()
866	}
867}
868
869impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
870	fn from(value: u128) -> Self {
871		M128::from(value).into()
872	}
873}
874
875impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
876	fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
877		value.to_underlier().into()
878	}
879}
880
881impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
882where
883	u128: From<Scalar::Underlier>,
884{
885	#[inline(always)]
886	fn broadcast(scalar: Scalar) -> Self {
887		let tower_level = Scalar::N_BITS.ilog2() as usize;
888		let mut value = u128::from(scalar.to_underlier());
889		for n in tower_level..3 {
890			value |= value << (1 << n);
891		}
892
893		let value = must_cast(value);
894		let value = match tower_level {
895			0..=3 => unsafe { _mm_broadcastb_epi8(value) },
896			4 => unsafe { _mm_broadcastw_epi16(value) },
897			5 => unsafe { _mm_broadcastd_epi32(value) },
898			6 => unsafe { _mm_broadcastq_epi64(value) },
899			7 => value,
900			_ => unreachable!(),
901		};
902
903		value.into()
904	}
905}
906
907#[inline]
908unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
909	match log_block_len {
910		0 => {
911			let mask = _mm_set1_epi8(0x55i8);
912			interleave_bits_imm::<1>(a, b, mask)
913		}
914		1 => {
915			let mask = _mm_set1_epi8(0x33i8);
916			interleave_bits_imm::<2>(a, b, mask)
917		}
918		2 => {
919			let mask = _mm_set1_epi8(0x0fi8);
920			interleave_bits_imm::<4>(a, b, mask)
921		}
922		3 => {
923			let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
924			let a = _mm_shuffle_epi8(a, shuffle);
925			let b = _mm_shuffle_epi8(b, shuffle);
926			let a_prime = _mm_unpacklo_epi8(a, b);
927			let b_prime = _mm_unpackhi_epi8(a, b);
928			(a_prime, b_prime)
929		}
930		4 => {
931			let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
932			let a = _mm_shuffle_epi8(a, shuffle);
933			let b = _mm_shuffle_epi8(b, shuffle);
934			let a_prime = _mm_unpacklo_epi16(a, b);
935			let b_prime = _mm_unpackhi_epi16(a, b);
936			(a_prime, b_prime)
937		}
938		5 => {
939			let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
940			let a = _mm_shuffle_epi8(a, shuffle);
941			let b = _mm_shuffle_epi8(b, shuffle);
942			let a_prime = _mm_unpacklo_epi32(a, b);
943			let b_prime = _mm_unpackhi_epi32(a, b);
944			(a_prime, b_prime)
945		}
946		6 => {
947			let a_prime = _mm_unpacklo_epi64(a, b);
948			let b_prime = _mm_unpackhi_epi64(a, b);
949			(a_prime, b_prime)
950		}
951		_ => panic!("unsupported block length"),
952	}
953}
954
955#[inline]
956unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
957	a: __m128i,
958	b: __m128i,
959	mask: __m128i,
960) -> (__m128i, __m128i) {
961	let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
962	let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
963	let b_prime = _mm_xor_si128(b, t);
964	(a_prime, b_prime)
965}
966
967impl_iteration!(M128,
968	@strategy BitIterationStrategy, U1,
969	@strategy FallbackStrategy, U2, U4,
970	@strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
971);
972
973#[cfg(test)]
974mod tests {
975	use proptest::{arbitrary::any, proptest};
976
977	use super::*;
978	use crate::underlier::single_element_mask_bits;
979
980	fn check_roundtrip<T>(val: M128)
981	where
982		T: From<M128>,
983		M128: From<T>,
984	{
985		assert_eq!(M128::from(T::from(val)), val);
986	}
987
988	#[test]
989	fn test_constants() {
990		assert_eq!(M128::default(), M128::ZERO);
991		assert_eq!(M128::from(0u128), M128::ZERO);
992		assert_eq!(M128::from(1u128), M128::ONE);
993	}
994
995	fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
996		(value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
997	}
998
999	proptest! {
1000		#[test]
1001		fn test_conversion(a in any::<u128>()) {
1002			check_roundtrip::<u128>(a.into());
1003			check_roundtrip::<__m128i>(a.into());
1004		}
1005
1006		#[test]
1007		fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1008			assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1009			assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1010			assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1011		}
1012
1013		#[test]
1014		fn test_negate(a in any::<u128>()) {
1015			assert_eq!(M128::from(!a), !M128::from(a))
1016		}
1017
1018		#[test]
1019		fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1020			assert_eq!(M128::from(a << b), M128::from(a) << b);
1021			assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1022		}
1023
1024		#[test]
1025		fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1026			let a = M128::from(a);
1027			let b = M128::from(b);
1028
1029			let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1030			let (c, d) = (M128::from(c), M128::from(d));
1031
1032			for i in (0..128>>height).step_by(2) {
1033				assert_eq!(get(c, height, i), get(a, height, i));
1034				assert_eq!(get(c, height, i+1), get(b, height, i));
1035				assert_eq!(get(d, height, i), get(a, height, i+1));
1036				assert_eq!(get(d, height, i+1), get(b, height, i+1));
1037			}
1038		}
1039
1040		#[test]
1041		fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1042			let a = M128::from(a);
1043			let b = M128::from(b);
1044
1045			let result = a.unpack_lo_128b_lanes(b, height);
1046			for i in 0..128>>(height + 1) {
1047				assert_eq!(get(result, height, 2*i), get(a, height, i));
1048				assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1049			}
1050		}
1051
1052		#[test]
1053		fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1054			let a = M128::from(a);
1055			let b = M128::from(b);
1056
1057			let result = a.unpack_hi_128b_lanes(b, height);
1058			let half_block_count = 128>>(height + 1);
1059			for i in 0..half_block_count {
1060				assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1061				assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1062			}
1063		}
1064	}
1065
1066	#[test]
1067	fn test_fill_with_bit() {
1068		assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1069		assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1070	}
1071
1072	#[test]
1073	fn test_eq() {
1074		let a = M128::from(0u128);
1075		let b = M128::from(42u128);
1076		let c = M128::from(u128::MAX);
1077
1078		assert_eq!(a, a);
1079		assert_eq!(b, b);
1080		assert_eq!(c, c);
1081
1082		assert_ne!(a, b);
1083		assert_ne!(a, c);
1084		assert_ne!(b, c);
1085	}
1086}