1use std::{any::TypeId, arch::x86_64::*};
4
5use crate::{
6 BinaryField, TowerField,
7 aes_field::AESTowerField8b,
8 arch::portable::{packed::PackedPrimitiveType, reuse_multiply_arithmetic::Alpha},
9 underlier::{UnderlierType, UnderlierWithBitOps},
10};
11
12pub trait TowerSimdType: Sized + Copy + UnderlierWithBitOps {
13 fn blend_odd_even<Scalar: BinaryField>(a: Self, b: Self) -> Self;
15 fn set_alpha_even<Scalar: BinaryField>(self) -> Self;
17 fn apply_mask<Scalar: BinaryField>(mask: Self, a: Self) -> Self;
19
20 fn xor(a: Self, b: Self) -> Self;
22
23 fn shuffle_epi8(a: Self, b: Self) -> Self;
25
26 fn bslli_epi128<const IMM8: i32>(self) -> Self;
28 fn bsrli_epi128<const IMM8: i32>(self) -> Self;
29
30 fn set1_epi128(val: __m128i) -> Self;
32 fn set_epi_64(val: i64) -> Self;
33
34 #[inline(always)]
35 fn dup_shuffle<Scalar: BinaryField>() -> Self {
36 let shuffle_mask_128 = unsafe {
37 match Scalar::N_BITS.ilog2() {
38 3 => _mm_set_epi8(14, 14, 12, 12, 10, 10, 8, 8, 6, 6, 4, 4, 2, 2, 0, 0),
39 4 => _mm_set_epi8(13, 12, 13, 12, 9, 8, 9, 8, 5, 4, 5, 4, 1, 0, 1, 0),
40 5 => _mm_set_epi8(11, 10, 9, 8, 11, 10, 9, 8, 3, 2, 1, 0, 3, 2, 1, 0),
41 6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, 1, 0),
42 _ => panic!("unsupported bit count"),
43 }
44 };
45
46 Self::set1_epi128(shuffle_mask_128)
47 }
48
49 #[inline(always)]
50 fn flip_shuffle<Scalar: BinaryField>() -> Self {
51 let flip_mask_128 = unsafe {
52 match Scalar::N_BITS.ilog2() {
53 3 => _mm_set_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1),
54 4 => _mm_set_epi8(13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2),
55 5 => _mm_set_epi8(11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4),
56 6 => _mm_set_epi8(7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8),
57 _ => panic!("unsupported bit count"),
58 }
59 };
60
61 Self::set1_epi128(flip_mask_128)
62 }
63
64 #[inline(always)]
66 fn make_epi8_mask_shuffle<Scalar: BinaryField>() -> Self {
67 let epi8_mask_128 = unsafe {
68 match Scalar::N_BITS.ilog2() {
69 4 => _mm_set_epi8(15, 15, 13, 13, 11, 11, 9, 9, 7, 7, 5, 5, 3, 3, 1, 1),
70 5 => _mm_set_epi8(15, 15, 15, 15, 11, 11, 11, 11, 7, 7, 7, 7, 3, 3, 3, 3),
71 6 => _mm_set_epi8(15, 15, 15, 15, 15, 15, 15, 15, 7, 7, 7, 7, 7, 7, 7, 7),
72 7 => _mm_set1_epi8(15),
73 _ => panic!("unsupported bit count"),
74 }
75 };
76
77 Self::set1_epi128(epi8_mask_128)
78 }
79
80 #[inline(always)]
81 fn alpha<Scalar: BinaryField>() -> Self {
82 let alpha_128 = {
83 match Scalar::N_BITS.ilog2() {
84 3 => {
85 let type_id = TypeId::of::<Scalar>();
87 let value = if type_id == TypeId::of::<AESTowerField8b>() {
88 0xd3u8 as i8
89 } else {
90 panic!("tower field not supported")
91 };
92 unsafe { _mm_set1_epi8(value) }
93 }
94 4 => unsafe { _mm_set1_epi16(0x0100) },
95 5 => unsafe { _mm_set1_epi32(0x00010000) },
96 6 => unsafe { _mm_set1_epi64x(0x0000000100000000) },
97 _ => panic!("unsupported bit count"),
98 }
99 };
100
101 Self::set1_epi128(alpha_128)
102 }
103
104 #[inline(always)]
105 fn even_mask<Scalar: BinaryField>() -> Self {
106 let mask_128 = {
107 match Scalar::N_BITS.ilog2() {
108 3 => unsafe { _mm_set1_epi16(0x00FF) },
109 4 => unsafe { _mm_set1_epi32(0x0000FFFF) },
110 5 => unsafe { _mm_set1_epi64x(0x00000000FFFFFFFF) },
111 6 => unsafe { _mm_set_epi64x(0, -1) },
112 _ => panic!("unsupported bit count"),
113 }
114 };
115
116 Self::set1_epi128(mask_128)
117 }
118}
119
120impl<U: UnderlierType + TowerSimdType, Scalar: TowerField> Alpha
121 for PackedPrimitiveType<U, Scalar>
122{
123 #[inline(always)]
124 fn alpha() -> Self {
125 U::alpha::<Scalar>().into()
126 }
127}