binius_field/
aes_field.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{
4	any::TypeId,
5	fmt::{Debug, Display, Formatter},
6	iter::{Product, Sum},
7	marker::PhantomData,
8	ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
9};
10
11use binius_utils::{
12	DeserializeBytes, SerializationError, SerializeBytes,
13	bytes::{Buf, BufMut},
14};
15use bytemuck::{Pod, Zeroable};
16
17use super::{
18	Error, PackedExtension, PackedSubfield,
19	arithmetic_traits::InvertOrZero,
20	binary_field::{BinaryField, BinaryField1b, binary_field, impl_field_extension},
21	binary_field_arithmetic::TowerFieldArithmetic,
22	mul_by_binary_field_1b,
23};
24use crate::{
25	ExtensionField, Field, TowerField, binary_field_arithmetic::impl_arithmetic_using_packed,
26	linear_transformation::Transformation, underlier::U1,
27};
28
29// These fields represent a tower based on AES GF(2^8) field (GF(256)/x^8+x^4+x^3+x+1)
30// that is isomorphically included into binary tower, i.e.:
31//  - AESTowerField16b is GF(2^16) / (x^2 + x * x_2 + 1) where `x_2` is 0x10 from
32// BinaryField8b isomorphically projected to AESTowerField8b.
33//  - AESTowerField32b is GF(2^32) / (x^2 + x * x_3 + 1), where `x_3` is 0x1000 from
34//    AESTowerField16b.
35//  ...
36binary_field!(pub AESTowerField8b(u8), 0xD0);
37
38unsafe impl Pod for AESTowerField8b {}
39
40impl_field_extension!(BinaryField1b(U1) < @3 => AESTowerField8b(u8));
41
42mul_by_binary_field_1b!(AESTowerField8b);
43
44impl_arithmetic_using_packed!(AESTowerField8b);
45
46impl TowerField for AESTowerField8b {
47	fn min_tower_level(self) -> usize {
48		match self {
49			Self::ZERO | Self::ONE => 0,
50			_ => 3,
51		}
52	}
53
54	fn mul_primitive(self, iota: usize) -> Result<Self, Error> {
55		match iota {
56			0..=1 => Ok(self * ISOMORPHIC_ALPHAS[iota]),
57			2 => Ok(self.multiply_alpha()),
58			_ => Err(Error::ExtensionDegreeMismatch),
59		}
60	}
61}
62
63/// Returns true if `F`` is AES tower field.
64#[inline(always)]
65pub fn is_aes_tower<F: TowerField>() -> bool {
66	TypeId::of::<F>() == TypeId::of::<AESTowerField8b>()
67}
68
69/// A 3- step transformation :
70/// 1. Cast to base b-bit packed field
71/// 2. Apply linear transformation between aes and binary b8 tower fields
72/// 3. Cast back to the target field
73pub struct SubfieldTransformer<IF, OF, T> {
74	inner_transform: T,
75	_ip_pd: PhantomData<IF>,
76	_op_pd: PhantomData<OF>,
77}
78
79impl<IF, OF, IEP, OEP, T> Transformation<IEP, OEP> for SubfieldTransformer<IF, OF, T>
80where
81	IF: Field,
82	OF: Field,
83	IEP: PackedExtension<IF>,
84	OEP: PackedExtension<OF>,
85	T: Transformation<PackedSubfield<IEP, IF>, PackedSubfield<OEP, OF>>,
86{
87	fn transform(&self, input: &IEP) -> OEP {
88		OEP::cast_ext(self.inner_transform.transform(IEP::cast_base_ref(input)))
89	}
90}
91
92/// Values isomorphic to 0x02, 0x04 and 0x10 in BinaryField8b
93const ISOMORPHIC_ALPHAS: [AESTowerField8b; 3] = [
94	AESTowerField8b(0xBC),
95	AESTowerField8b(0xB0),
96	AESTowerField8b(0xD3),
97];
98
99macro_rules! serialize_deserialize_non_canonical {
100	($field:ident, canonical=$canonical:ident) => {
101		impl SerializeBytes for $field {
102			fn serialize(&self, write_buf: impl BufMut) -> Result<(), SerializationError> {
103				self.0.serialize(write_buf)
104			}
105		}
106
107		impl DeserializeBytes for $field {
108			fn deserialize(read_buf: impl Buf) -> Result<Self, SerializationError>
109			where
110				Self: Sized,
111			{
112				Ok(Self(DeserializeBytes::deserialize(read_buf)?))
113			}
114		}
115	};
116}
117
118serialize_deserialize_non_canonical!(AESTowerField8b, canonical = BinaryField8b);
119
120#[cfg(test)]
121mod tests {
122	use binius_utils::{SerializeBytes, bytes::BytesMut};
123	use proptest::{arbitrary::any, proptest};
124	use rand::prelude::*;
125
126	use super::*;
127	use crate::{
128		Random, binary_field::tests::is_binary_field_valid_generator, underlier::WithUnderlier,
129	};
130
131	fn check_square(f: impl Field) {
132		assert_eq!(f.square(), f * f);
133	}
134
135	proptest! {
136		#[test]
137		fn test_square_8(a in any::<u8>()) {
138			check_square(AESTowerField8b::from(a))
139		}
140	}
141
142	fn check_invert(f: impl Field) {
143		let inversed = f.invert();
144		if f.is_zero() {
145			assert!(inversed.is_none());
146		} else {
147			assert_eq!(inversed.unwrap() * f, Field::ONE);
148		}
149	}
150
151	proptest! {
152		#[test]
153		fn test_invert_8(a in any::<u8>()) {
154			check_invert(AESTowerField8b::from(a))
155		}
156	}
157
158	fn check_mul_by_one<F: Field>(f: F) {
159		assert_eq!(F::ONE * f, f);
160		assert_eq!(f * F::ONE, f);
161	}
162
163	fn check_commutative<F: Field>(f_1: F, f_2: F) {
164		assert_eq!(f_1 * f_2, f_2 * f_1);
165	}
166
167	fn check_associativity_and_lineraity<F: Field>(f_1: F, f_2: F, f_3: F) {
168		assert_eq!(f_1 * (f_2 * f_3), (f_1 * f_2) * f_3);
169		assert_eq!(f_1 * (f_2 + f_3), f_1 * f_2 + f_1 * f_3);
170	}
171
172	fn check_mul<F: Field>(f_1: F, f_2: F, f_3: F) {
173		check_mul_by_one(f_1);
174		check_mul_by_one(f_2);
175		check_mul_by_one(f_3);
176
177		check_commutative(f_1, f_2);
178		check_commutative(f_1, f_3);
179		check_commutative(f_2, f_3);
180
181		check_associativity_and_lineraity(f_1, f_2, f_3);
182		check_associativity_and_lineraity(f_1, f_3, f_2);
183		check_associativity_and_lineraity(f_2, f_1, f_3);
184		check_associativity_and_lineraity(f_2, f_3, f_1);
185		check_associativity_and_lineraity(f_3, f_1, f_2);
186		check_associativity_and_lineraity(f_3, f_2, f_1);
187	}
188
189	proptest! {
190		#[test]
191		fn test_mul_8(a in any::<u8>(), b in any::<u8>(), c in any::<u8>()) {
192			check_mul(AESTowerField8b::from(a), AESTowerField8b::from(b), AESTowerField8b::from(c))
193		}
194	}
195
196	#[test]
197	fn test_multiplicative_generators() {
198		assert!(is_binary_field_valid_generator::<AESTowerField8b>());
199	}
200
201	fn test_mul_primitive<F: TowerField + WithUnderlier<Underlier: From<u8>>>(val: F, iota: usize) {
202		let result = val.mul_primitive(iota);
203		let expected = match iota {
204			0..=2 => {
205				Ok(val
206					* F::from_underlier(F::Underlier::from(ISOMORPHIC_ALPHAS[iota].to_underlier())))
207			}
208			_ => <F as ExtensionField<BinaryField1b>>::basis_checked(1 << iota).map(|b| val * b),
209		};
210		assert_eq!(result.is_ok(), expected.is_ok());
211		if result.is_ok() {
212			assert_eq!(result.unwrap(), expected.unwrap());
213		} else {
214			assert!(matches!(result.unwrap_err(), Error::ExtensionDegreeMismatch));
215		}
216	}
217
218	proptest! {
219		#[test]
220		fn test_mul_primitive_8b(val in 0u8.., iota in 3usize..8) {
221			test_mul_primitive::<AESTowerField8b>(val.into(), iota)
222		}
223	}
224
225	#[test]
226	fn test_serialization() {
227		let mut buffer = BytesMut::new();
228		let mut rng = StdRng::seed_from_u64(0);
229		let aes8 = AESTowerField8b::random(&mut rng);
230
231		SerializeBytes::serialize(&aes8, &mut buffer).unwrap();
232
233		let mut read_buffer = buffer.freeze();
234
235		assert_eq!(AESTowerField8b::deserialize(&mut read_buffer).unwrap(), aes8);
236	}
237}