binius_field/arch/x86_64/
m128.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	arch::x86_64::*,
5	array,
6	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializationMode, SerializeBytes,
11	bytes::{Buf, BufMut},
12	serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{Rng, RngCore};
16use seq_macro::seq;
17use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
18
19use crate::{
20	BinaryField,
21	arch::{
22		binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
23		portable::{
24			packed::{PackedPrimitiveType, impl_pack_scalar},
25			packed_arithmetic::{
26				UnderlierWithBitConstants, interleave_mask_even, interleave_mask_odd,
27			},
28		},
29	},
30	arithmetic_traits::Broadcast,
31	tower_levels::TowerLevel,
32	underlier::{
33		NumCast, Random, SmallU, SpreadToByte, U1, U2, U4, UnderlierType, UnderlierWithBitOps,
34		WithUnderlier, impl_divisible, impl_iteration, spread_fallback, transpose_128b_values,
35		unpack_hi_128b_fallback, unpack_lo_128b_fallback,
36	},
37};
38
39/// 128-bit value that is used for 128-bit SIMD operations
40#[derive(Copy, Clone)]
41#[repr(transparent)]
42pub struct M128(pub(super) __m128i);
43
44impl M128 {
45	#[inline(always)]
46	pub const fn from_u128(val: u128) -> Self {
47		let mut result = Self::ZERO;
48		unsafe {
49			result.0 = std::mem::transmute_copy(&val);
50		}
51
52		result
53	}
54}
55
56impl From<__m128i> for M128 {
57	#[inline(always)]
58	fn from(value: __m128i) -> Self {
59		Self(value)
60	}
61}
62
63impl From<u128> for M128 {
64	fn from(value: u128) -> Self {
65		Self(unsafe { _mm_loadu_si128(&raw const value as *const __m128i) })
66	}
67}
68
69impl From<u64> for M128 {
70	fn from(value: u64) -> Self {
71		Self::from(value as u128)
72	}
73}
74
75impl From<u32> for M128 {
76	fn from(value: u32) -> Self {
77		Self::from(value as u128)
78	}
79}
80
81impl From<u16> for M128 {
82	fn from(value: u16) -> Self {
83		Self::from(value as u128)
84	}
85}
86
87impl From<u8> for M128 {
88	fn from(value: u8) -> Self {
89		Self::from(value as u128)
90	}
91}
92
93impl<const N: usize> From<SmallU<N>> for M128 {
94	fn from(value: SmallU<N>) -> Self {
95		Self::from(value.val() as u128)
96	}
97}
98
99impl From<M128> for u128 {
100	fn from(value: M128) -> Self {
101		let mut result = 0u128;
102		unsafe { _mm_storeu_si128(&raw mut result as *mut __m128i, value.0) };
103
104		result
105	}
106}
107
108impl From<M128> for __m128i {
109	#[inline(always)]
110	fn from(value: M128) -> Self {
111		value.0
112	}
113}
114
115impl SerializeBytes for M128 {
116	fn serialize(
117		&self,
118		mut write_buf: impl BufMut,
119		_mode: SerializationMode,
120	) -> Result<(), SerializationError> {
121		assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
122
123		let raw_value: u128 = (*self).into();
124
125		write_buf.put_u128_le(raw_value);
126		Ok(())
127	}
128}
129
130impl DeserializeBytes for M128 {
131	fn deserialize(
132		mut read_buf: impl Buf,
133		_mode: SerializationMode,
134	) -> Result<Self, SerializationError>
135	where
136		Self: Sized,
137	{
138		assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
139
140		let raw_value = read_buf.get_u128_le();
141
142		Ok(Self::from(raw_value))
143	}
144}
145
146impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
147impl_pack_scalar!(M128);
148
149impl<U: NumCast<u128>> NumCast<M128> for U {
150	#[inline(always)]
151	fn num_cast_from(val: M128) -> Self {
152		Self::num_cast_from(u128::from(val))
153	}
154}
155
156impl Default for M128 {
157	#[inline(always)]
158	fn default() -> Self {
159		Self(unsafe { _mm_setzero_si128() })
160	}
161}
162
163impl BitAnd for M128 {
164	type Output = Self;
165
166	#[inline(always)]
167	fn bitand(self, rhs: Self) -> Self::Output {
168		Self(unsafe { _mm_and_si128(self.0, rhs.0) })
169	}
170}
171
172impl BitAndAssign for M128 {
173	#[inline(always)]
174	fn bitand_assign(&mut self, rhs: Self) {
175		*self = *self & rhs
176	}
177}
178
179impl BitOr for M128 {
180	type Output = Self;
181
182	#[inline(always)]
183	fn bitor(self, rhs: Self) -> Self::Output {
184		Self(unsafe { _mm_or_si128(self.0, rhs.0) })
185	}
186}
187
188impl BitOrAssign for M128 {
189	#[inline(always)]
190	fn bitor_assign(&mut self, rhs: Self) {
191		*self = *self | rhs
192	}
193}
194
195impl BitXor for M128 {
196	type Output = Self;
197
198	#[inline(always)]
199	fn bitxor(self, rhs: Self) -> Self::Output {
200		Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
201	}
202}
203
204impl BitXorAssign for M128 {
205	#[inline(always)]
206	fn bitxor_assign(&mut self, rhs: Self) {
207		*self = *self ^ rhs;
208	}
209}
210
211impl Not for M128 {
212	type Output = Self;
213
214	fn not(self) -> Self::Output {
215		const ONES: __m128i = m128_from_u128!(u128::MAX);
216
217		self ^ Self(ONES)
218	}
219}
220
221/// `std::cmp::max` isn't const, so we need our own implementation
222pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
223	if left > right { left } else { right }
224}
225
226/// This solution shows 4X better performance.
227/// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed
228/// as constant and Rust currently doesn't allow passing expressions (`count - 64`) where variable
229/// is a generic constant parameter. Source: https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688
230macro_rules! bitshift_128b {
231	($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
232		unsafe {
233			let carry = $byte_shift($val, 8);
234			seq!(N in 64..128 {
235				if $shift == N {
236					return $bit_shift_64(
237						carry,
238						crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
239					).into();
240				}
241			});
242			seq!(N in 0..64 {
243				if $shift == N {
244					let carry = $bit_shift_64_opposite(
245						carry,
246						crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
247					);
248
249					let val = $bit_shift_64($val, N);
250					return $or(val, carry).into();
251				}
252			});
253
254			return Default::default()
255		}
256	};
257}
258
259pub(crate) use bitshift_128b;
260
261impl Shr<usize> for M128 {
262	type Output = Self;
263
264	#[inline(always)]
265	fn shr(self, rhs: usize) -> Self::Output {
266		// This implementation is effective when `rhs` is known at compile-time.
267		// In our code this is always the case.
268		bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
269	}
270}
271
272impl Shl<usize> for M128 {
273	type Output = Self;
274
275	#[inline(always)]
276	fn shl(self, rhs: usize) -> Self::Output {
277		// This implementation is effective when `rhs` is known at compile-time.
278		// In our code this is always the case.
279		bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
280	}
281}
282
283impl PartialEq for M128 {
284	fn eq(&self, other: &Self) -> bool {
285		unsafe {
286			let neq = _mm_xor_si128(self.0, other.0);
287			_mm_test_all_zeros(neq, neq) == 1
288		}
289	}
290}
291
292impl Eq for M128 {}
293
294impl PartialOrd for M128 {
295	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
296		Some(self.cmp(other))
297	}
298}
299
300impl Ord for M128 {
301	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
302		u128::from(*self).cmp(&u128::from(*other))
303	}
304}
305
306impl ConstantTimeEq for M128 {
307	fn ct_eq(&self, other: &Self) -> Choice {
308		unsafe {
309			let neq = _mm_xor_si128(self.0, other.0);
310			Choice::from(_mm_test_all_zeros(neq, neq) as u8)
311		}
312	}
313}
314
315impl ConditionallySelectable for M128 {
316	fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
317		ConditionallySelectable::conditional_select(&u128::from(*a), &u128::from(*b), choice).into()
318	}
319}
320
321impl Random for M128 {
322	fn random(mut rng: impl RngCore) -> Self {
323		let val: u128 = rng.random();
324		val.into()
325	}
326}
327
328impl std::fmt::Display for M128 {
329	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330		let data: u128 = (*self).into();
331		write!(f, "{data:02X?}")
332	}
333}
334
335impl std::fmt::Debug for M128 {
336	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337		write!(f, "M128({self})")
338	}
339}
340
341#[repr(align(16))]
342pub struct AlignedData(pub [u128; 1]);
343
344macro_rules! m128_from_u128 {
345	($val:expr) => {{
346		let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
347		unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
348	}};
349}
350
351pub(super) use m128_from_u128;
352
353impl UnderlierType for M128 {
354	const LOG_BITS: usize = 7;
355}
356
357impl UnderlierWithBitOps for M128 {
358	const ZERO: Self = { Self(m128_from_u128!(0)) };
359	const ONE: Self = { Self(m128_from_u128!(1)) };
360	const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
361
362	#[inline(always)]
363	fn fill_with_bit(val: u8) -> Self {
364		assert!(val == 0 || val == 1);
365		Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
366	}
367
368	#[inline(always)]
369	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
370	where
371		T: UnderlierType,
372		Self: From<T>,
373	{
374		match T::BITS {
375			1 | 2 | 4 => {
376				let mut f = make_func_to_i8::<T, Self>(f);
377
378				unsafe {
379					_mm_set_epi8(
380						f(15),
381						f(14),
382						f(13),
383						f(12),
384						f(11),
385						f(10),
386						f(9),
387						f(8),
388						f(7),
389						f(6),
390						f(5),
391						f(4),
392						f(3),
393						f(2),
394						f(1),
395						f(0),
396					)
397				}
398				.into()
399			}
400			8 => {
401				let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
402				unsafe {
403					_mm_set_epi8(
404						f(15),
405						f(14),
406						f(13),
407						f(12),
408						f(11),
409						f(10),
410						f(9),
411						f(8),
412						f(7),
413						f(6),
414						f(5),
415						f(4),
416						f(3),
417						f(2),
418						f(1),
419						f(0),
420					)
421				}
422				.into()
423			}
424			16 => {
425				let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
426				unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
427			}
428			32 => {
429				let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
430				unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
431			}
432			64 => {
433				let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
434				unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
435			}
436			128 => Self::from(f(0)),
437			_ => panic!("unsupported bit count"),
438		}
439	}
440
441	#[inline(always)]
442	unsafe fn get_subvalue<T>(&self, i: usize) -> T
443	where
444		T: UnderlierType + NumCast<Self>,
445	{
446		match T::BITS {
447			1 | 2 | 4 => {
448				let elements_in_8 = 8 / T::BITS;
449				let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
450					*arr.get_unchecked(i / elements_in_8)
451				});
452
453				let shift = (i % elements_in_8) * T::BITS;
454				value_u8 >>= shift;
455
456				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
457			}
458			8 => {
459				let value_u8 =
460					as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
461				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
462			}
463			16 => {
464				let value_u16 =
465					as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
466				T::from_underlier(T::num_cast_from(Self::from(value_u16)))
467			}
468			32 => {
469				let value_u32 =
470					as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
471				T::from_underlier(T::num_cast_from(Self::from(value_u32)))
472			}
473			64 => {
474				let value_u64 =
475					as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
476				T::from_underlier(T::num_cast_from(Self::from(value_u64)))
477			}
478			128 => T::from_underlier(T::num_cast_from(*self)),
479			_ => panic!("unsupported bit count"),
480		}
481	}
482
483	#[inline(always)]
484	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
485	where
486		T: UnderlierWithBitOps,
487		Self: From<T>,
488	{
489		match T::BITS {
490			1 | 2 | 4 => {
491				let elements_in_8 = 8 / T::BITS;
492				let mask = (1u8 << T::BITS) - 1;
493				let shift = (i % elements_in_8) * T::BITS;
494				let val = u8::num_cast_from(Self::from(val)) << shift;
495				let mask = mask << shift;
496
497				as_array_mut::<_, u8, 16>(self, |array| unsafe {
498					let element = array.get_unchecked_mut(i / elements_in_8);
499					*element &= !mask;
500					*element |= val;
501				});
502			}
503			8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
504				*array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
505			}),
506			16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
507				*array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
508			}),
509			32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
510				*array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
511			}),
512			64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
513				*array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
514			}),
515			128 => {
516				*self = Self::from(val);
517			}
518			_ => panic!("unsupported bit count"),
519		}
520	}
521
522	#[inline(always)]
523	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
524	where
525		T: UnderlierWithBitOps + NumCast<Self>,
526		Self: From<T>,
527	{
528		match T::LOG_BITS {
529			0 => match log_block_len {
530				0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
531				1 => unsafe {
532					let bits: [u8; 2] =
533						array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
534
535					_mm_set_epi64x(
536						u64::fill_with_bit(bits[1]) as i64,
537						u64::fill_with_bit(bits[0]) as i64,
538					)
539					.into()
540				},
541				2 => unsafe {
542					let bits: [u8; 4] =
543						array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
544
545					_mm_set_epi32(
546						u32::fill_with_bit(bits[3]) as i32,
547						u32::fill_with_bit(bits[2]) as i32,
548						u32::fill_with_bit(bits[1]) as i32,
549						u32::fill_with_bit(bits[0]) as i32,
550					)
551					.into()
552				},
553				3 => unsafe {
554					let bits: [u8; 8] =
555						array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
556
557					_mm_set_epi16(
558						u16::fill_with_bit(bits[7]) as i16,
559						u16::fill_with_bit(bits[6]) as i16,
560						u16::fill_with_bit(bits[5]) as i16,
561						u16::fill_with_bit(bits[4]) as i16,
562						u16::fill_with_bit(bits[3]) as i16,
563						u16::fill_with_bit(bits[2]) as i16,
564						u16::fill_with_bit(bits[1]) as i16,
565						u16::fill_with_bit(bits[0]) as i16,
566					)
567					.into()
568				},
569				4 => unsafe {
570					let bits: [u8; 16] =
571						array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
572
573					_mm_set_epi8(
574						u8::fill_with_bit(bits[15]) as i8,
575						u8::fill_with_bit(bits[14]) as i8,
576						u8::fill_with_bit(bits[13]) as i8,
577						u8::fill_with_bit(bits[12]) as i8,
578						u8::fill_with_bit(bits[11]) as i8,
579						u8::fill_with_bit(bits[10]) as i8,
580						u8::fill_with_bit(bits[9]) as i8,
581						u8::fill_with_bit(bits[8]) as i8,
582						u8::fill_with_bit(bits[7]) as i8,
583						u8::fill_with_bit(bits[6]) as i8,
584						u8::fill_with_bit(bits[5]) as i8,
585						u8::fill_with_bit(bits[4]) as i8,
586						u8::fill_with_bit(bits[3]) as i8,
587						u8::fill_with_bit(bits[2]) as i8,
588						u8::fill_with_bit(bits[1]) as i8,
589						u8::fill_with_bit(bits[0]) as i8,
590					)
591					.into()
592				},
593				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
594			},
595			1 => match log_block_len {
596				0 => unsafe {
597					let value =
598						U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
599
600					_mm_set1_epi8(value as i8).into()
601				},
602				1 => {
603					let bytes: [u8; 2] = array::from_fn(|i| {
604						U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
605					});
606
607					Self::from_fn::<u8>(|i| bytes[i / 8])
608				}
609				2 => {
610					let bytes: [u8; 4] = array::from_fn(|i| {
611						U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
612					});
613
614					Self::from_fn::<u8>(|i| bytes[i / 4])
615				}
616				3 => {
617					let bytes: [u8; 8] = array::from_fn(|i| {
618						U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
619							.spread_to_byte()
620					});
621
622					Self::from_fn::<u8>(|i| bytes[i / 2])
623				}
624				4 => {
625					let bytes: [u8; 16] = array::from_fn(|i| {
626						U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
627							.spread_to_byte()
628					});
629
630					Self::from_fn::<u8>(|i| bytes[i])
631				}
632				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
633			},
634			2 => match log_block_len {
635				0 => {
636					let value =
637						U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
638
639					unsafe { _mm_set1_epi8(value as i8).into() }
640				}
641				1 => {
642					let values: [u8; 2] = array::from_fn(|i| {
643						U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
644					});
645
646					Self::from_fn::<u8>(|i| values[i / 8])
647				}
648				2 => {
649					let values: [u8; 4] = array::from_fn(|i| {
650						U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
651							.spread_to_byte()
652					});
653
654					Self::from_fn::<u8>(|i| values[i / 4])
655				}
656				3 => {
657					let values: [u8; 8] = array::from_fn(|i| {
658						U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
659							.spread_to_byte()
660					});
661
662					Self::from_fn::<u8>(|i| values[i / 2])
663				}
664				4 => {
665					let values: [u8; 16] = array::from_fn(|i| {
666						U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
667							.spread_to_byte()
668					});
669
670					Self::from_fn::<u8>(|i| values[i])
671				}
672				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
673			},
674			3 => match log_block_len {
675				0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
676				1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
677				2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
678				3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
679				4 => self,
680				_ => panic!("unsupported block length"),
681			},
682			4 => match log_block_len {
683				0 => {
684					let value = (u128::from(self) >> (block_idx * 16)) as u16;
685
686					unsafe { _mm_set1_epi16(value as i16).into() }
687				}
688				1 => {
689					let values: [u16; 2] =
690						array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
691
692					Self::from_fn::<u16>(|i| values[i / 4])
693				}
694				2 => {
695					let values: [u16; 4] =
696						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
697
698					Self::from_fn::<u16>(|i| values[i / 2])
699				}
700				3 => self,
701				_ => panic!("unsupported block length"),
702			},
703			5 => match log_block_len {
704				0 => unsafe {
705					let value = (u128::from(self) >> (block_idx * 32)) as u32;
706
707					_mm_set1_epi32(value as i32).into()
708				},
709				1 => {
710					let values: [u32; 2] =
711						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
712
713					Self::from_fn::<u32>(|i| values[i / 2])
714				}
715				2 => self,
716				_ => panic!("unsupported block length"),
717			},
718			6 => match log_block_len {
719				0 => unsafe {
720					let value = (u128::from(self) >> (block_idx * 64)) as u64;
721
722					_mm_set1_epi64x(value as i64).into()
723				},
724				1 => self,
725				_ => panic!("unsupported block length"),
726			},
727			7 => self,
728			_ => panic!("unsupported bit length"),
729		}
730	}
731
732	#[inline]
733	fn shl_128b_lanes(self, shift: usize) -> Self {
734		self << shift
735	}
736
737	#[inline]
738	fn shr_128b_lanes(self, shift: usize) -> Self {
739		self >> shift
740	}
741
742	#[inline]
743	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
744		match log_block_len {
745			0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
746			3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
747			4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
748			5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
749			6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
750			_ => panic!("unsupported block length"),
751		}
752	}
753
754	#[inline]
755	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
756		match log_block_len {
757			0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
758			3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
759			4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
760			5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
761			6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
762			_ => panic!("unsupported block length"),
763		}
764	}
765
766	#[inline]
767	fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
768	where
769		u8: NumCast<Self>,
770		Self: From<u8>,
771	{
772		transpose_128b_values::<Self, TL>(values, 0);
773	}
774
775	#[inline]
776	fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
777	where
778		u8: NumCast<Self>,
779		Self: From<u8>,
780	{
781		if TL::LOG_WIDTH == 0 {
782			return;
783		}
784
785		match TL::LOG_WIDTH {
786			1 => unsafe {
787				let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
788				for v in values.as_mut().iter_mut() {
789					*v = _mm_shuffle_epi8(v.0, shuffle).into();
790				}
791			},
792			2 => unsafe {
793				let shuffle = _mm_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
794				for v in values.as_mut().iter_mut() {
795					*v = _mm_shuffle_epi8(v.0, shuffle).into();
796				}
797			},
798			3 => unsafe {
799				let shuffle = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
800				for v in values.as_mut().iter_mut() {
801					*v = _mm_shuffle_epi8(v.0, shuffle).into();
802				}
803			},
804			4 => {}
805			_ => unreachable!("Log width must be less than 5"),
806		}
807
808		transpose_128b_values::<_, TL>(values, 4 - TL::LOG_WIDTH);
809	}
810}
811
812unsafe impl Zeroable for M128 {}
813
814unsafe impl Pod for M128 {}
815
816unsafe impl Send for M128 {}
817
818unsafe impl Sync for M128 {}
819
820static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
821static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
822static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
823static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
824
825const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
826	log_block_len: usize,
827	t_log_bits: usize,
828) -> [M128; BLOCK_IDX_AMOUNT] {
829	let element_log_width = t_log_bits - 3;
830
831	let element_width = 1 << element_log_width;
832
833	let block_size = 1 << (log_block_len + element_log_width);
834	let repeat = 1 << (4 - element_log_width - log_block_len);
835	let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
836
837	let mut block_idx = 0;
838
839	while block_idx < BLOCK_IDX_AMOUNT {
840		let base = block_idx * block_size;
841		let mut j = 0;
842		while j < 16 {
843			masks[block_idx][j] =
844				(base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
845			j += 1;
846		}
847		block_idx += 1;
848	}
849	let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
850
851	let mut block_idx = 0;
852
853	while block_idx < BLOCK_IDX_AMOUNT {
854		m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
855		block_idx += 1;
856	}
857
858	m128_masks
859}
860
861impl UnderlierWithBitConstants for M128 {
862	const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
863		Self::from_u128(interleave_mask_even!(u128, 0)),
864		Self::from_u128(interleave_mask_even!(u128, 1)),
865		Self::from_u128(interleave_mask_even!(u128, 2)),
866		Self::from_u128(interleave_mask_even!(u128, 3)),
867		Self::from_u128(interleave_mask_even!(u128, 4)),
868		Self::from_u128(interleave_mask_even!(u128, 5)),
869		Self::from_u128(interleave_mask_even!(u128, 6)),
870	];
871
872	const INTERLEAVE_ODD_MASK: &'static [Self] = &[
873		Self::from_u128(interleave_mask_odd!(u128, 0)),
874		Self::from_u128(interleave_mask_odd!(u128, 1)),
875		Self::from_u128(interleave_mask_odd!(u128, 2)),
876		Self::from_u128(interleave_mask_odd!(u128, 3)),
877		Self::from_u128(interleave_mask_odd!(u128, 4)),
878		Self::from_u128(interleave_mask_odd!(u128, 5)),
879		Self::from_u128(interleave_mask_odd!(u128, 6)),
880	];
881
882	#[inline(always)]
883	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
884		unsafe {
885			let (c, d) = interleave_bits(
886				Into::<Self>::into(self).into(),
887				Into::<Self>::into(other).into(),
888				log_block_len,
889			);
890			(Self::from(c), Self::from(d))
891		}
892	}
893}
894
895impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
896	fn from(value: __m128i) -> Self {
897		M128::from(value).into()
898	}
899}
900
901impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
902	fn from(value: u128) -> Self {
903		M128::from(value).into()
904	}
905}
906
907impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
908	fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
909		value.to_underlier().into()
910	}
911}
912
913impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
914where
915	u128: From<Scalar::Underlier>,
916{
917	#[inline(always)]
918	fn broadcast(scalar: Scalar) -> Self {
919		let tower_level = Scalar::N_BITS.ilog2() as usize;
920		match tower_level {
921			0..=3 => {
922				let mut value = u128::from(scalar.to_underlier()) as u8;
923				for n in tower_level..3 {
924					value |= value << (1 << n);
925				}
926
927				unsafe { _mm_set1_epi8(value as i8) }.into()
928			}
929			4 => {
930				let value = u128::from(scalar.to_underlier()) as u16;
931				unsafe { _mm_set1_epi16(value as i16) }.into()
932			}
933			5 => {
934				let value = u128::from(scalar.to_underlier()) as u32;
935				unsafe { _mm_set1_epi32(value as i32) }.into()
936			}
937			6 => {
938				let value = u128::from(scalar.to_underlier()) as u64;
939				unsafe { _mm_set1_epi64x(value as i64) }.into()
940			}
941			7 => {
942				let value = u128::from(scalar.to_underlier());
943				value.into()
944			}
945			_ => {
946				unreachable!("invalid tower level")
947			}
948		}
949	}
950}
951
952#[inline]
953unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
954	match log_block_len {
955		0 => unsafe {
956			let mask = _mm_set1_epi8(0x55i8);
957			interleave_bits_imm::<1>(a, b, mask)
958		},
959		1 => unsafe {
960			let mask = _mm_set1_epi8(0x33i8);
961			interleave_bits_imm::<2>(a, b, mask)
962		},
963		2 => unsafe {
964			let mask = _mm_set1_epi8(0x0fi8);
965			interleave_bits_imm::<4>(a, b, mask)
966		},
967		3 => unsafe {
968			let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
969			let a = _mm_shuffle_epi8(a, shuffle);
970			let b = _mm_shuffle_epi8(b, shuffle);
971			let a_prime = _mm_unpacklo_epi8(a, b);
972			let b_prime = _mm_unpackhi_epi8(a, b);
973			(a_prime, b_prime)
974		},
975		4 => unsafe {
976			let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
977			let a = _mm_shuffle_epi8(a, shuffle);
978			let b = _mm_shuffle_epi8(b, shuffle);
979			let a_prime = _mm_unpacklo_epi16(a, b);
980			let b_prime = _mm_unpackhi_epi16(a, b);
981			(a_prime, b_prime)
982		},
983		5 => unsafe {
984			let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
985			let a = _mm_shuffle_epi8(a, shuffle);
986			let b = _mm_shuffle_epi8(b, shuffle);
987			let a_prime = _mm_unpacklo_epi32(a, b);
988			let b_prime = _mm_unpackhi_epi32(a, b);
989			(a_prime, b_prime)
990		},
991		6 => unsafe {
992			let a_prime = _mm_unpacklo_epi64(a, b);
993			let b_prime = _mm_unpackhi_epi64(a, b);
994			(a_prime, b_prime)
995		},
996		_ => panic!("unsupported block length"),
997	}
998}
999
1000#[inline]
1001unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
1002	a: __m128i,
1003	b: __m128i,
1004	mask: __m128i,
1005) -> (__m128i, __m128i) {
1006	unsafe {
1007		let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
1008		let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
1009		let b_prime = _mm_xor_si128(b, t);
1010		(a_prime, b_prime)
1011	}
1012}
1013
1014impl_iteration!(M128,
1015	@strategy BitIterationStrategy, U1,
1016	@strategy FallbackStrategy, U2, U4,
1017	@strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
1018);
1019
1020#[cfg(test)]
1021mod tests {
1022	use binius_utils::bytes::BytesMut;
1023	use proptest::{arbitrary::any, proptest};
1024	use rand::{SeedableRng, rngs::StdRng};
1025
1026	use super::*;
1027	use crate::underlier::single_element_mask_bits;
1028
1029	fn check_roundtrip<T>(val: M128)
1030	where
1031		T: From<M128>,
1032		M128: From<T>,
1033	{
1034		assert_eq!(M128::from(T::from(val)), val);
1035	}
1036
1037	#[test]
1038	fn test_constants() {
1039		assert_eq!(M128::default(), M128::ZERO);
1040		assert_eq!(M128::from(0u128), M128::ZERO);
1041		assert_eq!(M128::from(1u128), M128::ONE);
1042	}
1043
1044	fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1045		(value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
1046	}
1047
1048	proptest! {
1049		#[test]
1050		fn test_conversion(a in any::<u128>()) {
1051			check_roundtrip::<u128>(a.into());
1052			check_roundtrip::<__m128i>(a.into());
1053		}
1054
1055		#[test]
1056		fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1057			assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1058			assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1059			assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1060		}
1061
1062		#[test]
1063		fn test_negate(a in any::<u128>()) {
1064			assert_eq!(M128::from(!a), !M128::from(a))
1065		}
1066
1067		#[test]
1068		fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1069			assert_eq!(M128::from(a << b), M128::from(a) << b);
1070			assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1071		}
1072
1073		#[test]
1074		fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1075			let a = M128::from(a);
1076			let b = M128::from(b);
1077
1078			let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1079			let (c, d) = (M128::from(c), M128::from(d));
1080
1081			for i in (0..128>>height).step_by(2) {
1082				assert_eq!(get(c, height, i), get(a, height, i));
1083				assert_eq!(get(c, height, i+1), get(b, height, i));
1084				assert_eq!(get(d, height, i), get(a, height, i+1));
1085				assert_eq!(get(d, height, i+1), get(b, height, i+1));
1086			}
1087		}
1088
1089		#[test]
1090		fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1091			let a = M128::from(a);
1092			let b = M128::from(b);
1093
1094			let result = a.unpack_lo_128b_lanes(b, height);
1095			for i in 0..128>>(height + 1) {
1096				assert_eq!(get(result, height, 2*i), get(a, height, i));
1097				assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1098			}
1099		}
1100
1101		#[test]
1102		fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1103			let a = M128::from(a);
1104			let b = M128::from(b);
1105
1106			let result = a.unpack_hi_128b_lanes(b, height);
1107			let half_block_count = 128>>(height + 1);
1108			for i in 0..half_block_count {
1109				assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1110				assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1111			}
1112		}
1113	}
1114
1115	#[test]
1116	fn test_fill_with_bit() {
1117		assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1118		assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1119	}
1120
1121	#[test]
1122	fn test_eq() {
1123		let a = M128::from(0u128);
1124		let b = M128::from(42u128);
1125		let c = M128::from(u128::MAX);
1126
1127		assert_eq!(a, a);
1128		assert_eq!(b, b);
1129		assert_eq!(c, c);
1130
1131		assert_ne!(a, b);
1132		assert_ne!(a, c);
1133		assert_ne!(b, c);
1134	}
1135
1136	#[test]
1137	fn test_serialize_and_deserialize_m128() {
1138		let mode = SerializationMode::Native;
1139
1140		let mut rng = StdRng::from_seed([0; 32]);
1141
1142		let original_value = M128::from(rng.random::<u128>());
1143
1144		let mut buf = BytesMut::new();
1145		original_value.serialize(&mut buf, mode).unwrap();
1146
1147		let deserialized_value = M128::deserialize(buf.freeze(), mode).unwrap();
1148
1149		assert_eq!(original_value, deserialized_value);
1150	}
1151}