binius_field/arch/x86_64/simd/
simd_arithmetic.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{any::TypeId, arch::x86_64::*};
4
5use crate::{
6	BinaryField, PackedField, TowerField,
7	aes_field::AESTowerField8b,
8	arch::{
9		SimdStrategy,
10		portable::{
11			packed::PackedPrimitiveType, packed_arithmetic::PackedTowerField,
12			reuse_multiply_arithmetic::Alpha,
13		},
14	},
15	arithmetic_traits::{
16		MulAlpha, TaggedInvertOrZero, TaggedMul, TaggedMulAlpha, TaggedPackedTransformationFactory,
17		TaggedSquare,
18	},
19	linear_transformation::{FieldLinearTransformation, Transformation},
20	packed::PackedBinaryField,
21	underlier::{UnderlierType, UnderlierWithBitOps, WithUnderlier},
22};
23
24pub trait TowerSimdType: Sized + Copy + UnderlierWithBitOps {
25	/// Blend odd and even elements
26	fn blend_odd_even<Scalar: BinaryField>(a: Self, b: Self) -> Self;
27	/// Set alpha to even elements
28	fn set_alpha_even<Scalar: BinaryField>(self) -> Self;
29	/// Apply `mask` to `a` (set zeros at positions where high bit of the `mask` is 0).
30	fn apply_mask<Scalar: BinaryField>(mask: Self, a: Self) -> Self;
31
32	/// Bit xor operation
33	fn xor(a: Self, b: Self) -> Self;
34
35	/// Shuffle 8-bit elements within 128-bit lanes
36	fn shuffle_epi8(a: Self, b: Self) -> Self;
37
38	/// Byte shifts within 128-bit lanes
39	fn bslli_epi128<const IMM8: i32>(self) -> Self;
40	fn bsrli_epi128<const IMM8: i32>(self) -> Self;
41
42	/// Initialize value with a single element
43	fn set1_epi128(val: __m128i) -> Self;
44	fn set_epi_64(val: i64) -> Self;
45
46	#[inline(always)]
47	fn dup_shuffle<Scalar: BinaryField>() -> Self {
48		let shuffle_mask_128 = unsafe {
49			match Scalar::N_BITS.ilog2() {
50				3 => _mm_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0),
51				4 => _mm_set_epi8(13, 12, 13, 12, 9, 8, 9, 8, 5, 4, 5, 4, 1, 0, 1, 0),
52				5 => _mm_set_epi8(11, 10, 9, 8, 11, 10, 9, 8, 3, 2, 1, 0, 3, 2, 1, 0),
53				6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0),
54				_ => panic!("unsupported bit count"),
55			}
56		};
57
58		Self::set1_epi128(shuffle_mask_128)
59	}
60
61	#[inline(always)]
62	fn flip_shuffle<Scalar: BinaryField>() -> Self {
63		let flip_mask_128 = unsafe {
64			match Scalar::N_BITS.ilog2() {
65				3 => _mm_set_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1),
66				4 => _mm_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2),
67				5 => _mm_set_epi8(11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4),
68				6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8),
69				_ => panic!("unsupported bit count"),
70			}
71		};
72
73		Self::set1_epi128(flip_mask_128)
74	}
75
76	/// Creates mask to propagate the highest bit form mask to other element bytes
77	#[inline(always)]
78	fn make_epi8_mask_shuffle<Scalar: BinaryField>() -> Self {
79		let epi8_mask_128 = unsafe {
80			match Scalar::N_BITS.ilog2() {
81				4 => _mm_set_epi8(15, 15, 13, 13, 11, 11, 9, 9, 7, 7, 5, 5, 3, 3, 1, 1),
82				5 => _mm_set_epi8(15, 15, 15, 15, 11, 11, 11, 11, 7, 7, 7, 7, 3, 3, 3, 3),
83				6 => _mm_set_epi8(15, 15, 15, 15, 15, 15, 15, 15, 7, 7, 7, 7, 7, 7, 7, 7),
84				7 => _mm_set1_epi8(15),
85				_ => panic!("unsupported bit count"),
86			}
87		};
88
89		Self::set1_epi128(epi8_mask_128)
90	}
91
92	#[inline(always)]
93	fn alpha<Scalar: BinaryField>() -> Self {
94		let alpha_128 = {
95			match Scalar::N_BITS.ilog2() {
96				3 => {
97					// Compiler will optimize this if out for each instantiation
98					let type_id = TypeId::of::<Scalar>();
99					let value = if type_id == TypeId::of::<AESTowerField8b>() {
100						0xd3u8 as i8
101					} else {
102						panic!("tower field not supported")
103					};
104					unsafe { _mm_set1_epi8(value) }
105				}
106				4 => unsafe { _mm_set1_epi16(0x0100) },
107				5 => unsafe { _mm_set1_epi32(0x00010000) },
108				6 => unsafe { _mm_set1_epi64x(0x0000000100000000) },
109				_ => panic!("unsupported bit count"),
110			}
111		};
112
113		Self::set1_epi128(alpha_128)
114	}
115
116	#[inline(always)]
117	fn even_mask<Scalar: BinaryField>() -> Self {
118		let mask_128 = {
119			match Scalar::N_BITS.ilog2() {
120				3 => unsafe { _mm_set1_epi16(0x00FF) },
121				4 => unsafe { _mm_set1_epi32(0x0000FFFF) },
122				5 => unsafe { _mm_set1_epi64x(0x00000000FFFFFFFF) },
123				6 => unsafe { _mm_set_epi64x(0, -1) },
124				_ => panic!("unsupported bit count"),
125			}
126		};
127
128		Self::set1_epi128(mask_128)
129	}
130}
131
132impl<U: UnderlierType + TowerSimdType, Scalar: TowerField> Alpha
133	for PackedPrimitiveType<U, Scalar>
134{
135	#[inline(always)]
136	fn alpha() -> Self {
137		U::alpha::<Scalar>().into()
138	}
139}
140
141#[inline(always)]
142fn blend_odd_even<U, PT>(a: PT, b: PT) -> PT
143where
144	U: TowerSimdType,
145	PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
146{
147	PT::from_underlier(U::blend_odd_even::<PT::Scalar>(a.to_underlier(), b.to_underlier()))
148}
149
150#[inline(always)]
151fn xor<U, PT>(a: PT, b: PT) -> PT
152where
153	U: TowerSimdType,
154	PT: WithUnderlier<Underlier = U>,
155{
156	PT::from_underlier(U::xor(a.to_underlier(), b.to_underlier()))
157}
158
159#[inline(always)]
160fn duplicate_odd<U, PT>(val: PT) -> PT
161where
162	U: TowerSimdType,
163	PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
164{
165	PT::from_underlier(U::shuffle_epi8(val.to_underlier(), U::dup_shuffle::<PT::Scalar>()))
166}
167
168#[inline(always)]
169fn flip_even_odd<U, PT>(val: PT) -> PT
170where
171	U: TowerSimdType,
172	PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
173{
174	PT::from_underlier(U::shuffle_epi8(val.to_underlier(), U::flip_shuffle::<PT::Scalar>()))
175}
176
177impl<U, Scalar: TowerField> TaggedMul<SimdStrategy> for PackedPrimitiveType<U, Scalar>
178where
179	Self: PackedTowerField<Underlier = U>,
180	U: TowerSimdType + UnderlierType,
181{
182	fn mul(self, rhs: Self) -> Self {
183		// This fallback is needed to generically use SimdStrategy in benchmarks.
184		if Scalar::TOWER_LEVEL <= 3 {
185			return self * rhs;
186		}
187
188		let a = self.as_packed_subfield();
189		let b = rhs.as_packed_subfield();
190
191		// [a0_lo * b0_lo, a0_hi * b0_hi, a1_lo * b1_lo, a1_h1 * b1_hi, ...]
192		let z0_even_z2_odd = a * b;
193
194		// [a0_lo, b0_lo, a1_lo, b1_lo, ...]
195		// [a0_hi, b0_hi, a1_hi, b1_hi, ...]
196		let (lo, hi) = a.interleave(b, 0);
197		// [a0_lo + a0_hi, b0_lo + b0_hi, a1_lo + a1_hi, b1lo + b1_hi, ...]
198		let lo_plus_hi_a_even_b_odd = lo + hi;
199
200		let alpha_even_z2_odd = <Self as PackedTowerField>::PackedDirectSubfield::from_underlier(
201			z0_even_z2_odd
202				.to_underlier()
203				.set_alpha_even::<<Self as PackedTowerField>::DirectSubfield>(),
204		);
205		let (lhs, rhs) = lo_plus_hi_a_even_b_odd.interleave(alpha_even_z2_odd, 0);
206		let z1_xor_z0z2_even_z2a_odd = lhs * rhs;
207
208		let z1_xor_z0z2 = duplicate_odd(z1_xor_z0z2_even_z2a_odd);
209		let zero_even_z1_xor_z2a_xor_z0z2_odd = xor(z1_xor_z0z2_even_z2a_odd, z1_xor_z0z2);
210
211		let z2_even_z0_odd = flip_even_odd(z0_even_z2_odd);
212		let z0z2 = xor(z0_even_z2_odd, z2_even_z0_odd);
213
214		Self::from_packed_subfield(xor(zero_even_z1_xor_z2a_xor_z0z2_odd, z0z2))
215	}
216}
217
218impl<U, Scalar: TowerField> TaggedMulAlpha<SimdStrategy> for PackedPrimitiveType<U, Scalar>
219where
220	Self: PackedTowerField<Underlier = U> + MulAlpha,
221	<Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
222	U: TowerSimdType + UnderlierType,
223{
224	#[inline]
225	fn mul_alpha(self) -> Self {
226		// This fallback is needed to generically use SimdStrategy in benchmarks.
227		if Scalar::TOWER_LEVEL <= 3 {
228			return MulAlpha::mul_alpha(self);
229		}
230
231		let a_0_a_1 = self.as_packed_subfield();
232		let a_0_mul_alpha_a_1_mul_alpha = a_0_a_1.mul_alpha();
233
234		let a_1_a_0 = flip_even_odd(self.as_packed_subfield());
235		let a0_plus_a1_alpha = xor(a_0_mul_alpha_a_1_mul_alpha, a_1_a_0);
236
237		Self::from_packed_subfield(blend_odd_even(a0_plus_a1_alpha, a_1_a_0))
238	}
239}
240
241impl<U, Scalar: TowerField> TaggedSquare<SimdStrategy> for PackedPrimitiveType<U, Scalar>
242where
243	Self: PackedTowerField<Underlier = U>,
244	<Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
245	U: TowerSimdType + UnderlierType,
246{
247	fn square(self) -> Self {
248		// This fallback is needed to generically use SimdStrategy in benchmarks.
249		if Scalar::TOWER_LEVEL <= 3 {
250			return PackedField::square(self);
251		}
252
253		let a_0_a_1 = self.as_packed_subfield();
254		let a_0_sq_a_1_sq = PackedField::square(a_0_a_1);
255		let a_1_sq_a_0_sq = flip_even_odd(a_0_sq_a_1_sq);
256		let a_0_sq_plus_a_1_sq = a_0_sq_a_1_sq + a_1_sq_a_0_sq;
257		let a_1_mul_alpha = a_0_sq_a_1_sq.mul_alpha();
258
259		Self::from_packed_subfield(blend_odd_even(a_1_mul_alpha, a_0_sq_plus_a_1_sq))
260	}
261}
262
263impl<U, Scalar: TowerField> TaggedInvertOrZero<SimdStrategy> for PackedPrimitiveType<U, Scalar>
264where
265	Self: PackedTowerField<Underlier = U>,
266	<Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
267	U: TowerSimdType + UnderlierType,
268{
269	fn invert_or_zero(self) -> Self {
270		// This fallback is needed to generically use SimdStrategy in benchmarks.
271		if Scalar::TOWER_LEVEL <= 3 {
272			return PackedField::invert_or_zero(self);
273		}
274
275		let a_0_a_1 = self.as_packed_subfield();
276		let a_1_a_0 = flip_even_odd(a_0_a_1);
277		let a_1_mul_alpha = a_1_a_0.mul_alpha();
278		let a_0_plus_a1_mul_alpha = xor(a_0_a_1, a_1_mul_alpha);
279		let a_1_sq_a_0_sq = PackedField::square(a_1_a_0);
280		let delta = xor(a_1_sq_a_0_sq, a_0_plus_a1_mul_alpha * a_0_a_1);
281		let delta_inv = PackedField::invert_or_zero(delta);
282		let delta_inv_delta_inv = duplicate_odd(delta_inv);
283		let delta_multiplier = blend_odd_even(a_0_a_1, a_0_plus_a1_mul_alpha);
284
285		Self::from_packed_subfield(delta_inv_delta_inv * delta_multiplier)
286	}
287}
288
289/// SIMD packed field transformation.
290/// The idea is similar to `PackedTransformation` but we use SIMD instructions
291/// to multiply a component with zeros/ones by a basis vector.
292pub struct SimdTransformation<OP> {
293	bases: Vec<OP>,
294	ones: OP,
295}
296
297#[allow(private_bounds)]
298impl<OP> SimdTransformation<OP>
299where
300	OP: PackedBinaryField + WithUnderlier<Underlier: TowerSimdType + UnderlierWithBitOps>,
301{
302	pub fn new<Data: AsRef<[OP::Scalar]> + Sync>(
303		transformation: FieldLinearTransformation<OP::Scalar, Data>,
304	) -> Self {
305		Self {
306			bases: transformation
307				.bases()
308				.iter()
309				.map(|base| OP::broadcast(*base))
310				.collect(),
311			// Set ones to the highest bit
312			// This is the format that is used in SIMD masks
313			ones: OP::one().mutate_underlier(|underlier| underlier << (OP::Scalar::N_BITS - 1)),
314		}
315	}
316}
317
318impl<U, IP, OP, IF, OF> Transformation<IP, OP> for SimdTransformation<OP>
319where
320	IP: PackedField<Scalar = IF> + WithUnderlier<Underlier = U>,
321	OP: PackedField<Scalar = OF> + WithUnderlier<Underlier = U>,
322	IF: BinaryField,
323	OF: BinaryField,
324	U: UnderlierWithBitOps + TowerSimdType,
325{
326	fn transform(&self, input: &IP) -> OP {
327		let mut result = OP::zero();
328		let ones = self.ones.to_underlier();
329		let mut input = input.to_underlier();
330
331		// Unlike `PackedTransformation`, we iterate from the highest bit to lowest one
332		// keeping current component in the highest bit.
333		for base in self.bases.iter().rev() {
334			let bases_mask = input & ones;
335			let component = U::apply_mask::<OP::Scalar>(bases_mask, base.to_underlier());
336			result += OP::from_underlier(component);
337			input = input << 1;
338		}
339
340		result
341	}
342}
343
344impl<IP, OP> TaggedPackedTransformationFactory<SimdStrategy, OP> for IP
345where
346	IP: PackedBinaryField + WithUnderlier<Underlier: UnderlierWithBitOps>,
347	OP: PackedBinaryField + WithUnderlier<Underlier = IP::Underlier>,
348	IP::Underlier: TowerSimdType,
349{
350	type PackedTransformation<Data: AsRef<[<OP>::Scalar]> + Sync> = SimdTransformation<OP>;
351
352	fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
353		transformation: FieldLinearTransformation<OP::Scalar, Data>,
354	) -> Self::PackedTransformation<Data> {
355		SimdTransformation::new(transformation)
356	}
357}
358
359#[cfg(test)]
360mod tests {
361	use super::*;
362	use crate::test_utils::{
363		define_invert_tests, define_multiply_tests, define_square_tests,
364		define_transformation_tests,
365	};
366
367	define_multiply_tests!(TaggedMul<SimdStrategy>::mul, TaggedMul<SimdStrategy>);
368
369	define_square_tests!(TaggedSquare<SimdStrategy>::square, TaggedSquare<SimdStrategy>);
370
371	define_invert_tests!(
372		TaggedInvertOrZero<SimdStrategy>::invert_or_zero,
373		TaggedInvertOrZero<SimdStrategy>
374	);
375
376	#[allow(unused)]
377	trait SelfPackedTransformationFactory:
378		TaggedPackedTransformationFactory<SimdStrategy, Self>
379	{
380	}
381
382	impl<T: TaggedPackedTransformationFactory<SimdStrategy, Self>> SelfPackedTransformationFactory
383		for T
384	{
385	}
386
387	define_transformation_tests!(SelfPackedTransformationFactory);
388}