binius_field/arch/x86_64/
m128.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	arch::x86_64::*,
5	array,
6	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializationMode, SerializeBytes,
11	bytes::{Buf, BufMut},
12	serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable, must_cast};
15use rand::{Rng, RngCore};
16use seq_macro::seq;
17use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
18
19use crate::{
20	BinaryField,
21	arch::{
22		binary_utils::{as_array_mut, as_array_ref, make_func_to_i8},
23		portable::{
24			packed::{PackedPrimitiveType, impl_pack_scalar},
25			packed_arithmetic::{
26				UnderlierWithBitConstants, interleave_mask_even, interleave_mask_odd,
27			},
28		},
29	},
30	arithmetic_traits::Broadcast,
31	tower_levels::TowerLevel,
32	underlier::{
33		NumCast, Random, SmallU, SpreadToByte, U1, U2, U4, UnderlierType, UnderlierWithBitOps,
34		WithUnderlier, impl_divisible, impl_iteration, spread_fallback, transpose_128b_values,
35		unpack_hi_128b_fallback, unpack_lo_128b_fallback,
36	},
37};
38
39/// 128-bit value that is used for 128-bit SIMD operations
40#[derive(Copy, Clone)]
41#[repr(transparent)]
42pub struct M128(pub(super) __m128i);
43
44impl M128 {
45	#[inline(always)]
46	pub const fn from_u128(val: u128) -> Self {
47		let mut result = Self::ZERO;
48		unsafe {
49			result.0 = std::mem::transmute_copy(&val);
50		}
51
52		result
53	}
54}
55
56impl From<__m128i> for M128 {
57	#[inline(always)]
58	fn from(value: __m128i) -> Self {
59		Self(value)
60	}
61}
62
63impl From<u128> for M128 {
64	fn from(value: u128) -> Self {
65		Self(unsafe { _mm_loadu_si128(&raw const value as *const __m128i) })
66	}
67}
68
69impl From<u64> for M128 {
70	fn from(value: u64) -> Self {
71		Self::from(value as u128)
72	}
73}
74
75impl From<u32> for M128 {
76	fn from(value: u32) -> Self {
77		Self::from(value as u128)
78	}
79}
80
81impl From<u16> for M128 {
82	fn from(value: u16) -> Self {
83		Self::from(value as u128)
84	}
85}
86
87impl From<u8> for M128 {
88	fn from(value: u8) -> Self {
89		Self::from(value as u128)
90	}
91}
92
93impl<const N: usize> From<SmallU<N>> for M128 {
94	fn from(value: SmallU<N>) -> Self {
95		Self::from(value.val() as u128)
96	}
97}
98
99impl From<M128> for u128 {
100	fn from(value: M128) -> Self {
101		let mut result = 0u128;
102		unsafe { _mm_storeu_si128(&raw mut result as *mut __m128i, value.0) };
103
104		result
105	}
106}
107
108impl From<M128> for __m128i {
109	#[inline(always)]
110	fn from(value: M128) -> Self {
111		value.0
112	}
113}
114
115impl SerializeBytes for M128 {
116	fn serialize(
117		&self,
118		mut write_buf: impl BufMut,
119		_mode: SerializationMode,
120	) -> Result<(), SerializationError> {
121		assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
122
123		let raw_value: u128 = (*self).into();
124
125		write_buf.put_u128_le(raw_value);
126		Ok(())
127	}
128}
129
130impl DeserializeBytes for M128 {
131	fn deserialize(
132		mut read_buf: impl Buf,
133		_mode: SerializationMode,
134	) -> Result<Self, SerializationError>
135	where
136		Self: Sized,
137	{
138		assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
139
140		let raw_value = read_buf.get_u128_le();
141
142		Ok(Self::from(raw_value))
143	}
144}
145
146impl_divisible!(@pairs M128, u128, u64, u32, u16, u8);
147impl_pack_scalar!(M128);
148
149impl<U: NumCast<u128>> NumCast<M128> for U {
150	#[inline(always)]
151	fn num_cast_from(val: M128) -> Self {
152		Self::num_cast_from(u128::from(val))
153	}
154}
155
156impl Default for M128 {
157	#[inline(always)]
158	fn default() -> Self {
159		Self(unsafe { _mm_setzero_si128() })
160	}
161}
162
163impl BitAnd for M128 {
164	type Output = Self;
165
166	#[inline(always)]
167	fn bitand(self, rhs: Self) -> Self::Output {
168		Self(unsafe { _mm_and_si128(self.0, rhs.0) })
169	}
170}
171
172impl BitAndAssign for M128 {
173	#[inline(always)]
174	fn bitand_assign(&mut self, rhs: Self) {
175		*self = *self & rhs
176	}
177}
178
179impl BitOr for M128 {
180	type Output = Self;
181
182	#[inline(always)]
183	fn bitor(self, rhs: Self) -> Self::Output {
184		Self(unsafe { _mm_or_si128(self.0, rhs.0) })
185	}
186}
187
188impl BitOrAssign for M128 {
189	#[inline(always)]
190	fn bitor_assign(&mut self, rhs: Self) {
191		*self = *self | rhs
192	}
193}
194
195impl BitXor for M128 {
196	type Output = Self;
197
198	#[inline(always)]
199	fn bitxor(self, rhs: Self) -> Self::Output {
200		Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
201	}
202}
203
204impl BitXorAssign for M128 {
205	#[inline(always)]
206	fn bitxor_assign(&mut self, rhs: Self) {
207		*self = *self ^ rhs;
208	}
209}
210
211impl Not for M128 {
212	type Output = Self;
213
214	fn not(self) -> Self::Output {
215		const ONES: __m128i = m128_from_u128!(u128::MAX);
216
217		self ^ Self(ONES)
218	}
219}
220
221/// `std::cmp::max` isn't const, so we need our own implementation
222pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
223	if left > right { left } else { right }
224}
225
226/// This solution shows 4X better performance.
227/// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed
228/// as constant and Rust currently doesn't allow passing expressions (`count - 64`) where variable
229/// is a generic constant parameter. Source: https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688
230macro_rules! bitshift_128b {
231	($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
232		unsafe {
233			let carry = $byte_shift($val, 8);
234			seq!(N in 64..128 {
235				if $shift == N {
236					return $bit_shift_64(
237						carry,
238						crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
239					).into();
240				}
241			});
242			seq!(N in 0..64 {
243				if $shift == N {
244					let carry = $bit_shift_64_opposite(
245						carry,
246						crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
247					);
248
249					let val = $bit_shift_64($val, N);
250					return $or(val, carry).into();
251				}
252			});
253
254			return Default::default()
255		}
256	};
257}
258
259pub(crate) use bitshift_128b;
260
261impl Shr<usize> for M128 {
262	type Output = Self;
263
264	#[inline(always)]
265	fn shr(self, rhs: usize) -> Self::Output {
266		// This implementation is effective when `rhs` is known at compile-time.
267		// In our code this is always the case.
268		bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
269	}
270}
271
272impl Shl<usize> for M128 {
273	type Output = Self;
274
275	#[inline(always)]
276	fn shl(self, rhs: usize) -> Self::Output {
277		// This implementation is effective when `rhs` is known at compile-time.
278		// In our code this is always the case.
279		bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
280	}
281}
282
283impl PartialEq for M128 {
284	fn eq(&self, other: &Self) -> bool {
285		unsafe {
286			let neq = _mm_xor_si128(self.0, other.0);
287			_mm_test_all_zeros(neq, neq) == 1
288		}
289	}
290}
291
292impl Eq for M128 {}
293
294impl PartialOrd for M128 {
295	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
296		Some(self.cmp(other))
297	}
298}
299
300impl Ord for M128 {
301	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
302		u128::from(*self).cmp(&u128::from(*other))
303	}
304}
305
306impl ConstantTimeEq for M128 {
307	fn ct_eq(&self, other: &Self) -> Choice {
308		unsafe {
309			let neq = _mm_xor_si128(self.0, other.0);
310			Choice::from(_mm_test_all_zeros(neq, neq) as u8)
311		}
312	}
313}
314
315impl ConditionallySelectable for M128 {
316	fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
317		ConditionallySelectable::conditional_select(&u128::from(*a), &u128::from(*b), choice).into()
318	}
319}
320
321impl Random for M128 {
322	fn random(mut rng: impl RngCore) -> Self {
323		let val: u128 = rng.r#gen();
324		val.into()
325	}
326}
327
328impl std::fmt::Display for M128 {
329	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
330		let data: u128 = (*self).into();
331		write!(f, "{data:02X?}")
332	}
333}
334
335impl std::fmt::Debug for M128 {
336	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337		write!(f, "M128({self})")
338	}
339}
340
341#[repr(align(16))]
342pub struct AlignedData(pub [u128; 1]);
343
344macro_rules! m128_from_u128 {
345	($val:expr) => {{
346		let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
347		unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
348	}};
349}
350
351pub(super) use m128_from_u128;
352
353impl UnderlierType for M128 {
354	const LOG_BITS: usize = 7;
355}
356
357impl UnderlierWithBitOps for M128 {
358	const ZERO: Self = { Self(m128_from_u128!(0)) };
359	const ONE: Self = { Self(m128_from_u128!(1)) };
360	const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
361
362	#[inline(always)]
363	fn fill_with_bit(val: u8) -> Self {
364		assert!(val == 0 || val == 1);
365		Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
366	}
367
368	#[inline(always)]
369	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
370	where
371		T: UnderlierType,
372		Self: From<T>,
373	{
374		match T::BITS {
375			1 | 2 | 4 => {
376				let mut f = make_func_to_i8::<T, Self>(f);
377
378				unsafe {
379					_mm_set_epi8(
380						f(15),
381						f(14),
382						f(13),
383						f(12),
384						f(11),
385						f(10),
386						f(9),
387						f(8),
388						f(7),
389						f(6),
390						f(5),
391						f(4),
392						f(3),
393						f(2),
394						f(1),
395						f(0),
396					)
397				}
398				.into()
399			}
400			8 => {
401				let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
402				unsafe {
403					_mm_set_epi8(
404						f(15),
405						f(14),
406						f(13),
407						f(12),
408						f(11),
409						f(10),
410						f(9),
411						f(8),
412						f(7),
413						f(6),
414						f(5),
415						f(4),
416						f(3),
417						f(2),
418						f(1),
419						f(0),
420					)
421				}
422				.into()
423			}
424			16 => {
425				let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
426				unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
427			}
428			32 => {
429				let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
430				unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
431			}
432			64 => {
433				let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
434				unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
435			}
436			128 => Self::from(f(0)),
437			_ => panic!("unsupported bit count"),
438		}
439	}
440
441	#[inline(always)]
442	unsafe fn get_subvalue<T>(&self, i: usize) -> T
443	where
444		T: UnderlierType + NumCast<Self>,
445	{
446		match T::BITS {
447			1 | 2 | 4 => {
448				let elements_in_8 = 8 / T::BITS;
449				let mut value_u8 = as_array_ref::<_, u8, 16, _>(self, |arr| unsafe {
450					*arr.get_unchecked(i / elements_in_8)
451				});
452
453				let shift = (i % elements_in_8) * T::BITS;
454				value_u8 >>= shift;
455
456				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
457			}
458			8 => {
459				let value_u8 =
460					as_array_ref::<_, u8, 16, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
461				T::from_underlier(T::num_cast_from(Self::from(value_u8)))
462			}
463			16 => {
464				let value_u16 =
465					as_array_ref::<_, u16, 8, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
466				T::from_underlier(T::num_cast_from(Self::from(value_u16)))
467			}
468			32 => {
469				let value_u32 =
470					as_array_ref::<_, u32, 4, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
471				T::from_underlier(T::num_cast_from(Self::from(value_u32)))
472			}
473			64 => {
474				let value_u64 =
475					as_array_ref::<_, u64, 2, _>(self, |arr| unsafe { *arr.get_unchecked(i) });
476				T::from_underlier(T::num_cast_from(Self::from(value_u64)))
477			}
478			128 => T::from_underlier(T::num_cast_from(*self)),
479			_ => panic!("unsupported bit count"),
480		}
481	}
482
483	#[inline(always)]
484	unsafe fn set_subvalue<T>(&mut self, i: usize, val: T)
485	where
486		T: UnderlierWithBitOps,
487		Self: From<T>,
488	{
489		match T::BITS {
490			1 | 2 | 4 => {
491				let elements_in_8 = 8 / T::BITS;
492				let mask = (1u8 << T::BITS) - 1;
493				let shift = (i % elements_in_8) * T::BITS;
494				let val = u8::num_cast_from(Self::from(val)) << shift;
495				let mask = mask << shift;
496
497				as_array_mut::<_, u8, 16>(self, |array| unsafe {
498					let element = array.get_unchecked_mut(i / elements_in_8);
499					*element &= !mask;
500					*element |= val;
501				});
502			}
503			8 => as_array_mut::<_, u8, 16>(self, |array| unsafe {
504				*array.get_unchecked_mut(i) = u8::num_cast_from(Self::from(val));
505			}),
506			16 => as_array_mut::<_, u16, 8>(self, |array| unsafe {
507				*array.get_unchecked_mut(i) = u16::num_cast_from(Self::from(val));
508			}),
509			32 => as_array_mut::<_, u32, 4>(self, |array| unsafe {
510				*array.get_unchecked_mut(i) = u32::num_cast_from(Self::from(val));
511			}),
512			64 => as_array_mut::<_, u64, 2>(self, |array| unsafe {
513				*array.get_unchecked_mut(i) = u64::num_cast_from(Self::from(val));
514			}),
515			128 => {
516				*self = Self::from(val);
517			}
518			_ => panic!("unsupported bit count"),
519		}
520	}
521
522	#[inline(always)]
523	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
524	where
525		T: UnderlierWithBitOps + NumCast<Self>,
526		Self: From<T>,
527	{
528		match T::LOG_BITS {
529			0 => match log_block_len {
530				0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
531				1 => unsafe {
532					let bits: [u8; 2] =
533						array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
534
535					_mm_set_epi64x(
536						u64::fill_with_bit(bits[1]) as i64,
537						u64::fill_with_bit(bits[0]) as i64,
538					)
539					.into()
540				},
541				2 => unsafe {
542					let bits: [u8; 4] =
543						array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
544
545					_mm_set_epi32(
546						u32::fill_with_bit(bits[3]) as i32,
547						u32::fill_with_bit(bits[2]) as i32,
548						u32::fill_with_bit(bits[1]) as i32,
549						u32::fill_with_bit(bits[0]) as i32,
550					)
551					.into()
552				},
553				3 => unsafe {
554					let bits: [u8; 8] =
555						array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
556
557					_mm_set_epi16(
558						u16::fill_with_bit(bits[7]) as i16,
559						u16::fill_with_bit(bits[6]) as i16,
560						u16::fill_with_bit(bits[5]) as i16,
561						u16::fill_with_bit(bits[4]) as i16,
562						u16::fill_with_bit(bits[3]) as i16,
563						u16::fill_with_bit(bits[2]) as i16,
564						u16::fill_with_bit(bits[1]) as i16,
565						u16::fill_with_bit(bits[0]) as i16,
566					)
567					.into()
568				},
569				4 => unsafe {
570					let bits: [u8; 16] =
571						array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
572
573					_mm_set_epi8(
574						u8::fill_with_bit(bits[15]) as i8,
575						u8::fill_with_bit(bits[14]) as i8,
576						u8::fill_with_bit(bits[13]) as i8,
577						u8::fill_with_bit(bits[12]) as i8,
578						u8::fill_with_bit(bits[11]) as i8,
579						u8::fill_with_bit(bits[10]) as i8,
580						u8::fill_with_bit(bits[9]) as i8,
581						u8::fill_with_bit(bits[8]) as i8,
582						u8::fill_with_bit(bits[7]) as i8,
583						u8::fill_with_bit(bits[6]) as i8,
584						u8::fill_with_bit(bits[5]) as i8,
585						u8::fill_with_bit(bits[4]) as i8,
586						u8::fill_with_bit(bits[3]) as i8,
587						u8::fill_with_bit(bits[2]) as i8,
588						u8::fill_with_bit(bits[1]) as i8,
589						u8::fill_with_bit(bits[0]) as i8,
590					)
591					.into()
592				},
593				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
594			},
595			1 => match log_block_len {
596				0 => unsafe {
597					let value =
598						U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
599
600					_mm_set1_epi8(value as i8).into()
601				},
602				1 => {
603					let bytes: [u8; 2] = array::from_fn(|i| {
604						U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
605					});
606
607					Self::from_fn::<u8>(|i| bytes[i / 8])
608				}
609				2 => {
610					let bytes: [u8; 4] = array::from_fn(|i| {
611						U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
612					});
613
614					Self::from_fn::<u8>(|i| bytes[i / 4])
615				}
616				3 => {
617					let bytes: [u8; 8] = array::from_fn(|i| {
618						U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
619							.spread_to_byte()
620					});
621
622					Self::from_fn::<u8>(|i| bytes[i / 2])
623				}
624				4 => {
625					let bytes: [u8; 16] = array::from_fn(|i| {
626						U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
627							.spread_to_byte()
628					});
629
630					Self::from_fn::<u8>(|i| bytes[i])
631				}
632				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
633			},
634			2 => match log_block_len {
635				0 => {
636					let value =
637						U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
638
639					unsafe { _mm_set1_epi8(value as i8).into() }
640				}
641				1 => {
642					let values: [u8; 2] = array::from_fn(|i| {
643						U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
644					});
645
646					Self::from_fn::<u8>(|i| values[i / 8])
647				}
648				2 => {
649					let values: [u8; 4] = array::from_fn(|i| {
650						U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
651							.spread_to_byte()
652					});
653
654					Self::from_fn::<u8>(|i| values[i / 4])
655				}
656				3 => {
657					let values: [u8; 8] = array::from_fn(|i| {
658						U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
659							.spread_to_byte()
660					});
661
662					Self::from_fn::<u8>(|i| values[i / 2])
663				}
664				4 => {
665					let values: [u8; 16] = array::from_fn(|i| {
666						U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
667							.spread_to_byte()
668					});
669
670					Self::from_fn::<u8>(|i| values[i])
671				}
672				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
673			},
674			3 => match log_block_len {
675				0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
676				1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
677				2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
678				3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
679				4 => self,
680				_ => panic!("unsupported block length"),
681			},
682			4 => match log_block_len {
683				0 => {
684					let value = (u128::from(self) >> (block_idx * 16)) as u16;
685
686					unsafe { _mm_set1_epi16(value as i16).into() }
687				}
688				1 => {
689					let values: [u16; 2] =
690						array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
691
692					Self::from_fn::<u16>(|i| values[i / 4])
693				}
694				2 => {
695					let values: [u16; 4] =
696						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
697
698					Self::from_fn::<u16>(|i| values[i / 2])
699				}
700				3 => self,
701				_ => panic!("unsupported block length"),
702			},
703			5 => match log_block_len {
704				0 => unsafe {
705					let value = (u128::from(self) >> (block_idx * 32)) as u32;
706
707					_mm_set1_epi32(value as i32).into()
708				},
709				1 => {
710					let values: [u32; 2] =
711						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
712
713					Self::from_fn::<u32>(|i| values[i / 2])
714				}
715				2 => self,
716				_ => panic!("unsupported block length"),
717			},
718			6 => match log_block_len {
719				0 => unsafe {
720					let value = (u128::from(self) >> (block_idx * 64)) as u64;
721
722					_mm_set1_epi64x(value as i64).into()
723				},
724				1 => self,
725				_ => panic!("unsupported block length"),
726			},
727			7 => self,
728			_ => panic!("unsupported bit length"),
729		}
730	}
731
732	#[inline]
733	fn shl_128b_lanes(self, shift: usize) -> Self {
734		self << shift
735	}
736
737	#[inline]
738	fn shr_128b_lanes(self, shift: usize) -> Self {
739		self >> shift
740	}
741
742	#[inline]
743	fn unpack_lo_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
744		match log_block_len {
745			0..3 => unpack_lo_128b_fallback(self, other, log_block_len),
746			3 => unsafe { _mm_unpacklo_epi8(self.0, other.0).into() },
747			4 => unsafe { _mm_unpacklo_epi16(self.0, other.0).into() },
748			5 => unsafe { _mm_unpacklo_epi32(self.0, other.0).into() },
749			6 => unsafe { _mm_unpacklo_epi64(self.0, other.0).into() },
750			_ => panic!("unsupported block length"),
751		}
752	}
753
754	#[inline]
755	fn unpack_hi_128b_lanes(self, other: Self, log_block_len: usize) -> Self {
756		match log_block_len {
757			0..3 => unpack_hi_128b_fallback(self, other, log_block_len),
758			3 => unsafe { _mm_unpackhi_epi8(self.0, other.0).into() },
759			4 => unsafe { _mm_unpackhi_epi16(self.0, other.0).into() },
760			5 => unsafe { _mm_unpackhi_epi32(self.0, other.0).into() },
761			6 => unsafe { _mm_unpackhi_epi64(self.0, other.0).into() },
762			_ => panic!("unsupported block length"),
763		}
764	}
765
766	#[inline]
767	fn transpose_bytes_from_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
768	where
769		u8: NumCast<Self>,
770		Self: From<u8>,
771	{
772		transpose_128b_values::<Self, TL>(values, 0);
773	}
774
775	#[inline]
776	fn transpose_bytes_to_byte_sliced<TL: TowerLevel>(values: &mut TL::Data<Self>)
777	where
778		u8: NumCast<Self>,
779		Self: From<u8>,
780	{
781		if TL::LOG_WIDTH == 0 {
782			return;
783		}
784
785		match TL::LOG_WIDTH {
786			1 => unsafe {
787				let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
788				for v in values.as_mut().iter_mut() {
789					*v = _mm_shuffle_epi8(v.0, shuffle).into();
790				}
791			},
792			2 => unsafe {
793				let shuffle = _mm_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0);
794				for v in values.as_mut().iter_mut() {
795					*v = _mm_shuffle_epi8(v.0, shuffle).into();
796				}
797			},
798			3 => unsafe {
799				let shuffle = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
800				for v in values.as_mut().iter_mut() {
801					*v = _mm_shuffle_epi8(v.0, shuffle).into();
802				}
803			},
804			4 => {}
805			_ => unreachable!("Log width must be less than 5"),
806		}
807
808		transpose_128b_values::<_, TL>(values, 4 - TL::LOG_WIDTH);
809	}
810}
811
812unsafe impl Zeroable for M128 {}
813
814unsafe impl Pod for M128 {}
815
816unsafe impl Send for M128 {}
817
818unsafe impl Sync for M128 {}
819
820static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
821static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
822static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
823static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
824
825const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
826	log_block_len: usize,
827	t_log_bits: usize,
828) -> [M128; BLOCK_IDX_AMOUNT] {
829	let element_log_width = t_log_bits - 3;
830
831	let element_width = 1 << element_log_width;
832
833	let block_size = 1 << (log_block_len + element_log_width);
834	let repeat = 1 << (4 - element_log_width - log_block_len);
835	let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
836
837	let mut block_idx = 0;
838
839	while block_idx < BLOCK_IDX_AMOUNT {
840		let base = block_idx * block_size;
841		let mut j = 0;
842		while j < 16 {
843			masks[block_idx][j] =
844				(base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
845			j += 1;
846		}
847		block_idx += 1;
848	}
849	let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
850
851	let mut block_idx = 0;
852
853	while block_idx < BLOCK_IDX_AMOUNT {
854		m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
855		block_idx += 1;
856	}
857
858	m128_masks
859}
860
861impl UnderlierWithBitConstants for M128 {
862	const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
863		Self::from_u128(interleave_mask_even!(u128, 0)),
864		Self::from_u128(interleave_mask_even!(u128, 1)),
865		Self::from_u128(interleave_mask_even!(u128, 2)),
866		Self::from_u128(interleave_mask_even!(u128, 3)),
867		Self::from_u128(interleave_mask_even!(u128, 4)),
868		Self::from_u128(interleave_mask_even!(u128, 5)),
869		Self::from_u128(interleave_mask_even!(u128, 6)),
870	];
871
872	const INTERLEAVE_ODD_MASK: &'static [Self] = &[
873		Self::from_u128(interleave_mask_odd!(u128, 0)),
874		Self::from_u128(interleave_mask_odd!(u128, 1)),
875		Self::from_u128(interleave_mask_odd!(u128, 2)),
876		Self::from_u128(interleave_mask_odd!(u128, 3)),
877		Self::from_u128(interleave_mask_odd!(u128, 4)),
878		Self::from_u128(interleave_mask_odd!(u128, 5)),
879		Self::from_u128(interleave_mask_odd!(u128, 6)),
880	];
881
882	#[inline(always)]
883	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
884		unsafe {
885			let (c, d) = interleave_bits(
886				Into::<Self>::into(self).into(),
887				Into::<Self>::into(other).into(),
888				log_block_len,
889			);
890			(Self::from(c), Self::from(d))
891		}
892	}
893}
894
895impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
896	fn from(value: __m128i) -> Self {
897		M128::from(value).into()
898	}
899}
900
901impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
902	fn from(value: u128) -> Self {
903		M128::from(value).into()
904	}
905}
906
907impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
908	fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
909		value.to_underlier().into()
910	}
911}
912
913impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
914where
915	u128: From<Scalar::Underlier>,
916{
917	#[inline(always)]
918	fn broadcast(scalar: Scalar) -> Self {
919		let tower_level = Scalar::N_BITS.ilog2() as usize;
920		let mut value = u128::from(scalar.to_underlier());
921		for n in tower_level..3 {
922			value |= value << (1 << n);
923		}
924
925		let value = must_cast(value);
926		let value = match tower_level {
927			0..=3 => unsafe { _mm_broadcastb_epi8(value) },
928			4 => unsafe { _mm_broadcastw_epi16(value) },
929			5 => unsafe { _mm_broadcastd_epi32(value) },
930			6 => unsafe { _mm_broadcastq_epi64(value) },
931			7 => value,
932			_ => unreachable!(),
933		};
934
935		value.into()
936	}
937}
938
939#[inline]
940unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
941	match log_block_len {
942		0 => unsafe {
943			let mask = _mm_set1_epi8(0x55i8);
944			interleave_bits_imm::<1>(a, b, mask)
945		},
946		1 => unsafe {
947			let mask = _mm_set1_epi8(0x33i8);
948			interleave_bits_imm::<2>(a, b, mask)
949		},
950		2 => unsafe {
951			let mask = _mm_set1_epi8(0x0fi8);
952			interleave_bits_imm::<4>(a, b, mask)
953		},
954		3 => unsafe {
955			let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
956			let a = _mm_shuffle_epi8(a, shuffle);
957			let b = _mm_shuffle_epi8(b, shuffle);
958			let a_prime = _mm_unpacklo_epi8(a, b);
959			let b_prime = _mm_unpackhi_epi8(a, b);
960			(a_prime, b_prime)
961		},
962		4 => unsafe {
963			let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
964			let a = _mm_shuffle_epi8(a, shuffle);
965			let b = _mm_shuffle_epi8(b, shuffle);
966			let a_prime = _mm_unpacklo_epi16(a, b);
967			let b_prime = _mm_unpackhi_epi16(a, b);
968			(a_prime, b_prime)
969		},
970		5 => unsafe {
971			let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
972			let a = _mm_shuffle_epi8(a, shuffle);
973			let b = _mm_shuffle_epi8(b, shuffle);
974			let a_prime = _mm_unpacklo_epi32(a, b);
975			let b_prime = _mm_unpackhi_epi32(a, b);
976			(a_prime, b_prime)
977		},
978		6 => unsafe {
979			let a_prime = _mm_unpacklo_epi64(a, b);
980			let b_prime = _mm_unpackhi_epi64(a, b);
981			(a_prime, b_prime)
982		},
983		_ => panic!("unsupported block length"),
984	}
985}
986
987#[inline]
988unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
989	a: __m128i,
990	b: __m128i,
991	mask: __m128i,
992) -> (__m128i, __m128i) {
993	unsafe {
994		let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
995		let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
996		let b_prime = _mm_xor_si128(b, t);
997		(a_prime, b_prime)
998	}
999}
1000
1001impl_iteration!(M128,
1002	@strategy BitIterationStrategy, U1,
1003	@strategy FallbackStrategy, U2, U4,
1004	@strategy DivisibleStrategy, u8, u16, u32, u64, u128, M128,
1005);
1006
1007#[cfg(test)]
1008mod tests {
1009	use binius_utils::bytes::BytesMut;
1010	use proptest::{arbitrary::any, proptest};
1011	use rand::{SeedableRng, rngs::StdRng};
1012
1013	use super::*;
1014	use crate::underlier::single_element_mask_bits;
1015
1016	fn check_roundtrip<T>(val: M128)
1017	where
1018		T: From<M128>,
1019		M128: From<T>,
1020	{
1021		assert_eq!(M128::from(T::from(val)), val);
1022	}
1023
1024	#[test]
1025	fn test_constants() {
1026		assert_eq!(M128::default(), M128::ZERO);
1027		assert_eq!(M128::from(0u128), M128::ZERO);
1028		assert_eq!(M128::from(1u128), M128::ONE);
1029	}
1030
1031	fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1032		(value >> (index << log_block_len)) & single_element_mask_bits::<M128>(1 << log_block_len)
1033	}
1034
1035	proptest! {
1036		#[test]
1037		fn test_conversion(a in any::<u128>()) {
1038			check_roundtrip::<u128>(a.into());
1039			check_roundtrip::<__m128i>(a.into());
1040		}
1041
1042		#[test]
1043		fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1044			assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1045			assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1046			assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1047		}
1048
1049		#[test]
1050		fn test_negate(a in any::<u128>()) {
1051			assert_eq!(M128::from(!a), !M128::from(a))
1052		}
1053
1054		#[test]
1055		fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1056			assert_eq!(M128::from(a << b), M128::from(a) << b);
1057			assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1058		}
1059
1060		#[test]
1061		fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1062			let a = M128::from(a);
1063			let b = M128::from(b);
1064
1065			let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1066			let (c, d) = (M128::from(c), M128::from(d));
1067
1068			for i in (0..128>>height).step_by(2) {
1069				assert_eq!(get(c, height, i), get(a, height, i));
1070				assert_eq!(get(c, height, i+1), get(b, height, i));
1071				assert_eq!(get(d, height, i), get(a, height, i+1));
1072				assert_eq!(get(d, height, i+1), get(b, height, i+1));
1073			}
1074		}
1075
1076		#[test]
1077		fn test_unpack_lo(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1078			let a = M128::from(a);
1079			let b = M128::from(b);
1080
1081			let result = a.unpack_lo_128b_lanes(b, height);
1082			for i in 0..128>>(height + 1) {
1083				assert_eq!(get(result, height, 2*i), get(a, height, i));
1084				assert_eq!(get(result, height, 2*i+1), get(b, height, i));
1085			}
1086		}
1087
1088		#[test]
1089		fn test_unpack_hi(a in any::<u128>(), b in any::<u128>(), height in 1usize..7) {
1090			let a = M128::from(a);
1091			let b = M128::from(b);
1092
1093			let result = a.unpack_hi_128b_lanes(b, height);
1094			let half_block_count = 128>>(height + 1);
1095			for i in 0..half_block_count {
1096				assert_eq!(get(result, height, 2*i), get(a, height, i + half_block_count));
1097				assert_eq!(get(result, height, 2*i+1), get(b, height, i + half_block_count));
1098			}
1099		}
1100	}
1101
1102	#[test]
1103	fn test_fill_with_bit() {
1104		assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1105		assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1106	}
1107
1108	#[test]
1109	fn test_eq() {
1110		let a = M128::from(0u128);
1111		let b = M128::from(42u128);
1112		let c = M128::from(u128::MAX);
1113
1114		assert_eq!(a, a);
1115		assert_eq!(b, b);
1116		assert_eq!(c, c);
1117
1118		assert_ne!(a, b);
1119		assert_ne!(a, c);
1120		assert_ne!(b, c);
1121	}
1122
1123	#[test]
1124	fn test_serialize_and_deserialize_m128() {
1125		let mode = SerializationMode::Native;
1126
1127		let mut rng = StdRng::from_seed([0; 32]);
1128
1129		let original_value = M128::from(rng.r#gen::<u128>());
1130
1131		let mut buf = BytesMut::new();
1132		original_value.serialize(&mut buf, mode).unwrap();
1133
1134		let deserialized_value = M128::deserialize(buf.freeze(), mode).unwrap();
1135
1136		assert_eq!(original_value, deserialized_value);
1137	}
1138}