1use std::{
4 any::TypeId,
5 fmt::{Debug, Display, Formatter},
6 iter::{Product, Sum},
7 marker::PhantomData,
8 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
9};
10
11use binius_utils::{
12 DeserializeBytes, SerializationError, SerializeBytes,
13 bytes::{Buf, BufMut},
14};
15use bytemuck::{Pod, Zeroable};
16
17use super::{
18 Error, PackedExtension, PackedSubfield,
19 arithmetic_traits::InvertOrZero,
20 binary_field::{BinaryField, BinaryField1b, binary_field, impl_field_extension},
21 binary_field_arithmetic::TowerFieldArithmetic,
22 mul_by_binary_field_1b,
23};
24use crate::{
25 ExtensionField, Field, TowerField, binary_field_arithmetic::impl_arithmetic_using_packed,
26 linear_transformation::Transformation, underlier::U1,
27};
28
29binary_field!(pub AESTowerField8b(u8), 0xD0);
37
38unsafe impl Pod for AESTowerField8b {}
39
40impl_field_extension!(BinaryField1b(U1) < @3 => AESTowerField8b(u8));
41
42mul_by_binary_field_1b!(AESTowerField8b);
43
44impl_arithmetic_using_packed!(AESTowerField8b);
45
46impl TowerField for AESTowerField8b {
47 fn min_tower_level(self) -> usize {
48 match self {
49 Self::ZERO | Self::ONE => 0,
50 _ => 3,
51 }
52 }
53
54 fn mul_primitive(self, iota: usize) -> Result<Self, Error> {
55 match iota {
56 0..=1 => Ok(self * ISOMORPHIC_ALPHAS[iota]),
57 2 => Ok(self.multiply_alpha()),
58 _ => Err(Error::ExtensionDegreeMismatch),
59 }
60 }
61}
62
63#[inline(always)]
65pub fn is_aes_tower<F: TowerField>() -> bool {
66 TypeId::of::<F>() == TypeId::of::<AESTowerField8b>()
67}
68
69pub struct SubfieldTransformer<IF, OF, T> {
74 inner_transform: T,
75 _ip_pd: PhantomData<IF>,
76 _op_pd: PhantomData<OF>,
77}
78
79impl<IF, OF, IEP, OEP, T> Transformation<IEP, OEP> for SubfieldTransformer<IF, OF, T>
80where
81 IF: Field,
82 OF: Field,
83 IEP: PackedExtension<IF>,
84 OEP: PackedExtension<OF>,
85 T: Transformation<PackedSubfield<IEP, IF>, PackedSubfield<OEP, OF>>,
86{
87 fn transform(&self, input: &IEP) -> OEP {
88 OEP::cast_ext(self.inner_transform.transform(IEP::cast_base_ref(input)))
89 }
90}
91
92const ISOMORPHIC_ALPHAS: [AESTowerField8b; 3] = [
94 AESTowerField8b(0xBC),
95 AESTowerField8b(0xB0),
96 AESTowerField8b(0xD3),
97];
98
99macro_rules! serialize_deserialize_non_canonical {
100 ($field:ident, canonical=$canonical:ident) => {
101 impl SerializeBytes for $field {
102 fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
103 self.0.serialize(write_buf)
104 }
105 }
106
107 impl DeserializeBytes for $field {
108 fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError>
109 where
110 Self: Sized,
111 {
112 Ok(Self(DeserializeBytes::deserialize(read_buf)?))
113 }
114 }
115 };
116}
117
118serialize_deserialize_non_canonical!(AESTowerField8b, canonical = BinaryField8b);
119
120#[cfg(test)]
121mod tests {
122 use binius_utils::{SerializeBytes, bytes::BytesMut};
123 use proptest::{arbitrary::any, proptest};
124 use rand::prelude::*;
125
126 use super::*;
127 use crate::{
128 Random, binary_field::tests::is_binary_field_valid_generator, underlier::WithUnderlier,
129 };
130
131 fn check_square(f: impl Field) {
132 assert_eq!(f.square(), f * f);
133 }
134
135 proptest! {
136 #[test]
137 fn test_square_8(a in any::<u8>()) {
138 check_square(AESTowerField8b::from(a))
139 }
140 }
141
142 fn check_invert(f: impl Field) {
143 let inversed = f.invert();
144 if f.is_zero() {
145 assert!(inversed.is_none());
146 } else {
147 assert_eq!(inversed.unwrap() * f, Field::ONE);
148 }
149 }
150
151 proptest! {
152 #[test]
153 fn test_invert_8(a in any::<u8>()) {
154 check_invert(AESTowerField8b::from(a))
155 }
156 }
157
158 fn check_mul_by_one<F: Field>(f: F) {
159 assert_eq!(F::ONE * f, f);
160 assert_eq!(f * F::ONE, f);
161 }
162
163 fn check_commutative<F: Field>(f_1: F, f_2: F) {
164 assert_eq!(f_1 * f_2, f_2 * f_1);
165 }
166
167 fn check_associativity_and_lineraity<F: Field>(f_1: F, f_2: F, f_3: F) {
168 assert_eq!(f_1 * (f_2 * f_3), (f_1 * f_2) * f_3);
169 assert_eq!(f_1 * (f_2 + f_3), f_1 * f_2 + f_1 * f_3);
170 }
171
172 fn check_mul<F: Field>(f_1: F, f_2: F, f_3: F) {
173 check_mul_by_one(f_1);
174 check_mul_by_one(f_2);
175 check_mul_by_one(f_3);
176
177 check_commutative(f_1, f_2);
178 check_commutative(f_1, f_3);
179 check_commutative(f_2, f_3);
180
181 check_associativity_and_lineraity(f_1, f_2, f_3);
182 check_associativity_and_lineraity(f_1, f_3, f_2);
183 check_associativity_and_lineraity(f_2, f_1, f_3);
184 check_associativity_and_lineraity(f_2, f_3, f_1);
185 check_associativity_and_lineraity(f_3, f_1, f_2);
186 check_associativity_and_lineraity(f_3, f_2, f_1);
187 }
188
189 proptest! {
190 #[test]
191 fn test_mul_8(a in any::<u8>(), b in any::<u8>(), c in any::<u8>()) {
192 check_mul(AESTowerField8b::from(a), AESTowerField8b::from(b), AESTowerField8b::from(c))
193 }
194 }
195
196 #[test]
197 fn test_multiplicative_generators() {
198 assert!(is_binary_field_valid_generator::<AESTowerField8b>());
199 }
200
201 fn test_mul_primitive<F: TowerField + WithUnderlier<Underlier: From<u8>>>(val: F, iota: usize) {
202 let result = val.mul_primitive(iota);
203 let expected = match iota {
204 0..=2 => {
205 Ok(val
206 * F::from_underlier(F::Underlier::from(ISOMORPHIC_ALPHAS[iota].to_underlier())))
207 }
208 _ => <F as ExtensionField<BinaryField1b>>::basis_checked(1 << iota).map(|b| val * b),
209 };
210 assert_eq!(result.is_ok(), expected.is_ok());
211 if result.is_ok() {
212 assert_eq!(result.unwrap(), expected.unwrap());
213 } else {
214 assert!(matches!(result.unwrap_err(), Error::ExtensionDegreeMismatch));
215 }
216 }
217
218 proptest! {
219 #[test]
220 fn test_mul_primitive_8b(val in 0u8.., iota in 3usize..8) {
221 test_mul_primitive::<AESTowerField8b>(val.into(), iota)
222 }
223 }
224
225 #[test]
226 fn test_serialization() {
227 let mut buffer = BytesMut::new();
228 let mut rng = StdRng::seed_from_u64(0);
229 let aes8 = AESTowerField8b::random(&mut rng);
230
231 SerializeBytes::serialize(&aes8, &mut buffer).unwrap();
232
233 let mut read_buffer = buffer.freeze();
234
235 assert_eq!(AESTowerField8b::deserialize(&mut read_buffer).unwrap(), aes8);
236 }
237}