binius_field/arch/x86_64/simd/
simd_arithmetic.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{any::TypeId, arch::x86_64::*};
4
5use crate::{
6	BinaryField, TowerField,
7	aes_field::AESTowerField8b,
8	arch::portable::{packed::PackedPrimitiveType, reuse_multiply_arithmetic::Alpha},
9	underlier::{UnderlierType, UnderlierWithBitOps},
10};
11
12pub trait TowerSimdType: Sized + Copy + UnderlierWithBitOps {
13	/// Blend odd and even elements
14	fn blend_odd_even<Scalar: BinaryField>(a: Self, b: Self) -> Self;
15	/// Set alpha to even elements
16	fn set_alpha_even<Scalar: BinaryField>(self) -> Self;
17	/// Apply `mask` to `a` (set zeros at positions where high bit of the `mask` is 0).
18	fn apply_mask<Scalar: BinaryField>(mask: Self, a: Self) -> Self;
19
20	/// Bit xor operation
21	fn xor(a: Self, b: Self) -> Self;
22
23	/// Shuffle 8-bit elements within 128-bit lanes
24	fn shuffle_epi8(a: Self, b: Self) -> Self;
25
26	/// Byte shifts within 128-bit lanes
27	fn bslli_epi128<const IMM8: i32>(self) -> Self;
28	fn bsrli_epi128<const IMM8: i32>(self) -> Self;
29
30	/// Initialize value with a single element
31	fn set1_epi128(val: __m128i) -> Self;
32	fn set_epi_64(val: i64) -> Self;
33
34	#[inline(always)]
35	fn dup_shuffle<Scalar: BinaryField>() -> Self {
36		let shuffle_mask_128 = unsafe {
37			match Scalar::N_BITS.ilog2() {
38				3 => _mm_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0),
39				4 => _mm_set_epi8(13, 12, 13, 12, 9, 8, 9, 8, 5, 4, 5, 4, 1, 0, 1, 0),
40				5 => _mm_set_epi8(11, 10, 9, 8, 11, 10, 9, 8, 3, 2, 1, 0, 3, 2, 1, 0),
41				6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0),
42				_ => panic!("unsupported bit count"),
43			}
44		};
45
46		Self::set1_epi128(shuffle_mask_128)
47	}
48
49	#[inline(always)]
50	fn flip_shuffle<Scalar: BinaryField>() -> Self {
51		let flip_mask_128 = unsafe {
52			match Scalar::N_BITS.ilog2() {
53				3 => _mm_set_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1),
54				4 => _mm_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2),
55				5 => _mm_set_epi8(11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4),
56				6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8),
57				_ => panic!("unsupported bit count"),
58			}
59		};
60
61		Self::set1_epi128(flip_mask_128)
62	}
63
64	/// Creates mask to propagate the highest bit form mask to other element bytes
65	#[inline(always)]
66	fn make_epi8_mask_shuffle<Scalar: BinaryField>() -> Self {
67		let epi8_mask_128 = unsafe {
68			match Scalar::N_BITS.ilog2() {
69				4 => _mm_set_epi8(15, 15, 13, 13, 11, 11, 9, 9, 7, 7, 5, 5, 3, 3, 1, 1),
70				5 => _mm_set_epi8(15, 15, 15, 15, 11, 11, 11, 11, 7, 7, 7, 7, 3, 3, 3, 3),
71				6 => _mm_set_epi8(15, 15, 15, 15, 15, 15, 15, 15, 7, 7, 7, 7, 7, 7, 7, 7),
72				7 => _mm_set1_epi8(15),
73				_ => panic!("unsupported bit count"),
74			}
75		};
76
77		Self::set1_epi128(epi8_mask_128)
78	}
79
80	#[inline(always)]
81	fn alpha<Scalar: BinaryField>() -> Self {
82		let alpha_128 = {
83			match Scalar::N_BITS.ilog2() {
84				3 => {
85					// Compiler will optimize this if out for each instantiation
86					let type_id = TypeId::of::<Scalar>();
87					let value = if type_id == TypeId::of::<AESTowerField8b>() {
88						0xd3u8 as i8
89					} else {
90						panic!("tower field not supported")
91					};
92					unsafe { _mm_set1_epi8(value) }
93				}
94				4 => unsafe { _mm_set1_epi16(0x0100) },
95				5 => unsafe { _mm_set1_epi32(0x00010000) },
96				6 => unsafe { _mm_set1_epi64x(0x0000000100000000) },
97				_ => panic!("unsupported bit count"),
98			}
99		};
100
101		Self::set1_epi128(alpha_128)
102	}
103
104	#[inline(always)]
105	fn even_mask<Scalar: BinaryField>() -> Self {
106		let mask_128 = {
107			match Scalar::N_BITS.ilog2() {
108				3 => unsafe { _mm_set1_epi16(0x00FF) },
109				4 => unsafe { _mm_set1_epi32(0x0000FFFF) },
110				5 => unsafe { _mm_set1_epi64x(0x00000000FFFFFFFF) },
111				6 => unsafe { _mm_set_epi64x(0, -1) },
112				_ => panic!("unsupported bit count"),
113			}
114		};
115
116		Self::set1_epi128(mask_128)
117	}
118}
119
120impl<U: UnderlierType + TowerSimdType, Scalar: TowerField> Alpha
121	for PackedPrimitiveType<U, Scalar>
122{
123	#[inline(always)]
124	fn alpha() -> Self {
125		U::alpha::<Scalar>().into()
126	}
127}