1use std::{any::TypeId, arch::x86_64::*};
4
5use crate::{
6 BinaryField, PackedField, TowerField,
7 aes_field::AESTowerField8b,
8 arch::{
9 SimdStrategy,
10 portable::{
11 packed::PackedPrimitiveType, packed_arithmetic::PackedTowerField,
12 reuse_multiply_arithmetic::Alpha,
13 },
14 },
15 arithmetic_traits::{
16 MulAlpha, TaggedInvertOrZero, TaggedMul, TaggedMulAlpha, TaggedPackedTransformationFactory,
17 TaggedSquare,
18 },
19 linear_transformation::{FieldLinearTransformation, Transformation},
20 packed::PackedBinaryField,
21 underlier::{UnderlierType, UnderlierWithBitOps, WithUnderlier},
22};
23
24pub trait TowerSimdType: Sized + Copy + UnderlierWithBitOps {
25 fn blend_odd_even<Scalar: BinaryField>(a: Self, b: Self) -> Self;
27 fn set_alpha_even<Scalar: BinaryField>(self) -> Self;
29 fn apply_mask<Scalar: BinaryField>(mask: Self, a: Self) -> Self;
31
32 fn xor(a: Self, b: Self) -> Self;
34
35 fn shuffle_epi8(a: Self, b: Self) -> Self;
37
38 fn bslli_epi128<const IMM8: i32>(self) -> Self;
40 fn bsrli_epi128<const IMM8: i32>(self) -> Self;
41
42 fn set1_epi128(val: __m128i) -> Self;
44 fn set_epi_64(val: i64) -> Self;
45
46 #[inline(always)]
47 fn dup_shuffle<Scalar: BinaryField>() -> Self {
48 let shuffle_mask_128 = unsafe {
49 match Scalar::N_BITS.ilog2() {
50 3 => _mm_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0),
51 4 => _mm_set_epi8(13, 12, 13, 12, 9, 8, 9, 8, 5, 4, 5, 4, 1, 0, 1, 0),
52 5 => _mm_set_epi8(11, 10, 9, 8, 11, 10, 9, 8, 3, 2, 1, 0, 3, 2, 1, 0),
53 6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0),
54 _ => panic!("unsupported bit count"),
55 }
56 };
57
58 Self::set1_epi128(shuffle_mask_128)
59 }
60
61 #[inline(always)]
62 fn flip_shuffle<Scalar: BinaryField>() -> Self {
63 let flip_mask_128 = unsafe {
64 match Scalar::N_BITS.ilog2() {
65 3 => _mm_set_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1),
66 4 => _mm_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2),
67 5 => _mm_set_epi8(11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4),
68 6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8),
69 _ => panic!("unsupported bit count"),
70 }
71 };
72
73 Self::set1_epi128(flip_mask_128)
74 }
75
76 #[inline(always)]
78 fn make_epi8_mask_shuffle<Scalar: BinaryField>() -> Self {
79 let epi8_mask_128 = unsafe {
80 match Scalar::N_BITS.ilog2() {
81 4 => _mm_set_epi8(15, 15, 13, 13, 11, 11, 9, 9, 7, 7, 5, 5, 3, 3, 1, 1),
82 5 => _mm_set_epi8(15, 15, 15, 15, 11, 11, 11, 11, 7, 7, 7, 7, 3, 3, 3, 3),
83 6 => _mm_set_epi8(15, 15, 15, 15, 15, 15, 15, 15, 7, 7, 7, 7, 7, 7, 7, 7),
84 7 => _mm_set1_epi8(15),
85 _ => panic!("unsupported bit count"),
86 }
87 };
88
89 Self::set1_epi128(epi8_mask_128)
90 }
91
92 #[inline(always)]
93 fn alpha<Scalar: BinaryField>() -> Self {
94 let alpha_128 = {
95 match Scalar::N_BITS.ilog2() {
96 3 => {
97 let type_id = TypeId::of::<Scalar>();
99 let value = if type_id == TypeId::of::<AESTowerField8b>() {
100 0xd3u8 as i8
101 } else {
102 panic!("tower field not supported")
103 };
104 unsafe { _mm_set1_epi8(value) }
105 }
106 4 => unsafe { _mm_set1_epi16(0x0100) },
107 5 => unsafe { _mm_set1_epi32(0x00010000) },
108 6 => unsafe { _mm_set1_epi64x(0x0000000100000000) },
109 _ => panic!("unsupported bit count"),
110 }
111 };
112
113 Self::set1_epi128(alpha_128)
114 }
115
116 #[inline(always)]
117 fn even_mask<Scalar: BinaryField>() -> Self {
118 let mask_128 = {
119 match Scalar::N_BITS.ilog2() {
120 3 => unsafe { _mm_set1_epi16(0x00FF) },
121 4 => unsafe { _mm_set1_epi32(0x0000FFFF) },
122 5 => unsafe { _mm_set1_epi64x(0x00000000FFFFFFFF) },
123 6 => unsafe { _mm_set_epi64x(0, -1) },
124 _ => panic!("unsupported bit count"),
125 }
126 };
127
128 Self::set1_epi128(mask_128)
129 }
130}
131
132impl<U: UnderlierType + TowerSimdType, Scalar: TowerField> Alpha
133 for PackedPrimitiveType<U, Scalar>
134{
135 #[inline(always)]
136 fn alpha() -> Self {
137 U::alpha::<Scalar>().into()
138 }
139}
140
141#[inline(always)]
142fn blend_odd_even<U, PT>(a: PT, b: PT) -> PT
143where
144 U: TowerSimdType,
145 PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
146{
147 PT::from_underlier(U::blend_odd_even::<PT::Scalar>(a.to_underlier(), b.to_underlier()))
148}
149
150#[inline(always)]
151fn xor<U, PT>(a: PT, b: PT) -> PT
152where
153 U: TowerSimdType,
154 PT: WithUnderlier<Underlier = U>,
155{
156 PT::from_underlier(U::xor(a.to_underlier(), b.to_underlier()))
157}
158
159#[inline(always)]
160fn duplicate_odd<U, PT>(val: PT) -> PT
161where
162 U: TowerSimdType,
163 PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
164{
165 PT::from_underlier(U::shuffle_epi8(val.to_underlier(), U::dup_shuffle::<PT::Scalar>()))
166}
167
168#[inline(always)]
169fn flip_even_odd<U, PT>(val: PT) -> PT
170where
171 U: TowerSimdType,
172 PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
173{
174 PT::from_underlier(U::shuffle_epi8(val.to_underlier(), U::flip_shuffle::<PT::Scalar>()))
175}
176
177impl<U, Scalar: TowerField> TaggedMul<SimdStrategy> for PackedPrimitiveType<U, Scalar>
178where
179 Self: PackedTowerField<Underlier = U>,
180 U: TowerSimdType + UnderlierType,
181{
182 fn mul(self, rhs: Self) -> Self {
183 if Scalar::TOWER_LEVEL <= 3 {
185 return self * rhs;
186 }
187
188 let a = self.as_packed_subfield();
189 let b = rhs.as_packed_subfield();
190
191 let z0_even_z2_odd = a * b;
193
194 let (lo, hi) = a.interleave(b, 0);
197 let lo_plus_hi_a_even_b_odd = lo + hi;
199
200 let alpha_even_z2_odd = <Self as PackedTowerField>::PackedDirectSubfield::from_underlier(
201 z0_even_z2_odd
202 .to_underlier()
203 .set_alpha_even::<<Self as PackedTowerField>::DirectSubfield>(),
204 );
205 let (lhs, rhs) = lo_plus_hi_a_even_b_odd.interleave(alpha_even_z2_odd, 0);
206 let z1_xor_z0z2_even_z2a_odd = lhs * rhs;
207
208 let z1_xor_z0z2 = duplicate_odd(z1_xor_z0z2_even_z2a_odd);
209 let zero_even_z1_xor_z2a_xor_z0z2_odd = xor(z1_xor_z0z2_even_z2a_odd, z1_xor_z0z2);
210
211 let z2_even_z0_odd = flip_even_odd(z0_even_z2_odd);
212 let z0z2 = xor(z0_even_z2_odd, z2_even_z0_odd);
213
214 Self::from_packed_subfield(xor(zero_even_z1_xor_z2a_xor_z0z2_odd, z0z2))
215 }
216}
217
218impl<U, Scalar: TowerField> TaggedMulAlpha<SimdStrategy> for PackedPrimitiveType<U, Scalar>
219where
220 Self: PackedTowerField<Underlier = U> + MulAlpha,
221 <Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
222 U: TowerSimdType + UnderlierType,
223{
224 #[inline]
225 fn mul_alpha(self) -> Self {
226 if Scalar::TOWER_LEVEL <= 3 {
228 return MulAlpha::mul_alpha(self);
229 }
230
231 let a_0_a_1 = self.as_packed_subfield();
232 let a_0_mul_alpha_a_1_mul_alpha = a_0_a_1.mul_alpha();
233
234 let a_1_a_0 = flip_even_odd(self.as_packed_subfield());
235 let a0_plus_a1_alpha = xor(a_0_mul_alpha_a_1_mul_alpha, a_1_a_0);
236
237 Self::from_packed_subfield(blend_odd_even(a0_plus_a1_alpha, a_1_a_0))
238 }
239}
240
241impl<U, Scalar: TowerField> TaggedSquare<SimdStrategy> for PackedPrimitiveType<U, Scalar>
242where
243 Self: PackedTowerField<Underlier = U>,
244 <Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
245 U: TowerSimdType + UnderlierType,
246{
247 fn square(self) -> Self {
248 if Scalar::TOWER_LEVEL <= 3 {
250 return PackedField::square(self);
251 }
252
253 let a_0_a_1 = self.as_packed_subfield();
254 let a_0_sq_a_1_sq = PackedField::square(a_0_a_1);
255 let a_1_sq_a_0_sq = flip_even_odd(a_0_sq_a_1_sq);
256 let a_0_sq_plus_a_1_sq = a_0_sq_a_1_sq + a_1_sq_a_0_sq;
257 let a_1_mul_alpha = a_0_sq_a_1_sq.mul_alpha();
258
259 Self::from_packed_subfield(blend_odd_even(a_1_mul_alpha, a_0_sq_plus_a_1_sq))
260 }
261}
262
263impl<U, Scalar: TowerField> TaggedInvertOrZero<SimdStrategy> for PackedPrimitiveType<U, Scalar>
264where
265 Self: PackedTowerField<Underlier = U>,
266 <Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
267 U: TowerSimdType + UnderlierType,
268{
269 fn invert_or_zero(self) -> Self {
270 if Scalar::TOWER_LEVEL <= 3 {
272 return PackedField::invert_or_zero(self);
273 }
274
275 let a_0_a_1 = self.as_packed_subfield();
276 let a_1_a_0 = flip_even_odd(a_0_a_1);
277 let a_1_mul_alpha = a_1_a_0.mul_alpha();
278 let a_0_plus_a1_mul_alpha = xor(a_0_a_1, a_1_mul_alpha);
279 let a_1_sq_a_0_sq = PackedField::square(a_1_a_0);
280 let delta = xor(a_1_sq_a_0_sq, a_0_plus_a1_mul_alpha * a_0_a_1);
281 let delta_inv = PackedField::invert_or_zero(delta);
282 let delta_inv_delta_inv = duplicate_odd(delta_inv);
283 let delta_multiplier = blend_odd_even(a_0_a_1, a_0_plus_a1_mul_alpha);
284
285 Self::from_packed_subfield(delta_inv_delta_inv * delta_multiplier)
286 }
287}
288
289pub struct SimdTransformation<OP> {
293 bases: Vec<OP>,
294 ones: OP,
295}
296
297#[allow(private_bounds)]
298impl<OP> SimdTransformation<OP>
299where
300 OP: PackedBinaryField + WithUnderlier<Underlier: TowerSimdType + UnderlierWithBitOps>,
301{
302 pub fn new<Data: AsRef<[OP::Scalar]> + Sync>(
303 transformation: FieldLinearTransformation<OP::Scalar, Data>,
304 ) -> Self {
305 Self {
306 bases: transformation
307 .bases()
308 .iter()
309 .map(|base| OP::broadcast(*base))
310 .collect(),
311 ones: OP::one().mutate_underlier(|underlier| underlier << (OP::Scalar::N_BITS - 1)),
314 }
315 }
316}
317
318impl<U, IP, OP, IF, OF> Transformation<IP, OP> for SimdTransformation<OP>
319where
320 IP: PackedField<Scalar = IF> + WithUnderlier<Underlier = U>,
321 OP: PackedField<Scalar = OF> + WithUnderlier<Underlier = U>,
322 IF: BinaryField,
323 OF: BinaryField,
324 U: UnderlierWithBitOps + TowerSimdType,
325{
326 fn transform(&self, input: &IP) -> OP {
327 let mut result = OP::zero();
328 let ones = self.ones.to_underlier();
329 let mut input = input.to_underlier();
330
331 for base in self.bases.iter().rev() {
334 let bases_mask = input & ones;
335 let component = U::apply_mask::<OP::Scalar>(bases_mask, base.to_underlier());
336 result += OP::from_underlier(component);
337 input = input << 1;
338 }
339
340 result
341 }
342}
343
344impl<IP, OP> TaggedPackedTransformationFactory<SimdStrategy, OP> for IP
345where
346 IP: PackedBinaryField + WithUnderlier<Underlier: UnderlierWithBitOps>,
347 OP: PackedBinaryField + WithUnderlier<Underlier = IP::Underlier>,
348 IP::Underlier: TowerSimdType,
349{
350 type PackedTransformation<Data: AsRef<[<OP>::Scalar]> + Sync> = SimdTransformation<OP>;
351
352 fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
353 transformation: FieldLinearTransformation<OP::Scalar, Data>,
354 ) -> Self::PackedTransformation<Data> {
355 SimdTransformation::new(transformation)
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use crate::test_utils::{
363 define_invert_tests, define_multiply_tests, define_square_tests,
364 define_transformation_tests,
365 };
366
367 define_multiply_tests!(TaggedMul<SimdStrategy>::mul, TaggedMul<SimdStrategy>);
368
369 define_square_tests!(TaggedSquare<SimdStrategy>::square, TaggedSquare<SimdStrategy>);
370
371 define_invert_tests!(
372 TaggedInvertOrZero<SimdStrategy>::invert_or_zero,
373 TaggedInvertOrZero<SimdStrategy>
374 );
375
376 #[allow(unused)]
377 trait SelfPackedTransformationFactory:
378 TaggedPackedTransformationFactory<SimdStrategy, Self>
379 {
380 }
381
382 impl<T: TaggedPackedTransformationFactory<SimdStrategy, Self>> SelfPackedTransformationFactory
383 for T
384 {
385 }
386
387 define_transformation_tests!(SelfPackedTransformationFactory);
388}