binius_field/arch/x86_64/simd/
simd_arithmetic.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{any::TypeId, arch::x86_64::*};
4
5use crate::{
6	aes_field::AESTowerField8b,
7	arch::{
8		portable::{
9			packed::PackedPrimitiveType, packed_arithmetic::PackedTowerField,
10			reuse_multiply_arithmetic::Alpha,
11		},
12		SimdStrategy,
13	},
14	arithmetic_traits::{
15		MulAlpha, TaggedInvertOrZero, TaggedMul, TaggedMulAlpha, TaggedPackedTransformationFactory,
16		TaggedSquare,
17	},
18	linear_transformation::{FieldLinearTransformation, Transformation},
19	packed::PackedBinaryField,
20	underlier::{UnderlierType, UnderlierWithBitOps, WithUnderlier},
21	BinaryField, BinaryField8b, PackedField, TowerField,
22};
23
24pub trait TowerSimdType: Sized + Copy + UnderlierWithBitOps {
25	/// Blend odd and even elements
26	fn blend_odd_even<Scalar: BinaryField>(a: Self, b: Self) -> Self;
27	/// Set alpha to even elements
28	fn set_alpha_even<Scalar: BinaryField>(self) -> Self;
29	/// Apply `mask` to `a` (set zeros at positions where high bit of the `mask` is 0).
30	fn apply_mask<Scalar: BinaryField>(mask: Self, a: Self) -> Self;
31
32	/// Bit xor operation
33	fn xor(a: Self, b: Self) -> Self;
34
35	/// Shuffle 8-bit elements within 128-bit lanes
36	fn shuffle_epi8(a: Self, b: Self) -> Self;
37
38	/// Byte shifts within 128-bit lanes
39	fn bslli_epi128<const IMM8: i32>(self) -> Self;
40	fn bsrli_epi128<const IMM8: i32>(self) -> Self;
41
42	/// Initialize value with a single element
43	fn set1_epi128(val: __m128i) -> Self;
44	fn set_epi_64(val: i64) -> Self;
45
46	#[inline(always)]
47	fn dup_shuffle<Scalar: BinaryField>() -> Self {
48		let shuffle_mask_128 = unsafe {
49			match Scalar::N_BITS.ilog2() {
50				3 => _mm_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0),
51				4 => _mm_set_epi8(13, 12, 13, 12, 9, 8, 9, 8, 5, 4, 5, 4, 1, 0, 1, 0),
52				5 => _mm_set_epi8(11, 10, 9, 8, 11, 10, 9, 8, 3, 2, 1, 0, 3, 2, 1, 0),
53				6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0),
54				_ => panic!("unsupported bit count"),
55			}
56		};
57
58		Self::set1_epi128(shuffle_mask_128)
59	}
60
61	#[inline(always)]
62	fn flip_shuffle<Scalar: BinaryField>() -> Self {
63		let flip_mask_128 = unsafe {
64			match Scalar::N_BITS.ilog2() {
65				3 => _mm_set_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1),
66				4 => _mm_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2),
67				5 => _mm_set_epi8(11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4),
68				6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8),
69				_ => panic!("unsupported bit count"),
70			}
71		};
72
73		Self::set1_epi128(flip_mask_128)
74	}
75
76	/// Creates mask to propagate the highest bit form mask to other element bytes
77	#[inline(always)]
78	fn make_epi8_mask_shuffle<Scalar: BinaryField>() -> Self {
79		let epi8_mask_128 = unsafe {
80			match Scalar::N_BITS.ilog2() {
81				4 => _mm_set_epi8(15, 15, 13, 13, 11, 11, 9, 9, 7, 7, 5, 5, 3, 3, 1, 1),
82				5 => _mm_set_epi8(15, 15, 15, 15, 11, 11, 11, 11, 7, 7, 7, 7, 3, 3, 3, 3),
83				6 => _mm_set_epi8(15, 15, 15, 15, 15, 15, 15, 15, 7, 7, 7, 7, 7, 7, 7, 7),
84				7 => _mm_set1_epi8(15),
85				_ => panic!("unsupported bit count"),
86			}
87		};
88
89		Self::set1_epi128(epi8_mask_128)
90	}
91
92	#[inline(always)]
93	fn alpha<Scalar: BinaryField>() -> Self {
94		let alpha_128 = unsafe {
95			match Scalar::N_BITS.ilog2() {
96				3 => {
97					// Compiler will optimize this if out for each instantiation
98					let type_id = TypeId::of::<Scalar>();
99					let value = if type_id == TypeId::of::<BinaryField8b>() {
100						0x10
101					} else if type_id == TypeId::of::<AESTowerField8b>() {
102						0xd3u8 as i8
103					} else {
104						panic!("tower field not supported")
105					};
106					_mm_set1_epi8(value)
107				}
108				4 => _mm_set1_epi16(0x0100),
109				5 => _mm_set1_epi32(0x00010000),
110				6 => _mm_set1_epi64x(0x0000000100000000),
111				_ => panic!("unsupported bit count"),
112			}
113		};
114
115		Self::set1_epi128(alpha_128)
116	}
117
118	#[inline(always)]
119	fn even_mask<Scalar: BinaryField>() -> Self {
120		let mask_128 = unsafe {
121			match Scalar::N_BITS.ilog2() {
122				3 => _mm_set1_epi16(0x00FF),
123				4 => _mm_set1_epi32(0x0000FFFF),
124				5 => _mm_set1_epi64x(0x00000000FFFFFFFF),
125				6 => _mm_set_epi64x(0, -1),
126				_ => panic!("unsupported bit count"),
127			}
128		};
129
130		Self::set1_epi128(mask_128)
131	}
132}
133
134impl<U: UnderlierType + TowerSimdType, Scalar: TowerField> Alpha
135	for PackedPrimitiveType<U, Scalar>
136{
137	#[inline(always)]
138	fn alpha() -> Self {
139		U::alpha::<Scalar>().into()
140	}
141}
142
143#[inline(always)]
144fn blend_odd_even<U, PT>(a: PT, b: PT) -> PT
145where
146	U: TowerSimdType,
147	PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
148{
149	PT::from_underlier(U::blend_odd_even::<PT::Scalar>(a.to_underlier(), b.to_underlier()))
150}
151
152#[inline(always)]
153fn xor<U, PT>(a: PT, b: PT) -> PT
154where
155	U: TowerSimdType,
156	PT: WithUnderlier<Underlier = U>,
157{
158	PT::from_underlier(U::xor(a.to_underlier(), b.to_underlier()))
159}
160
161#[inline(always)]
162fn duplicate_odd<U, PT>(val: PT) -> PT
163where
164	U: TowerSimdType,
165	PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
166{
167	PT::from_underlier(U::shuffle_epi8(val.to_underlier(), U::dup_shuffle::<PT::Scalar>()))
168}
169
170#[inline(always)]
171fn flip_even_odd<U, PT>(val: PT) -> PT
172where
173	U: TowerSimdType,
174	PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
175{
176	PT::from_underlier(U::shuffle_epi8(val.to_underlier(), U::flip_shuffle::<PT::Scalar>()))
177}
178
179impl<U, Scalar: TowerField> TaggedMul<SimdStrategy> for PackedPrimitiveType<U, Scalar>
180where
181	Self: PackedTowerField<Underlier = U>,
182	U: TowerSimdType + UnderlierType,
183{
184	fn mul(self, rhs: Self) -> Self {
185		// This fallback is needed to generically use SimdStrategy in benchmarks.
186		if Scalar::TOWER_LEVEL <= 3 {
187			return self * rhs;
188		}
189
190		let a = self.as_packed_subfield();
191		let b = rhs.as_packed_subfield();
192
193		// [a0_lo * b0_lo, a0_hi * b0_hi, a1_lo * b1_lo, a1_h1 * b1_hi, ...]
194		let z0_even_z2_odd = a * b;
195
196		// [a0_lo, b0_lo, a1_lo, b1_lo, ...]
197		// [a0_hi, b0_hi, a1_hi, b1_hi, ...]
198		let (lo, hi) = a.interleave(b, 0);
199		// [a0_lo + a0_hi, b0_lo + b0_hi, a1_lo + a1_hi, b1lo + b1_hi, ...]
200		let lo_plus_hi_a_even_b_odd = lo + hi;
201
202		let alpha_even_z2_odd = <Self as PackedTowerField>::PackedDirectSubfield::from_underlier(
203			z0_even_z2_odd
204				.to_underlier()
205				.set_alpha_even::<<Self as PackedTowerField>::DirectSubfield>(),
206		);
207		let (lhs, rhs) = lo_plus_hi_a_even_b_odd.interleave(alpha_even_z2_odd, 0);
208		let z1_xor_z0z2_even_z2a_odd = lhs * rhs;
209
210		let z1_xor_z0z2 = duplicate_odd(z1_xor_z0z2_even_z2a_odd);
211		let zero_even_z1_xor_z2a_xor_z0z2_odd = xor(z1_xor_z0z2_even_z2a_odd, z1_xor_z0z2);
212
213		let z2_even_z0_odd = flip_even_odd(z0_even_z2_odd);
214		let z0z2 = xor(z0_even_z2_odd, z2_even_z0_odd);
215
216		Self::from_packed_subfield(xor(zero_even_z1_xor_z2a_xor_z0z2_odd, z0z2))
217	}
218}
219
220impl<U, Scalar: TowerField> TaggedMulAlpha<SimdStrategy> for PackedPrimitiveType<U, Scalar>
221where
222	Self: PackedTowerField<Underlier = U> + MulAlpha,
223	<Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
224	U: TowerSimdType + UnderlierType,
225{
226	#[inline]
227	fn mul_alpha(self) -> Self {
228		// This fallback is needed to generically use SimdStrategy in benchmarks.
229		if Scalar::TOWER_LEVEL <= 3 {
230			return MulAlpha::mul_alpha(self);
231		}
232
233		let a_0_a_1 = self.as_packed_subfield();
234		let a_0_mul_alpha_a_1_mul_alpha = a_0_a_1.mul_alpha();
235
236		let a_1_a_0 = flip_even_odd(self.as_packed_subfield());
237		let a0_plus_a1_alpha = xor(a_0_mul_alpha_a_1_mul_alpha, a_1_a_0);
238
239		Self::from_packed_subfield(blend_odd_even(a0_plus_a1_alpha, a_1_a_0))
240	}
241}
242
243impl<U, Scalar: TowerField> TaggedSquare<SimdStrategy> for PackedPrimitiveType<U, Scalar>
244where
245	Self: PackedTowerField<Underlier = U>,
246	<Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
247	U: TowerSimdType + UnderlierType,
248{
249	fn square(self) -> Self {
250		// This fallback is needed to generically use SimdStrategy in benchmarks.
251		if Scalar::TOWER_LEVEL <= 3 {
252			return PackedField::square(self);
253		}
254
255		let a_0_a_1 = self.as_packed_subfield();
256		let a_0_sq_a_1_sq = PackedField::square(a_0_a_1);
257		let a_1_sq_a_0_sq = flip_even_odd(a_0_sq_a_1_sq);
258		let a_0_sq_plus_a_1_sq = a_0_sq_a_1_sq + a_1_sq_a_0_sq;
259		let a_1_mul_alpha = a_0_sq_a_1_sq.mul_alpha();
260
261		Self::from_packed_subfield(blend_odd_even(a_1_mul_alpha, a_0_sq_plus_a_1_sq))
262	}
263}
264
265impl<U, Scalar: TowerField> TaggedInvertOrZero<SimdStrategy> for PackedPrimitiveType<U, Scalar>
266where
267	Self: PackedTowerField<Underlier = U>,
268	<Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
269	U: TowerSimdType + UnderlierType,
270{
271	fn invert_or_zero(self) -> Self {
272		// This fallback is needed to generically use SimdStrategy in benchmarks.
273		if Scalar::TOWER_LEVEL <= 3 {
274			return PackedField::invert_or_zero(self);
275		}
276
277		let a_0_a_1 = self.as_packed_subfield();
278		let a_1_a_0 = flip_even_odd(a_0_a_1);
279		let a_1_mul_alpha = a_1_a_0.mul_alpha();
280		let a_0_plus_a1_mul_alpha = xor(a_0_a_1, a_1_mul_alpha);
281		let a_1_sq_a_0_sq = PackedField::square(a_1_a_0);
282		let delta = xor(a_1_sq_a_0_sq, a_0_plus_a1_mul_alpha * a_0_a_1);
283		let delta_inv = PackedField::invert_or_zero(delta);
284		let delta_inv_delta_inv = duplicate_odd(delta_inv);
285		let delta_multiplier = blend_odd_even(a_0_a_1, a_0_plus_a1_mul_alpha);
286
287		Self::from_packed_subfield(delta_inv_delta_inv * delta_multiplier)
288	}
289}
290
291/// SIMD packed field transformation.
292/// The idea is similar to `PackedTransformation` but we use SIMD instructions
293/// to multiply a component with zeros/ones by a basis vector.
294pub struct SimdTransformation<OP> {
295	bases: Vec<OP>,
296	ones: OP,
297}
298
299#[allow(private_bounds)]
300impl<OP> SimdTransformation<OP>
301where
302	OP: PackedBinaryField + WithUnderlier<Underlier: TowerSimdType + UnderlierWithBitOps>,
303{
304	pub fn new<Data: AsRef<[OP::Scalar]> + Sync>(
305		transformation: FieldLinearTransformation<OP::Scalar, Data>,
306	) -> Self {
307		Self {
308			bases: transformation
309				.bases()
310				.iter()
311				.map(|base| OP::broadcast(*base))
312				.collect(),
313			// Set ones to the highest bit
314			// This is the format that is used in SIMD masks
315			ones: OP::one().mutate_underlier(|underlier| underlier << (OP::Scalar::N_BITS - 1)),
316		}
317	}
318}
319
320impl<U, IP, OP, IF, OF> Transformation<IP, OP> for SimdTransformation<OP>
321where
322	IP: PackedField<Scalar = IF> + WithUnderlier<Underlier = U>,
323	OP: PackedField<Scalar = OF> + WithUnderlier<Underlier = U>,
324	IF: BinaryField,
325	OF: BinaryField,
326	U: UnderlierWithBitOps + TowerSimdType,
327{
328	fn transform(&self, input: &IP) -> OP {
329		let mut result = OP::zero();
330		let ones = self.ones.to_underlier();
331		let mut input = input.to_underlier();
332
333		// Unlike `PackedTransformation`, we iterate from the highest bit to lowest one
334		// keeping current component in the highest bit.
335		for base in self.bases.iter().rev() {
336			let bases_mask = input & ones;
337			let component = U::apply_mask::<OP::Scalar>(bases_mask, base.to_underlier());
338			result += OP::from_underlier(component);
339			input = input << 1;
340		}
341
342		result
343	}
344}
345
346impl<IP, OP> TaggedPackedTransformationFactory<SimdStrategy, OP> for IP
347where
348	IP: PackedBinaryField + WithUnderlier<Underlier: UnderlierWithBitOps>,
349	OP: PackedBinaryField + WithUnderlier<Underlier = IP::Underlier>,
350	IP::Underlier: TowerSimdType,
351{
352	type PackedTransformation<Data: AsRef<[<OP>::Scalar]> + Sync> = SimdTransformation<OP>;
353
354	fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
355		transformation: FieldLinearTransformation<OP::Scalar, Data>,
356	) -> Self::PackedTransformation<Data> {
357		SimdTransformation::new(transformation)
358	}
359}
360
361#[cfg(test)]
362mod tests {
363	use super::*;
364	use crate::test_utils::{
365		define_invert_tests, define_mul_alpha_tests, define_multiply_tests, define_square_tests,
366		define_transformation_tests,
367	};
368
369	define_multiply_tests!(TaggedMul<SimdStrategy>::mul, TaggedMul<SimdStrategy>);
370
371	define_square_tests!(TaggedSquare<SimdStrategy>::square, TaggedSquare<SimdStrategy>);
372
373	define_invert_tests!(
374		TaggedInvertOrZero<SimdStrategy>::invert_or_zero,
375		TaggedInvertOrZero<SimdStrategy>
376	);
377
378	define_mul_alpha_tests!(TaggedMulAlpha<SimdStrategy>::mul_alpha, TaggedMulAlpha<SimdStrategy>);
379
380	#[allow(unused)]
381	trait SelfPackedTransformationFactory:
382		TaggedPackedTransformationFactory<SimdStrategy, Self>
383	{
384	}
385
386	impl<T: TaggedPackedTransformationFactory<SimdStrategy, Self>> SelfPackedTransformationFactory
387		for T
388	{
389	}
390
391	define_transformation_tests!(SelfPackedTransformationFactory);
392}