binius_field/arch/x86_64/
m128.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	arch::x86_64::*,
5	array,
6	ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not, Shl, Shr},
7};
8
9use binius_utils::{
10	DeserializeBytes, SerializationError, SerializeBytes,
11	bytes::{Buf, BufMut},
12	serialization::{assert_enough_data_for, assert_enough_space_for},
13};
14use bytemuck::{Pod, Zeroable};
15use rand::{
16	Rng,
17	distr::{Distribution, StandardUniform},
18};
19use seq_macro::seq;
20
21use crate::{
22	BinaryField,
23	arch::{
24		binary_utils::make_func_to_i8,
25		portable::{
26			packed::{PackedPrimitiveType, impl_pack_scalar},
27			packed_arithmetic::{
28				UnderlierWithBitConstants, interleave_mask_even, interleave_mask_odd,
29			},
30		},
31	},
32	arithmetic_traits::Broadcast,
33	underlier::{
34		Divisible, NumCast, SmallU, SpreadToByte, U2, U4, UnderlierType, UnderlierWithBitOps,
35		WithUnderlier, impl_divisible_bitmask, mapget, spread_fallback,
36	},
37};
38
39/// 128-bit value that is used for 128-bit SIMD operations
40#[derive(Copy, Clone)]
41#[repr(transparent)]
42pub struct M128(pub(super) __m128i);
43
44impl M128 {
45	#[inline(always)]
46	pub const fn from_u128(val: u128) -> Self {
47		let mut result = Self::ZERO;
48		unsafe {
49			result.0 = std::mem::transmute_copy(&val);
50		}
51
52		result
53	}
54}
55
56impl From<__m128i> for M128 {
57	#[inline(always)]
58	fn from(value: __m128i) -> Self {
59		Self(value)
60	}
61}
62
63impl From<u128> for M128 {
64	fn from(value: u128) -> Self {
65		Self(unsafe {
66			// Safety: u128 is 16-byte aligned
67			assert_eq!(align_of::<u128>(), 16);
68			_mm_load_si128(&raw const value as *const __m128i)
69		})
70	}
71}
72
73impl From<u64> for M128 {
74	fn from(value: u64) -> Self {
75		Self::from(value as u128)
76	}
77}
78
79impl From<u32> for M128 {
80	fn from(value: u32) -> Self {
81		Self::from(value as u128)
82	}
83}
84
85impl From<u16> for M128 {
86	fn from(value: u16) -> Self {
87		Self::from(value as u128)
88	}
89}
90
91impl From<u8> for M128 {
92	fn from(value: u8) -> Self {
93		Self::from(value as u128)
94	}
95}
96
97impl<const N: usize> From<SmallU<N>> for M128 {
98	fn from(value: SmallU<N>) -> Self {
99		Self::from(value.val() as u128)
100	}
101}
102
103impl From<M128> for u128 {
104	fn from(value: M128) -> Self {
105		let mut result = 0u128;
106		unsafe {
107			// Safety: u128 is 16-byte aligned
108			assert_eq!(align_of::<u128>(), 16);
109			_mm_store_si128(&raw mut result as *mut __m128i, value.0)
110		};
111		result
112	}
113}
114
115impl From<M128> for __m128i {
116	#[inline(always)]
117	fn from(value: M128) -> Self {
118		value.0
119	}
120}
121
122impl SerializeBytes for M128 {
123	fn serialize(&self, mut write_buf: impl BufMut) -> Result<(), SerializationError> {
124		assert_enough_space_for(&write_buf, std::mem::size_of::<Self>())?;
125
126		let raw_value: u128 = (*self).into();
127
128		write_buf.put_u128_le(raw_value);
129		Ok(())
130	}
131}
132
133impl DeserializeBytes for M128 {
134	fn deserialize(mut read_buf: impl Buf) -> Result<Self, SerializationError>
135	where
136		Self: Sized,
137	{
138		assert_enough_data_for(&read_buf, std::mem::size_of::<Self>())?;
139
140		let raw_value = read_buf.get_u128_le();
141
142		Ok(Self::from(raw_value))
143	}
144}
145
146impl_divisible_bitmask!(M128, 1, 2, 4);
147impl_pack_scalar!(M128);
148
149impl<U: NumCast<u128>> NumCast<M128> for U {
150	#[inline(always)]
151	fn num_cast_from(val: M128) -> Self {
152		Self::num_cast_from(u128::from(val))
153	}
154}
155
156impl Default for M128 {
157	#[inline(always)]
158	fn default() -> Self {
159		Self(unsafe { _mm_setzero_si128() })
160	}
161}
162
163impl BitAnd for M128 {
164	type Output = Self;
165
166	#[inline(always)]
167	fn bitand(self, rhs: Self) -> Self::Output {
168		Self(unsafe { _mm_and_si128(self.0, rhs.0) })
169	}
170}
171
172impl BitAndAssign for M128 {
173	#[inline(always)]
174	fn bitand_assign(&mut self, rhs: Self) {
175		*self = *self & rhs
176	}
177}
178
179impl BitOr for M128 {
180	type Output = Self;
181
182	#[inline(always)]
183	fn bitor(self, rhs: Self) -> Self::Output {
184		Self(unsafe { _mm_or_si128(self.0, rhs.0) })
185	}
186}
187
188impl BitOrAssign for M128 {
189	#[inline(always)]
190	fn bitor_assign(&mut self, rhs: Self) {
191		*self = *self | rhs
192	}
193}
194
195impl BitXor for M128 {
196	type Output = Self;
197
198	#[inline(always)]
199	fn bitxor(self, rhs: Self) -> Self::Output {
200		Self(unsafe { _mm_xor_si128(self.0, rhs.0) })
201	}
202}
203
204impl BitXorAssign for M128 {
205	#[inline(always)]
206	fn bitxor_assign(&mut self, rhs: Self) {
207		*self = *self ^ rhs;
208	}
209}
210
211impl Not for M128 {
212	type Output = Self;
213
214	fn not(self) -> Self::Output {
215		const ONES: __m128i = m128_from_u128!(u128::MAX);
216
217		self ^ Self(ONES)
218	}
219}
220
221/// `std::cmp::max` isn't const, so we need our own implementation
222pub(crate) const fn max_i32(left: i32, right: i32) -> i32 {
223	if left > right { left } else { right }
224}
225
226/// This solution shows 4X better performance.
227/// We have to use macro because parameter `count` in _mm_slli_epi64/_mm_srli_epi64 should be passed
228/// as constant and Rust currently doesn't allow passing expressions (`count - 64`) where variable
229/// is a generic constant parameter. Source: <https://stackoverflow.com/questions/34478328/the-best-way-to-shift-a-m128i/34482688#34482688>
230macro_rules! bitshift_128b {
231	($val:expr, $shift:ident, $byte_shift:ident, $bit_shift_64:ident, $bit_shift_64_opposite:ident, $or:ident) => {
232		unsafe {
233			let carry = $byte_shift($val, 8);
234			seq!(N in 64..128 {
235				if $shift == N {
236					return $bit_shift_64(
237						carry,
238						crate::arch::x86_64::m128::max_i32((N - 64) as i32, 0) as _,
239					).into();
240				}
241			});
242			seq!(N in 0..64 {
243				if $shift == N {
244					let carry = $bit_shift_64_opposite(
245						carry,
246						crate::arch::x86_64::m128::max_i32((64 - N) as i32, 0) as _,
247					);
248
249					let val = $bit_shift_64($val, N);
250					return $or(val, carry).into();
251				}
252			});
253
254			return Default::default()
255		}
256	};
257}
258
259pub(crate) use bitshift_128b;
260
261impl Shr<usize> for M128 {
262	type Output = Self;
263
264	#[inline(always)]
265	fn shr(self, rhs: usize) -> Self::Output {
266		// This implementation is effective when `rhs` is known at compile-time.
267		// In our code this is always the case.
268		bitshift_128b!(self.0, rhs, _mm_bsrli_si128, _mm_srli_epi64, _mm_slli_epi64, _mm_or_si128)
269	}
270}
271
272impl Shl<usize> for M128 {
273	type Output = Self;
274
275	#[inline(always)]
276	fn shl(self, rhs: usize) -> Self::Output {
277		// This implementation is effective when `rhs` is known at compile-time.
278		// In our code this is always the case.
279		bitshift_128b!(self.0, rhs, _mm_bslli_si128, _mm_slli_epi64, _mm_srli_epi64, _mm_or_si128);
280	}
281}
282
283impl PartialEq for M128 {
284	fn eq(&self, other: &Self) -> bool {
285		unsafe {
286			let neq = _mm_xor_si128(self.0, other.0);
287			_mm_test_all_zeros(neq, neq) == 1
288		}
289	}
290}
291
292impl Eq for M128 {}
293
294impl PartialOrd for M128 {
295	fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
296		Some(self.cmp(other))
297	}
298}
299
300impl Ord for M128 {
301	fn cmp(&self, other: &Self) -> std::cmp::Ordering {
302		u128::from(*self).cmp(&u128::from(*other))
303	}
304}
305
306impl Distribution<M128> for StandardUniform {
307	fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> M128 {
308		M128(rng.random())
309	}
310}
311
312impl std::fmt::Display for M128 {
313	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314		let data: u128 = (*self).into();
315		write!(f, "{data:02X?}")
316	}
317}
318
319impl std::fmt::Debug for M128 {
320	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321		write!(f, "M128({self})")
322	}
323}
324
325#[repr(align(16))]
326pub struct AlignedData(pub [u128; 1]);
327
328macro_rules! m128_from_u128 {
329	($val:expr) => {{
330		let aligned_data = $crate::arch::x86_64::m128::AlignedData([$val]);
331		unsafe { *(aligned_data.0.as_ptr() as *const core::arch::x86_64::__m128i) }
332	}};
333}
334
335pub(super) use m128_from_u128;
336
337impl UnderlierType for M128 {
338	const LOG_BITS: usize = 7;
339}
340
341impl UnderlierWithBitOps for M128 {
342	const ZERO: Self = { Self(m128_from_u128!(0)) };
343	const ONE: Self = { Self(m128_from_u128!(1)) };
344	const ONES: Self = { Self(m128_from_u128!(u128::MAX)) };
345
346	#[inline(always)]
347	fn fill_with_bit(val: u8) -> Self {
348		assert!(val == 0 || val == 1);
349		Self(unsafe { _mm_set1_epi8(val.wrapping_neg() as i8) })
350	}
351
352	#[inline(always)]
353	fn from_fn<T>(mut f: impl FnMut(usize) -> T) -> Self
354	where
355		T: UnderlierType,
356		Self: From<T>,
357	{
358		match T::BITS {
359			1 | 2 | 4 => {
360				let mut f = make_func_to_i8::<T, Self>(f);
361
362				unsafe {
363					_mm_set_epi8(
364						f(15),
365						f(14),
366						f(13),
367						f(12),
368						f(11),
369						f(10),
370						f(9),
371						f(8),
372						f(7),
373						f(6),
374						f(5),
375						f(4),
376						f(3),
377						f(2),
378						f(1),
379						f(0),
380					)
381				}
382				.into()
383			}
384			8 => {
385				let mut f = |i| u8::num_cast_from(Self::from(f(i))) as i8;
386				unsafe {
387					_mm_set_epi8(
388						f(15),
389						f(14),
390						f(13),
391						f(12),
392						f(11),
393						f(10),
394						f(9),
395						f(8),
396						f(7),
397						f(6),
398						f(5),
399						f(4),
400						f(3),
401						f(2),
402						f(1),
403						f(0),
404					)
405				}
406				.into()
407			}
408			16 => {
409				let mut f = |i| u16::num_cast_from(Self::from(f(i))) as i16;
410				unsafe { _mm_set_epi16(f(7), f(6), f(5), f(4), f(3), f(2), f(1), f(0)) }.into()
411			}
412			32 => {
413				let mut f = |i| u32::num_cast_from(Self::from(f(i))) as i32;
414				unsafe { _mm_set_epi32(f(3), f(2), f(1), f(0)) }.into()
415			}
416			64 => {
417				let mut f = |i| u64::num_cast_from(Self::from(f(i))) as i64;
418				unsafe { _mm_set_epi64x(f(1), f(0)) }.into()
419			}
420			128 => Self::from(f(0)),
421			_ => panic!("unsupported bit count"),
422		}
423	}
424
425	#[inline(always)]
426	unsafe fn spread<T>(self, log_block_len: usize, block_idx: usize) -> Self
427	where
428		T: UnderlierWithBitOps + NumCast<Self>,
429		Self: Divisible<T> + From<T>,
430	{
431		match T::LOG_BITS {
432			0 => match log_block_len {
433				0 => Self::fill_with_bit(((u128::from(self) >> block_idx) & 1) as _),
434				1 => unsafe {
435					let bits: [u8; 2] =
436						array::from_fn(|i| ((u128::from(self) >> (block_idx * 2 + i)) & 1) as _);
437
438					_mm_set_epi64x(
439						u64::fill_with_bit(bits[1]) as i64,
440						u64::fill_with_bit(bits[0]) as i64,
441					)
442					.into()
443				},
444				2 => unsafe {
445					let bits: [u8; 4] =
446						array::from_fn(|i| ((u128::from(self) >> (block_idx * 4 + i)) & 1) as _);
447
448					_mm_set_epi32(
449						u32::fill_with_bit(bits[3]) as i32,
450						u32::fill_with_bit(bits[2]) as i32,
451						u32::fill_with_bit(bits[1]) as i32,
452						u32::fill_with_bit(bits[0]) as i32,
453					)
454					.into()
455				},
456				3 => unsafe {
457					let bits: [u8; 8] =
458						array::from_fn(|i| ((u128::from(self) >> (block_idx * 8 + i)) & 1) as _);
459
460					_mm_set_epi16(
461						u16::fill_with_bit(bits[7]) as i16,
462						u16::fill_with_bit(bits[6]) as i16,
463						u16::fill_with_bit(bits[5]) as i16,
464						u16::fill_with_bit(bits[4]) as i16,
465						u16::fill_with_bit(bits[3]) as i16,
466						u16::fill_with_bit(bits[2]) as i16,
467						u16::fill_with_bit(bits[1]) as i16,
468						u16::fill_with_bit(bits[0]) as i16,
469					)
470					.into()
471				},
472				4 => unsafe {
473					let bits: [u8; 16] =
474						array::from_fn(|i| ((u128::from(self) >> (block_idx * 16 + i)) & 1) as _);
475
476					_mm_set_epi8(
477						u8::fill_with_bit(bits[15]) as i8,
478						u8::fill_with_bit(bits[14]) as i8,
479						u8::fill_with_bit(bits[13]) as i8,
480						u8::fill_with_bit(bits[12]) as i8,
481						u8::fill_with_bit(bits[11]) as i8,
482						u8::fill_with_bit(bits[10]) as i8,
483						u8::fill_with_bit(bits[9]) as i8,
484						u8::fill_with_bit(bits[8]) as i8,
485						u8::fill_with_bit(bits[7]) as i8,
486						u8::fill_with_bit(bits[6]) as i8,
487						u8::fill_with_bit(bits[5]) as i8,
488						u8::fill_with_bit(bits[4]) as i8,
489						u8::fill_with_bit(bits[3]) as i8,
490						u8::fill_with_bit(bits[2]) as i8,
491						u8::fill_with_bit(bits[1]) as i8,
492						u8::fill_with_bit(bits[0]) as i8,
493					)
494					.into()
495				},
496				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
497			},
498			1 => match log_block_len {
499				0 => unsafe {
500					let value =
501						U2::new((u128::from(self) >> (block_idx * 2)) as _).spread_to_byte();
502
503					_mm_set1_epi8(value as i8).into()
504				},
505				1 => {
506					let bytes: [u8; 2] = array::from_fn(|i| {
507						U2::new((u128::from(self) >> (block_idx * 4 + i * 2)) as _).spread_to_byte()
508					});
509
510					Self::from_fn::<u8>(|i| bytes[i / 8])
511				}
512				2 => {
513					let bytes: [u8; 4] = array::from_fn(|i| {
514						U2::new((u128::from(self) >> (block_idx * 8 + i * 2)) as _).spread_to_byte()
515					});
516
517					Self::from_fn::<u8>(|i| bytes[i / 4])
518				}
519				3 => {
520					let bytes: [u8; 8] = array::from_fn(|i| {
521						U2::new((u128::from(self) >> (block_idx * 16 + i * 2)) as _)
522							.spread_to_byte()
523					});
524
525					Self::from_fn::<u8>(|i| bytes[i / 2])
526				}
527				4 => {
528					let bytes: [u8; 16] = array::from_fn(|i| {
529						U2::new((u128::from(self) >> (block_idx * 32 + i * 2)) as _)
530							.spread_to_byte()
531					});
532
533					Self::from_fn::<u8>(|i| bytes[i])
534				}
535				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
536			},
537			2 => match log_block_len {
538				0 => {
539					let value =
540						U4::new((u128::from(self) >> (block_idx * 4)) as _).spread_to_byte();
541
542					unsafe { _mm_set1_epi8(value as i8).into() }
543				}
544				1 => {
545					let values: [u8; 2] = array::from_fn(|i| {
546						U4::new((u128::from(self) >> (block_idx * 8 + i * 4)) as _).spread_to_byte()
547					});
548
549					Self::from_fn::<u8>(|i| values[i / 8])
550				}
551				2 => {
552					let values: [u8; 4] = array::from_fn(|i| {
553						U4::new((u128::from(self) >> (block_idx * 16 + i * 4)) as _)
554							.spread_to_byte()
555					});
556
557					Self::from_fn::<u8>(|i| values[i / 4])
558				}
559				3 => {
560					let values: [u8; 8] = array::from_fn(|i| {
561						U4::new((u128::from(self) >> (block_idx * 32 + i * 4)) as _)
562							.spread_to_byte()
563					});
564
565					Self::from_fn::<u8>(|i| values[i / 2])
566				}
567				4 => {
568					let values: [u8; 16] = array::from_fn(|i| {
569						U4::new((u128::from(self) >> (block_idx * 64 + i * 4)) as _)
570							.spread_to_byte()
571					});
572
573					Self::from_fn::<u8>(|i| values[i])
574				}
575				_ => unsafe { spread_fallback(self, log_block_len, block_idx) },
576			},
577			3 => match log_block_len {
578				0 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_0[block_idx].0).into() },
579				1 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_1[block_idx].0).into() },
580				2 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_2[block_idx].0).into() },
581				3 => unsafe { _mm_shuffle_epi8(self.0, LOG_B8_3[block_idx].0).into() },
582				4 => self,
583				_ => panic!("unsupported block length"),
584			},
585			4 => match log_block_len {
586				0 => {
587					let value = (u128::from(self) >> (block_idx * 16)) as u16;
588
589					unsafe { _mm_set1_epi16(value as i16).into() }
590				}
591				1 => {
592					let values: [u16; 2] =
593						array::from_fn(|i| (u128::from(self) >> (block_idx * 32 + i * 16)) as u16);
594
595					Self::from_fn::<u16>(|i| values[i / 4])
596				}
597				2 => {
598					let values: [u16; 4] =
599						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 16)) as u16);
600
601					Self::from_fn::<u16>(|i| values[i / 2])
602				}
603				3 => self,
604				_ => panic!("unsupported block length"),
605			},
606			5 => match log_block_len {
607				0 => unsafe {
608					let value = (u128::from(self) >> (block_idx * 32)) as u32;
609
610					_mm_set1_epi32(value as i32).into()
611				},
612				1 => {
613					let values: [u32; 2] =
614						array::from_fn(|i| (u128::from(self) >> (block_idx * 64 + i * 32)) as u32);
615
616					Self::from_fn::<u32>(|i| values[i / 2])
617				}
618				2 => self,
619				_ => panic!("unsupported block length"),
620			},
621			6 => match log_block_len {
622				0 => unsafe {
623					let value = (u128::from(self) >> (block_idx * 64)) as u64;
624
625					_mm_set1_epi64x(value as i64).into()
626				},
627				1 => self,
628				_ => panic!("unsupported block length"),
629			},
630			7 => self,
631			_ => panic!("unsupported bit length"),
632		}
633	}
634}
635
636unsafe impl Zeroable for M128 {}
637
638unsafe impl Pod for M128 {}
639
640unsafe impl Send for M128 {}
641
642unsafe impl Sync for M128 {}
643
644static LOG_B8_0: [M128; 16] = precompute_spread_mask::<16>(0, 3);
645static LOG_B8_1: [M128; 8] = precompute_spread_mask::<8>(1, 3);
646static LOG_B8_2: [M128; 4] = precompute_spread_mask::<4>(2, 3);
647static LOG_B8_3: [M128; 2] = precompute_spread_mask::<2>(3, 3);
648
649const fn precompute_spread_mask<const BLOCK_IDX_AMOUNT: usize>(
650	log_block_len: usize,
651	t_log_bits: usize,
652) -> [M128; BLOCK_IDX_AMOUNT] {
653	let element_log_width = t_log_bits - 3;
654
655	let element_width = 1 << element_log_width;
656
657	let block_size = 1 << (log_block_len + element_log_width);
658	let repeat = 1 << (4 - element_log_width - log_block_len);
659	let mut masks = [[0u8; 16]; BLOCK_IDX_AMOUNT];
660
661	let mut block_idx = 0;
662
663	while block_idx < BLOCK_IDX_AMOUNT {
664		let base = block_idx * block_size;
665		let mut j = 0;
666		while j < 16 {
667			masks[block_idx][j] =
668				(base + ((j / element_width) / repeat) * element_width + j % element_width) as u8;
669			j += 1;
670		}
671		block_idx += 1;
672	}
673	let mut m128_masks = [M128::ZERO; BLOCK_IDX_AMOUNT];
674
675	let mut block_idx = 0;
676
677	while block_idx < BLOCK_IDX_AMOUNT {
678		m128_masks[block_idx] = M128::from_u128(u128::from_le_bytes(masks[block_idx]));
679		block_idx += 1;
680	}
681
682	m128_masks
683}
684
685impl UnderlierWithBitConstants for M128 {
686	const INTERLEAVE_EVEN_MASK: &'static [Self] = &[
687		Self::from_u128(interleave_mask_even!(u128, 0)),
688		Self::from_u128(interleave_mask_even!(u128, 1)),
689		Self::from_u128(interleave_mask_even!(u128, 2)),
690		Self::from_u128(interleave_mask_even!(u128, 3)),
691		Self::from_u128(interleave_mask_even!(u128, 4)),
692		Self::from_u128(interleave_mask_even!(u128, 5)),
693		Self::from_u128(interleave_mask_even!(u128, 6)),
694	];
695
696	const INTERLEAVE_ODD_MASK: &'static [Self] = &[
697		Self::from_u128(interleave_mask_odd!(u128, 0)),
698		Self::from_u128(interleave_mask_odd!(u128, 1)),
699		Self::from_u128(interleave_mask_odd!(u128, 2)),
700		Self::from_u128(interleave_mask_odd!(u128, 3)),
701		Self::from_u128(interleave_mask_odd!(u128, 4)),
702		Self::from_u128(interleave_mask_odd!(u128, 5)),
703		Self::from_u128(interleave_mask_odd!(u128, 6)),
704	];
705
706	#[inline(always)]
707	fn interleave(self, other: Self, log_block_len: usize) -> (Self, Self) {
708		unsafe {
709			let (c, d) = interleave_bits(
710				Into::<Self>::into(self).into(),
711				Into::<Self>::into(other).into(),
712				log_block_len,
713			);
714			(Self::from(c), Self::from(d))
715		}
716	}
717}
718
719impl<Scalar: BinaryField> From<__m128i> for PackedPrimitiveType<M128, Scalar> {
720	fn from(value: __m128i) -> Self {
721		M128::from(value).into()
722	}
723}
724
725impl<Scalar: BinaryField> From<u128> for PackedPrimitiveType<M128, Scalar> {
726	fn from(value: u128) -> Self {
727		M128::from(value).into()
728	}
729}
730
731impl<Scalar: BinaryField> From<PackedPrimitiveType<M128, Scalar>> for __m128i {
732	fn from(value: PackedPrimitiveType<M128, Scalar>) -> Self {
733		value.to_underlier().into()
734	}
735}
736
737impl<Scalar: BinaryField> Broadcast<Scalar> for PackedPrimitiveType<M128, Scalar>
738where
739	u128: From<Scalar::Underlier>,
740{
741	#[inline(always)]
742	fn broadcast(scalar: Scalar) -> Self {
743		let tower_level = Scalar::N_BITS.ilog2() as usize;
744		match tower_level {
745			0..=3 => {
746				let mut value = u128::from(scalar.to_underlier()) as u8;
747				for n in tower_level..3 {
748					value |= value << (1 << n);
749				}
750
751				unsafe { _mm_set1_epi8(value as i8) }.into()
752			}
753			4 => {
754				let value = u128::from(scalar.to_underlier()) as u16;
755				unsafe { _mm_set1_epi16(value as i16) }.into()
756			}
757			5 => {
758				let value = u128::from(scalar.to_underlier()) as u32;
759				unsafe { _mm_set1_epi32(value as i32) }.into()
760			}
761			6 => {
762				let value = u128::from(scalar.to_underlier()) as u64;
763				unsafe { _mm_set1_epi64x(value as i64) }.into()
764			}
765			7 => {
766				let value = u128::from(scalar.to_underlier());
767				value.into()
768			}
769			_ => {
770				unreachable!("invalid tower level")
771			}
772		}
773	}
774}
775
776#[inline]
777unsafe fn interleave_bits(a: __m128i, b: __m128i, log_block_len: usize) -> (__m128i, __m128i) {
778	match log_block_len {
779		0 => unsafe {
780			let mask = _mm_set1_epi8(0x55i8);
781			interleave_bits_imm::<1>(a, b, mask)
782		},
783		1 => unsafe {
784			let mask = _mm_set1_epi8(0x33i8);
785			interleave_bits_imm::<2>(a, b, mask)
786		},
787		2 => unsafe {
788			let mask = _mm_set1_epi8(0x0fi8);
789			interleave_bits_imm::<4>(a, b, mask)
790		},
791		3 => unsafe {
792			let shuffle = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0);
793			let a = _mm_shuffle_epi8(a, shuffle);
794			let b = _mm_shuffle_epi8(b, shuffle);
795			let a_prime = _mm_unpacklo_epi8(a, b);
796			let b_prime = _mm_unpackhi_epi8(a, b);
797			(a_prime, b_prime)
798		},
799		4 => unsafe {
800			let shuffle = _mm_set_epi8(15, 14, 11, 10, 7, 6, 3, 2, 13, 12, 9, 8, 5, 4, 1, 0);
801			let a = _mm_shuffle_epi8(a, shuffle);
802			let b = _mm_shuffle_epi8(b, shuffle);
803			let a_prime = _mm_unpacklo_epi16(a, b);
804			let b_prime = _mm_unpackhi_epi16(a, b);
805			(a_prime, b_prime)
806		},
807		5 => unsafe {
808			let shuffle = _mm_set_epi8(15, 14, 13, 12, 7, 6, 5, 4, 11, 10, 9, 8, 3, 2, 1, 0);
809			let a = _mm_shuffle_epi8(a, shuffle);
810			let b = _mm_shuffle_epi8(b, shuffle);
811			let a_prime = _mm_unpacklo_epi32(a, b);
812			let b_prime = _mm_unpackhi_epi32(a, b);
813			(a_prime, b_prime)
814		},
815		6 => unsafe {
816			let a_prime = _mm_unpacklo_epi64(a, b);
817			let b_prime = _mm_unpackhi_epi64(a, b);
818			(a_prime, b_prime)
819		},
820		_ => panic!("unsupported block length"),
821	}
822}
823
824#[inline]
825unsafe fn interleave_bits_imm<const BLOCK_LEN: i32>(
826	a: __m128i,
827	b: __m128i,
828	mask: __m128i,
829) -> (__m128i, __m128i) {
830	unsafe {
831		let t = _mm_and_si128(_mm_xor_si128(_mm_srli_epi64::<BLOCK_LEN>(a), b), mask);
832		let a_prime = _mm_xor_si128(a, _mm_slli_epi64::<BLOCK_LEN>(t));
833		let b_prime = _mm_xor_si128(b, t);
834		(a_prime, b_prime)
835	}
836}
837
838// Divisible implementations using SIMD extract/insert intrinsics
839
840impl Divisible<u128> for M128 {
841	const LOG_N: usize = 0;
842
843	#[inline]
844	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone {
845		std::iter::once(u128::from(value))
846	}
847
848	#[inline]
849	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
850		std::iter::once(u128::from(*value))
851	}
852
853	#[inline]
854	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u128> + Send + Clone + '_ {
855		slice.iter().map(|&v| u128::from(v))
856	}
857
858	#[inline]
859	fn get(self, index: usize) -> u128 {
860		assert!(index == 0, "index out of bounds");
861		u128::from(self)
862	}
863
864	#[inline]
865	fn set(self, index: usize, val: u128) -> Self {
866		assert!(index == 0, "index out of bounds");
867		Self::from(val)
868	}
869}
870
871impl Divisible<u64> for M128 {
872	const LOG_N: usize = 1;
873
874	#[inline]
875	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone {
876		mapget::value_iter(value)
877	}
878
879	#[inline]
880	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
881		mapget::value_iter(*value)
882	}
883
884	#[inline]
885	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u64> + Send + Clone + '_ {
886		mapget::slice_iter(slice)
887	}
888
889	#[inline]
890	fn get(self, index: usize) -> u64 {
891		unsafe {
892			match index {
893				0 => _mm_extract_epi64(self.0, 0) as u64,
894				1 => _mm_extract_epi64(self.0, 1) as u64,
895				_ => panic!("index out of bounds"),
896			}
897		}
898	}
899
900	#[inline]
901	fn set(self, index: usize, val: u64) -> Self {
902		unsafe {
903			match index {
904				0 => Self(_mm_insert_epi64(self.0, val as i64, 0)),
905				1 => Self(_mm_insert_epi64(self.0, val as i64, 1)),
906				_ => panic!("index out of bounds"),
907			}
908		}
909	}
910}
911
912impl Divisible<u32> for M128 {
913	const LOG_N: usize = 2;
914
915	#[inline]
916	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone {
917		mapget::value_iter(value)
918	}
919
920	#[inline]
921	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
922		mapget::value_iter(*value)
923	}
924
925	#[inline]
926	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u32> + Send + Clone + '_ {
927		mapget::slice_iter(slice)
928	}
929
930	#[inline]
931	fn get(self, index: usize) -> u32 {
932		unsafe {
933			match index {
934				0 => _mm_extract_epi32(self.0, 0) as u32,
935				1 => _mm_extract_epi32(self.0, 1) as u32,
936				2 => _mm_extract_epi32(self.0, 2) as u32,
937				3 => _mm_extract_epi32(self.0, 3) as u32,
938				_ => panic!("index out of bounds"),
939			}
940		}
941	}
942
943	#[inline]
944	fn set(self, index: usize, val: u32) -> Self {
945		unsafe {
946			match index {
947				0 => Self(_mm_insert_epi32(self.0, val as i32, 0)),
948				1 => Self(_mm_insert_epi32(self.0, val as i32, 1)),
949				2 => Self(_mm_insert_epi32(self.0, val as i32, 2)),
950				3 => Self(_mm_insert_epi32(self.0, val as i32, 3)),
951				_ => panic!("index out of bounds"),
952			}
953		}
954	}
955}
956
957impl Divisible<u16> for M128 {
958	const LOG_N: usize = 3;
959
960	#[inline]
961	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone {
962		mapget::value_iter(value)
963	}
964
965	#[inline]
966	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
967		mapget::value_iter(*value)
968	}
969
970	#[inline]
971	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u16> + Send + Clone + '_ {
972		mapget::slice_iter(slice)
973	}
974
975	#[inline]
976	fn get(self, index: usize) -> u16 {
977		unsafe {
978			match index {
979				0 => _mm_extract_epi16(self.0, 0) as u16,
980				1 => _mm_extract_epi16(self.0, 1) as u16,
981				2 => _mm_extract_epi16(self.0, 2) as u16,
982				3 => _mm_extract_epi16(self.0, 3) as u16,
983				4 => _mm_extract_epi16(self.0, 4) as u16,
984				5 => _mm_extract_epi16(self.0, 5) as u16,
985				6 => _mm_extract_epi16(self.0, 6) as u16,
986				7 => _mm_extract_epi16(self.0, 7) as u16,
987				_ => panic!("index out of bounds"),
988			}
989		}
990	}
991
992	#[inline]
993	fn set(self, index: usize, val: u16) -> Self {
994		unsafe {
995			match index {
996				0 => Self(_mm_insert_epi16(self.0, val as i32, 0)),
997				1 => Self(_mm_insert_epi16(self.0, val as i32, 1)),
998				2 => Self(_mm_insert_epi16(self.0, val as i32, 2)),
999				3 => Self(_mm_insert_epi16(self.0, val as i32, 3)),
1000				4 => Self(_mm_insert_epi16(self.0, val as i32, 4)),
1001				5 => Self(_mm_insert_epi16(self.0, val as i32, 5)),
1002				6 => Self(_mm_insert_epi16(self.0, val as i32, 6)),
1003				7 => Self(_mm_insert_epi16(self.0, val as i32, 7)),
1004				_ => panic!("index out of bounds"),
1005			}
1006		}
1007	}
1008}
1009
1010impl Divisible<u8> for M128 {
1011	const LOG_N: usize = 4;
1012
1013	#[inline]
1014	fn value_iter(value: Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone {
1015		mapget::value_iter(value)
1016	}
1017
1018	#[inline]
1019	fn ref_iter(value: &Self) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
1020		mapget::value_iter(*value)
1021	}
1022
1023	#[inline]
1024	fn slice_iter(slice: &[Self]) -> impl ExactSizeIterator<Item = u8> + Send + Clone + '_ {
1025		mapget::slice_iter(slice)
1026	}
1027
1028	#[inline]
1029	fn get(self, index: usize) -> u8 {
1030		unsafe {
1031			match index {
1032				0 => _mm_extract_epi8(self.0, 0) as u8,
1033				1 => _mm_extract_epi8(self.0, 1) as u8,
1034				2 => _mm_extract_epi8(self.0, 2) as u8,
1035				3 => _mm_extract_epi8(self.0, 3) as u8,
1036				4 => _mm_extract_epi8(self.0, 4) as u8,
1037				5 => _mm_extract_epi8(self.0, 5) as u8,
1038				6 => _mm_extract_epi8(self.0, 6) as u8,
1039				7 => _mm_extract_epi8(self.0, 7) as u8,
1040				8 => _mm_extract_epi8(self.0, 8) as u8,
1041				9 => _mm_extract_epi8(self.0, 9) as u8,
1042				10 => _mm_extract_epi8(self.0, 10) as u8,
1043				11 => _mm_extract_epi8(self.0, 11) as u8,
1044				12 => _mm_extract_epi8(self.0, 12) as u8,
1045				13 => _mm_extract_epi8(self.0, 13) as u8,
1046				14 => _mm_extract_epi8(self.0, 14) as u8,
1047				15 => _mm_extract_epi8(self.0, 15) as u8,
1048				_ => panic!("index out of bounds"),
1049			}
1050		}
1051	}
1052
1053	#[inline]
1054	fn set(self, index: usize, val: u8) -> Self {
1055		unsafe {
1056			match index {
1057				0 => Self(_mm_insert_epi8(self.0, val as i32, 0)),
1058				1 => Self(_mm_insert_epi8(self.0, val as i32, 1)),
1059				2 => Self(_mm_insert_epi8(self.0, val as i32, 2)),
1060				3 => Self(_mm_insert_epi8(self.0, val as i32, 3)),
1061				4 => Self(_mm_insert_epi8(self.0, val as i32, 4)),
1062				5 => Self(_mm_insert_epi8(self.0, val as i32, 5)),
1063				6 => Self(_mm_insert_epi8(self.0, val as i32, 6)),
1064				7 => Self(_mm_insert_epi8(self.0, val as i32, 7)),
1065				8 => Self(_mm_insert_epi8(self.0, val as i32, 8)),
1066				9 => Self(_mm_insert_epi8(self.0, val as i32, 9)),
1067				10 => Self(_mm_insert_epi8(self.0, val as i32, 10)),
1068				11 => Self(_mm_insert_epi8(self.0, val as i32, 11)),
1069				12 => Self(_mm_insert_epi8(self.0, val as i32, 12)),
1070				13 => Self(_mm_insert_epi8(self.0, val as i32, 13)),
1071				14 => Self(_mm_insert_epi8(self.0, val as i32, 14)),
1072				15 => Self(_mm_insert_epi8(self.0, val as i32, 15)),
1073				_ => panic!("index out of bounds"),
1074			}
1075		}
1076	}
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081	use binius_utils::bytes::BytesMut;
1082	use proptest::{arbitrary::any, proptest};
1083	use rand::{SeedableRng, rngs::StdRng};
1084
1085	use super::*;
1086
1087	fn check_roundtrip<T>(val: M128)
1088	where
1089		T: From<M128>,
1090		M128: From<T>,
1091	{
1092		assert_eq!(M128::from(T::from(val)), val);
1093	}
1094
1095	#[test]
1096	fn test_constants() {
1097		assert_eq!(M128::default(), M128::ZERO);
1098		assert_eq!(M128::from(0u128), M128::ZERO);
1099		assert_eq!(M128::from(1u128), M128::ONE);
1100	}
1101
1102	fn get(value: M128, log_block_len: usize, index: usize) -> M128 {
1103		(value >> (index << log_block_len)) & M128::from(1u128 << log_block_len)
1104	}
1105
1106	proptest! {
1107		#[test]
1108		fn test_conversion(a in any::<u128>()) {
1109			check_roundtrip::<u128>(a.into());
1110			check_roundtrip::<__m128i>(a.into());
1111		}
1112
1113		#[test]
1114		fn test_binary_bit_operations(a in any::<u128>(), b in any::<u128>()) {
1115			assert_eq!(M128::from(a & b), M128::from(a) & M128::from(b));
1116			assert_eq!(M128::from(a | b), M128::from(a) | M128::from(b));
1117			assert_eq!(M128::from(a ^ b), M128::from(a) ^ M128::from(b));
1118		}
1119
1120		#[test]
1121		fn test_negate(a in any::<u128>()) {
1122			assert_eq!(M128::from(!a), !M128::from(a))
1123		}
1124
1125		#[test]
1126		fn test_shifts(a in any::<u128>(), b in 0..128usize) {
1127			assert_eq!(M128::from(a << b), M128::from(a) << b);
1128			assert_eq!(M128::from(a >> b), M128::from(a) >> b);
1129		}
1130
1131		#[test]
1132		fn test_interleave_bits(a in any::<u128>(), b in any::<u128>(), height in 0usize..7) {
1133			let a = M128::from(a);
1134			let b = M128::from(b);
1135
1136			let (c, d) = unsafe {interleave_bits(a.0, b.0, height)};
1137			let (c, d) = (M128::from(c), M128::from(d));
1138
1139			for i in (0..128>>height).step_by(2) {
1140				assert_eq!(get(c, height, i), get(a, height, i));
1141				assert_eq!(get(c, height, i+1), get(b, height, i));
1142				assert_eq!(get(d, height, i), get(a, height, i+1));
1143				assert_eq!(get(d, height, i+1), get(b, height, i+1));
1144			}
1145		}
1146	}
1147
1148	#[test]
1149	fn test_fill_with_bit() {
1150		assert_eq!(M128::fill_with_bit(1), M128::from(u128::MAX));
1151		assert_eq!(M128::fill_with_bit(0), M128::from(0u128));
1152	}
1153
1154	#[test]
1155	fn test_eq() {
1156		let a = M128::from(0u128);
1157		let b = M128::from(42u128);
1158		let c = M128::from(u128::MAX);
1159
1160		assert_eq!(a, a);
1161		assert_eq!(b, b);
1162		assert_eq!(c, c);
1163
1164		assert_ne!(a, b);
1165		assert_ne!(a, c);
1166		assert_ne!(b, c);
1167	}
1168
1169	#[test]
1170	fn test_serialize_and_deserialize_m128() {
1171		let mut rng = StdRng::from_seed([0; 32]);
1172
1173		let original_value = M128::from(rng.random::<u128>());
1174
1175		let mut buf = BytesMut::new();
1176		original_value.serialize(&mut buf).unwrap();
1177
1178		let deserialized_value = M128::deserialize(buf.freeze()).unwrap();
1179
1180		assert_eq!(original_value, deserialized_value);
1181	}
1182}