1use std::{any::TypeId, arch::x86_64::*};
4
5use crate::{
6 aes_field::AESTowerField8b,
7 arch::{
8 portable::{
9 packed::PackedPrimitiveType, packed_arithmetic::PackedTowerField,
10 reuse_multiply_arithmetic::Alpha,
11 },
12 SimdStrategy,
13 },
14 arithmetic_traits::{
15 MulAlpha, TaggedInvertOrZero, TaggedMul, TaggedMulAlpha, TaggedPackedTransformationFactory,
16 TaggedSquare,
17 },
18 linear_transformation::{FieldLinearTransformation, Transformation},
19 packed::PackedBinaryField,
20 underlier::{UnderlierType, UnderlierWithBitOps, WithUnderlier},
21 BinaryField, BinaryField8b, PackedField, TowerField,
22};
23
24pub trait TowerSimdType: Sized + Copy + UnderlierWithBitOps {
25 fn blend_odd_even<Scalar: BinaryField>(a: Self, b: Self) -> Self;
27 fn set_alpha_even<Scalar: BinaryField>(self) -> Self;
29 fn apply_mask<Scalar: BinaryField>(mask: Self, a: Self) -> Self;
31
32 fn xor(a: Self, b: Self) -> Self;
34
35 fn shuffle_epi8(a: Self, b: Self) -> Self;
37
38 fn bslli_epi128<const IMM8: i32>(self) -> Self;
40 fn bsrli_epi128<const IMM8: i32>(self) -> Self;
41
42 fn set1_epi128(val: __m128i) -> Self;
44 fn set_epi_64(val: i64) -> Self;
45
46 #[inline(always)]
47 fn dup_shuffle<Scalar: BinaryField>() -> Self {
48 let shuffle_mask_128 = unsafe {
49 match Scalar::N_BITS.ilog2() {
50 3 => _mm_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0),
51 4 => _mm_set_epi8(13, 12, 13, 12, 9, 8, 9, 8, 5, 4, 5, 4, 1, 0, 1, 0),
52 5 => _mm_set_epi8(11, 10, 9, 8, 11, 10, 9, 8, 3, 2, 1, 0, 3, 2, 1, 0),
53 6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0),
54 _ => panic!("unsupported bit count"),
55 }
56 };
57
58 Self::set1_epi128(shuffle_mask_128)
59 }
60
61 #[inline(always)]
62 fn flip_shuffle<Scalar: BinaryField>() -> Self {
63 let flip_mask_128 = unsafe {
64 match Scalar::N_BITS.ilog2() {
65 3 => _mm_set_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1),
66 4 => _mm_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2),
67 5 => _mm_set_epi8(11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4),
68 6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8),
69 _ => panic!("unsupported bit count"),
70 }
71 };
72
73 Self::set1_epi128(flip_mask_128)
74 }
75
76 #[inline(always)]
78 fn make_epi8_mask_shuffle<Scalar: BinaryField>() -> Self {
79 let epi8_mask_128 = unsafe {
80 match Scalar::N_BITS.ilog2() {
81 4 => _mm_set_epi8(15, 15, 13, 13, 11, 11, 9, 9, 7, 7, 5, 5, 3, 3, 1, 1),
82 5 => _mm_set_epi8(15, 15, 15, 15, 11, 11, 11, 11, 7, 7, 7, 7, 3, 3, 3, 3),
83 6 => _mm_set_epi8(15, 15, 15, 15, 15, 15, 15, 15, 7, 7, 7, 7, 7, 7, 7, 7),
84 7 => _mm_set1_epi8(15),
85 _ => panic!("unsupported bit count"),
86 }
87 };
88
89 Self::set1_epi128(epi8_mask_128)
90 }
91
92 #[inline(always)]
93 fn alpha<Scalar: BinaryField>() -> Self {
94 let alpha_128 = unsafe {
95 match Scalar::N_BITS.ilog2() {
96 3 => {
97 let type_id = TypeId::of::<Scalar>();
99 let value = if type_id == TypeId::of::<BinaryField8b>() {
100 0x10
101 } else if type_id == TypeId::of::<AESTowerField8b>() {
102 0xd3u8 as i8
103 } else {
104 panic!("tower field not supported")
105 };
106 _mm_set1_epi8(value)
107 }
108 4 => _mm_set1_epi16(0x0100),
109 5 => _mm_set1_epi32(0x00010000),
110 6 => _mm_set1_epi64x(0x0000000100000000),
111 _ => panic!("unsupported bit count"),
112 }
113 };
114
115 Self::set1_epi128(alpha_128)
116 }
117
118 #[inline(always)]
119 fn even_mask<Scalar: BinaryField>() -> Self {
120 let mask_128 = unsafe {
121 match Scalar::N_BITS.ilog2() {
122 3 => _mm_set1_epi16(0x00FF),
123 4 => _mm_set1_epi32(0x0000FFFF),
124 5 => _mm_set1_epi64x(0x00000000FFFFFFFF),
125 6 => _mm_set_epi64x(0, -1),
126 _ => panic!("unsupported bit count"),
127 }
128 };
129
130 Self::set1_epi128(mask_128)
131 }
132}
133
134impl<U: UnderlierType + TowerSimdType, Scalar: TowerField> Alpha
135 for PackedPrimitiveType<U, Scalar>
136{
137 #[inline(always)]
138 fn alpha() -> Self {
139 U::alpha::<Scalar>().into()
140 }
141}
142
143#[inline(always)]
144fn blend_odd_even<U, PT>(a: PT, b: PT) -> PT
145where
146 U: TowerSimdType,
147 PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
148{
149 PT::from_underlier(U::blend_odd_even::<PT::Scalar>(a.to_underlier(), b.to_underlier()))
150}
151
152#[inline(always)]
153fn xor<U, PT>(a: PT, b: PT) -> PT
154where
155 U: TowerSimdType,
156 PT: WithUnderlier<Underlier = U>,
157{
158 PT::from_underlier(U::xor(a.to_underlier(), b.to_underlier()))
159}
160
161#[inline(always)]
162fn duplicate_odd<U, PT>(val: PT) -> PT
163where
164 U: TowerSimdType,
165 PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
166{
167 PT::from_underlier(U::shuffle_epi8(val.to_underlier(), U::dup_shuffle::<PT::Scalar>()))
168}
169
170#[inline(always)]
171fn flip_even_odd<U, PT>(val: PT) -> PT
172where
173 U: TowerSimdType,
174 PT: PackedField<Scalar: TowerField> + WithUnderlier<Underlier = U>,
175{
176 PT::from_underlier(U::shuffle_epi8(val.to_underlier(), U::flip_shuffle::<PT::Scalar>()))
177}
178
179impl<U, Scalar: TowerField> TaggedMul<SimdStrategy> for PackedPrimitiveType<U, Scalar>
180where
181 Self: PackedTowerField<Underlier = U>,
182 U: TowerSimdType + UnderlierType,
183{
184 fn mul(self, rhs: Self) -> Self {
185 if Scalar::TOWER_LEVEL <= 3 {
187 return self * rhs;
188 }
189
190 let a = self.as_packed_subfield();
191 let b = rhs.as_packed_subfield();
192
193 let z0_even_z2_odd = a * b;
195
196 let (lo, hi) = a.interleave(b, 0);
199 let lo_plus_hi_a_even_b_odd = lo + hi;
201
202 let alpha_even_z2_odd = <Self as PackedTowerField>::PackedDirectSubfield::from_underlier(
203 z0_even_z2_odd
204 .to_underlier()
205 .set_alpha_even::<<Self as PackedTowerField>::DirectSubfield>(),
206 );
207 let (lhs, rhs) = lo_plus_hi_a_even_b_odd.interleave(alpha_even_z2_odd, 0);
208 let z1_xor_z0z2_even_z2a_odd = lhs * rhs;
209
210 let z1_xor_z0z2 = duplicate_odd(z1_xor_z0z2_even_z2a_odd);
211 let zero_even_z1_xor_z2a_xor_z0z2_odd = xor(z1_xor_z0z2_even_z2a_odd, z1_xor_z0z2);
212
213 let z2_even_z0_odd = flip_even_odd(z0_even_z2_odd);
214 let z0z2 = xor(z0_even_z2_odd, z2_even_z0_odd);
215
216 Self::from_packed_subfield(xor(zero_even_z1_xor_z2a_xor_z0z2_odd, z0z2))
217 }
218}
219
220impl<U, Scalar: TowerField> TaggedMulAlpha<SimdStrategy> for PackedPrimitiveType<U, Scalar>
221where
222 Self: PackedTowerField<Underlier = U> + MulAlpha,
223 <Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
224 U: TowerSimdType + UnderlierType,
225{
226 #[inline]
227 fn mul_alpha(self) -> Self {
228 if Scalar::TOWER_LEVEL <= 3 {
230 return MulAlpha::mul_alpha(self);
231 }
232
233 let a_0_a_1 = self.as_packed_subfield();
234 let a_0_mul_alpha_a_1_mul_alpha = a_0_a_1.mul_alpha();
235
236 let a_1_a_0 = flip_even_odd(self.as_packed_subfield());
237 let a0_plus_a1_alpha = xor(a_0_mul_alpha_a_1_mul_alpha, a_1_a_0);
238
239 Self::from_packed_subfield(blend_odd_even(a0_plus_a1_alpha, a_1_a_0))
240 }
241}
242
243impl<U, Scalar: TowerField> TaggedSquare<SimdStrategy> for PackedPrimitiveType<U, Scalar>
244where
245 Self: PackedTowerField<Underlier = U>,
246 <Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
247 U: TowerSimdType + UnderlierType,
248{
249 fn square(self) -> Self {
250 if Scalar::TOWER_LEVEL <= 3 {
252 return PackedField::square(self);
253 }
254
255 let a_0_a_1 = self.as_packed_subfield();
256 let a_0_sq_a_1_sq = PackedField::square(a_0_a_1);
257 let a_1_sq_a_0_sq = flip_even_odd(a_0_sq_a_1_sq);
258 let a_0_sq_plus_a_1_sq = a_0_sq_a_1_sq + a_1_sq_a_0_sq;
259 let a_1_mul_alpha = a_0_sq_a_1_sq.mul_alpha();
260
261 Self::from_packed_subfield(blend_odd_even(a_1_mul_alpha, a_0_sq_plus_a_1_sq))
262 }
263}
264
265impl<U, Scalar: TowerField> TaggedInvertOrZero<SimdStrategy> for PackedPrimitiveType<U, Scalar>
266where
267 Self: PackedTowerField<Underlier = U>,
268 <Self as PackedTowerField>::PackedDirectSubfield: MulAlpha,
269 U: TowerSimdType + UnderlierType,
270{
271 fn invert_or_zero(self) -> Self {
272 if Scalar::TOWER_LEVEL <= 3 {
274 return PackedField::invert_or_zero(self);
275 }
276
277 let a_0_a_1 = self.as_packed_subfield();
278 let a_1_a_0 = flip_even_odd(a_0_a_1);
279 let a_1_mul_alpha = a_1_a_0.mul_alpha();
280 let a_0_plus_a1_mul_alpha = xor(a_0_a_1, a_1_mul_alpha);
281 let a_1_sq_a_0_sq = PackedField::square(a_1_a_0);
282 let delta = xor(a_1_sq_a_0_sq, a_0_plus_a1_mul_alpha * a_0_a_1);
283 let delta_inv = PackedField::invert_or_zero(delta);
284 let delta_inv_delta_inv = duplicate_odd(delta_inv);
285 let delta_multiplier = blend_odd_even(a_0_a_1, a_0_plus_a1_mul_alpha);
286
287 Self::from_packed_subfield(delta_inv_delta_inv * delta_multiplier)
288 }
289}
290
291pub struct SimdTransformation<OP> {
295 bases: Vec<OP>,
296 ones: OP,
297}
298
299#[allow(private_bounds)]
300impl<OP> SimdTransformation<OP>
301where
302 OP: PackedBinaryField + WithUnderlier<Underlier: TowerSimdType + UnderlierWithBitOps>,
303{
304 pub fn new<Data: AsRef<[OP::Scalar]> + Sync>(
305 transformation: FieldLinearTransformation<OP::Scalar, Data>,
306 ) -> Self {
307 Self {
308 bases: transformation
309 .bases()
310 .iter()
311 .map(|base| OP::broadcast(*base))
312 .collect(),
313 ones: OP::one().mutate_underlier(|underlier| underlier << (OP::Scalar::N_BITS - 1)),
316 }
317 }
318}
319
320impl<U, IP, OP, IF, OF> Transformation<IP, OP> for SimdTransformation<OP>
321where
322 IP: PackedField<Scalar = IF> + WithUnderlier<Underlier = U>,
323 OP: PackedField<Scalar = OF> + WithUnderlier<Underlier = U>,
324 IF: BinaryField,
325 OF: BinaryField,
326 U: UnderlierWithBitOps + TowerSimdType,
327{
328 fn transform(&self, input: &IP) -> OP {
329 let mut result = OP::zero();
330 let ones = self.ones.to_underlier();
331 let mut input = input.to_underlier();
332
333 for base in self.bases.iter().rev() {
336 let bases_mask = input & ones;
337 let component = U::apply_mask::<OP::Scalar>(bases_mask, base.to_underlier());
338 result += OP::from_underlier(component);
339 input = input << 1;
340 }
341
342 result
343 }
344}
345
346impl<IP, OP> TaggedPackedTransformationFactory<SimdStrategy, OP> for IP
347where
348 IP: PackedBinaryField + WithUnderlier<Underlier: UnderlierWithBitOps>,
349 OP: PackedBinaryField + WithUnderlier<Underlier = IP::Underlier>,
350 IP::Underlier: TowerSimdType,
351{
352 type PackedTransformation<Data: AsRef<[<OP>::Scalar]> + Sync> = SimdTransformation<OP>;
353
354 fn make_packed_transformation<Data: AsRef<[OP::Scalar]> + Sync>(
355 transformation: FieldLinearTransformation<OP::Scalar, Data>,
356 ) -> Self::PackedTransformation<Data> {
357 SimdTransformation::new(transformation)
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::test_utils::{
365 define_invert_tests, define_mul_alpha_tests, define_multiply_tests, define_square_tests,
366 define_transformation_tests,
367 };
368
369 define_multiply_tests!(TaggedMul<SimdStrategy>::mul, TaggedMul<SimdStrategy>);
370
371 define_square_tests!(TaggedSquare<SimdStrategy>::square, TaggedSquare<SimdStrategy>);
372
373 define_invert_tests!(
374 TaggedInvertOrZero<SimdStrategy>::invert_or_zero,
375 TaggedInvertOrZero<SimdStrategy>
376 );
377
378 define_mul_alpha_tests!(TaggedMulAlpha<SimdStrategy>::mul_alpha, TaggedMulAlpha<SimdStrategy>);
379
380 #[allow(unused)]
381 trait SelfPackedTransformationFactory:
382 TaggedPackedTransformationFactory<SimdStrategy, Self>
383 {
384 }
385
386 impl<T: TaggedPackedTransformationFactory<SimdStrategy, Self>> SelfPackedTransformationFactory
387 for T
388 {
389 }
390
391 define_transformation_tests!(SelfPackedTransformationFactory);
392}