1use std::{
4 fmt::{Debug, Display, Formatter},
5 iter::{Product, Sum},
6 marker::PhantomData,
7 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
8};
9
10use binius_utils::{
11 DeserializeBytes, SerializationError, SerializeBytes,
12 bytes::{Buf, BufMut},
13};
14use bytemuck::{Pod, Zeroable};
15
16use super::{
17 Error, PackedExtension, PackedSubfield,
18 arithmetic_traits::InvertOrZero,
19 binary_field::{BinaryField, BinaryField1b, binary_field, impl_field_extension},
20 binary_field_arithmetic::TowerFieldArithmetic,
21 mul_by_binary_field_1b,
22};
23use crate::{
24 ExtensionField, Field, TowerField, binary_field_arithmetic::impl_arithmetic_using_packed,
25 linear_transformation::Transformation, underlier::U1,
26};
27
28binary_field!(pub AESTowerField8b(u8), 0xD0);
36
37unsafe impl Pod for AESTowerField8b {}
38
39impl_field_extension!(BinaryField1b(U1) < @3 => AESTowerField8b(u8));
40
41mul_by_binary_field_1b!(AESTowerField8b);
42
43impl_arithmetic_using_packed!(AESTowerField8b);
44
45impl TowerField for AESTowerField8b {
46 fn min_tower_level(self) -> usize {
47 match self {
48 Self::ZERO | Self::ONE => 0,
49 _ => 3,
50 }
51 }
52
53 fn mul_primitive(self, iota: usize) -> Result<Self, Error> {
54 match iota {
55 0..=1 => Ok(self * ISOMORPHIC_ALPHAS[iota]),
56 2 => Ok(self.multiply_alpha()),
57 _ => Err(Error::ExtensionDegreeMismatch),
58 }
59 }
60}
61
62pub struct SubfieldTransformer<IF, OF, T> {
67 inner_transform: T,
68 _ip_pd: PhantomData<IF>,
69 _op_pd: PhantomData<OF>,
70}
71
72impl<IF, OF, IEP, OEP, T> Transformation<IEP, OEP> for SubfieldTransformer<IF, OF, T>
73where
74 IF: Field,
75 OF: Field,
76 IEP: PackedExtension<IF>,
77 OEP: PackedExtension<OF>,
78 T: Transformation<PackedSubfield<IEP, IF>, PackedSubfield<OEP, OF>>,
79{
80 fn transform(&self, input: &IEP) -> OEP {
81 OEP::cast_ext(self.inner_transform.transform(IEP::cast_base_ref(input)))
82 }
83}
84
85const ISOMORPHIC_ALPHAS: [AESTowerField8b; 3] = [
87 AESTowerField8b(0xBC),
88 AESTowerField8b(0xB0),
89 AESTowerField8b(0xD3),
90];
91
92impl SerializeBytes for AESTowerField8b {
93 fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
94 self.0.serialize(write_buf)
95 }
96}
97
98impl DeserializeBytes for AESTowerField8b {
99 fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError>
100 where
101 Self: Sized,
102 {
103 Ok(Self(DeserializeBytes::deserialize(read_buf)?))
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use binius_utils::{SerializeBytes, bytes::BytesMut};
110 use proptest::{arbitrary::any, proptest};
111 use rand::prelude::*;
112
113 use super::*;
114 use crate::{
115 Random, binary_field::tests::is_binary_field_valid_generator, underlier::WithUnderlier,
116 };
117
118 fn check_square(f: impl Field) {
119 assert_eq!(f.square(), f * f);
120 }
121
122 proptest! {
123 #[test]
124 fn test_square_8(a in any::<u8>()) {
125 check_square(AESTowerField8b::from(a))
126 }
127 }
128
129 fn check_invert(f: impl Field) {
130 let inversed = f.invert();
131 if f.is_zero() {
132 assert!(inversed.is_none());
133 } else {
134 assert_eq!(inversed.unwrap() * f, Field::ONE);
135 }
136 }
137
138 proptest! {
139 #[test]
140 fn test_invert_8(a in any::<u8>()) {
141 check_invert(AESTowerField8b::from(a))
142 }
143 }
144
145 fn check_mul_by_one<F: Field>(f: F) {
146 assert_eq!(F::ONE * f, f);
147 assert_eq!(f * F::ONE, f);
148 }
149
150 fn check_commutative<F: Field>(f_1: F, f_2: F) {
151 assert_eq!(f_1 * f_2, f_2 * f_1);
152 }
153
154 fn check_associativity_and_lineraity<F: Field>(f_1: F, f_2: F, f_3: F) {
155 assert_eq!(f_1 * (f_2 * f_3), (f_1 * f_2) * f_3);
156 assert_eq!(f_1 * (f_2 + f_3), f_1 * f_2 + f_1 * f_3);
157 }
158
159 fn check_mul<F: Field>(f_1: F, f_2: F, f_3: F) {
160 check_mul_by_one(f_1);
161 check_mul_by_one(f_2);
162 check_mul_by_one(f_3);
163
164 check_commutative(f_1, f_2);
165 check_commutative(f_1, f_3);
166 check_commutative(f_2, f_3);
167
168 check_associativity_and_lineraity(f_1, f_2, f_3);
169 check_associativity_and_lineraity(f_1, f_3, f_2);
170 check_associativity_and_lineraity(f_2, f_1, f_3);
171 check_associativity_and_lineraity(f_2, f_3, f_1);
172 check_associativity_and_lineraity(f_3, f_1, f_2);
173 check_associativity_and_lineraity(f_3, f_2, f_1);
174 }
175
176 proptest! {
177 #[test]
178 fn test_mul_8(a in any::<u8>(), b in any::<u8>(), c in any::<u8>()) {
179 check_mul(AESTowerField8b::from(a), AESTowerField8b::from(b), AESTowerField8b::from(c))
180 }
181 }
182
183 #[test]
184 fn test_multiplicative_generators() {
185 assert!(is_binary_field_valid_generator::<AESTowerField8b>());
186 }
187
188 fn test_mul_primitive<F: TowerField + WithUnderlier<Underlier: From<u8>>>(val: F, iota: usize) {
189 let result = val.mul_primitive(iota);
190 let expected = match iota {
191 0..=2 => {
192 Ok(val
193 * F::from_underlier(F::Underlier::from(ISOMORPHIC_ALPHAS[iota].to_underlier())))
194 }
195 _ => <F as ExtensionField<BinaryField1b>>::basis_checked(1 << iota).map(|b| val * b),
196 };
197 assert_eq!(result.is_ok(), expected.is_ok());
198 if result.is_ok() {
199 assert_eq!(result.unwrap(), expected.unwrap());
200 } else {
201 assert!(matches!(result.unwrap_err(), Error::ExtensionDegreeMismatch));
202 }
203 }
204
205 proptest! {
206 #[test]
207 fn test_mul_primitive_8b(val in 0u8.., iota in 3usize..8) {
208 test_mul_primitive::<AESTowerField8b>(val.into(), iota)
209 }
210 }
211
212 #[test]
213 fn test_serialization() {
214 let mut buffer = BytesMut::new();
215 let mut rng = StdRng::seed_from_u64(0);
216 let aes8 = AESTowerField8b::random(&mut rng);
217
218 SerializeBytes::serialize(&aes8, &mut buffer).unwrap();
219
220 let mut read_buffer = buffer.freeze();
221
222 assert_eq!(AESTowerField8b::deserialize(&mut read_buffer).unwrap(), aes8);
223 }
224}