1use binius_field::{BinaryField128bGhash as Ghash, Divisible, WithUnderlier};
9use binius_math::batch_invert::BatchInversion;
10
11use super::{
12 constants::{B_FWD_COEFFS, B_INV_COEFFS, BYTES_PER_GHASH, M, NUM_ROUNDS, ROUND_CONSTANTS},
13 linear_tables::{LINEAR_B_FWD_TABLE, LINEAR_B_INV_TABLE},
14};
15
16pub fn linearized_b_inv_transform_scalar(x: &mut Ghash) {
18 linearized_transform_scalar(x, &LINEAR_B_INV_TABLE);
19 *x += B_INV_COEFFS[0];
20}
21
22pub fn linearized_transform_scalar(x: &mut Ghash, table: &'static [[Ghash; 256]; BYTES_PER_GHASH]) {
24 *x = <u128 as Divisible<u8>>::ref_iter(x.to_underlier_ref())
25 .zip(table)
26 .map(|(byte, lookup)| lookup[byte as usize])
27 .sum();
28}
29
30pub fn b_fwd_transform<const N: usize>(state: &mut [Ghash; N]) {
32 (0..N).for_each(|i| {
33 linearized_transform_scalar(&mut state[i], &LINEAR_B_FWD_TABLE);
34 state[i] += B_FWD_COEFFS[0];
35 });
36}
37
38pub fn b_inv_transform<const N: usize>(state: &mut [Ghash; N]) {
40 (0..N).for_each(|i| {
41 linearized_transform_scalar(&mut state[i], &LINEAR_B_INV_TABLE);
42 state[i] += B_INV_COEFFS[0];
43 });
44}
45
46pub fn sbox(
48 state: &mut [Ghash; M],
49 transform: impl Fn(&mut [Ghash; M]),
50 inverter: &mut BatchInversion<Ghash>,
51) {
52 inverter.invert_or_zero(state);
53 transform(state);
54}
55
56pub fn mds_mul(a: &mut [Ghash]) {
58 let a0 = a[0];
64 let a1 = a[1];
65 let a2 = a[2];
66 let a3 = a[3];
67 let a4 = a[4];
68 let a5 = a[5];
69
70 a[0] = (a0 + a1).mul_inv_x() + (a2 + a4).mul_x() + a3 + (a5.mul_x() + a5);
73
74 a[1] = (a0.mul_x() + a0) + (a1 + a2).mul_inv_x() + a3.mul_x() + a4 + a5.mul_x();
77
78 a[2] = a0.mul_x() + (a1.mul_x() + a1) + (a2 + a3).mul_inv_x() + a4.mul_x() + a5;
81
82 a[3] = a0 + a1.mul_x() + (a2.mul_x() + a2) + (a3 + a4).mul_inv_x() + a5.mul_x();
85
86 a[4] = a0.mul_x() + a1 + a2.mul_x() + (a3.mul_x() + a3) + (a4 + a5).mul_inv_x();
89
90 a[5] = (a0 + a5).mul_inv_x() + a1.mul_x() + a2 + a3.mul_x() + (a4.mul_x() + a4);
93}
94
95pub fn constants_add(state: &mut [Ghash], constants: &[Ghash]) {
97 for i in 0..M {
98 state[i] += constants[i];
99 }
100}
101
102fn round(state: &mut [Ghash; M], round_constants_idx: usize, inverter: &mut BatchInversion<Ghash>) {
104 sbox(state, b_inv_transform, inverter);
106 mds_mul(state);
107 constants_add(state, &ROUND_CONSTANTS[round_constants_idx]);
108 sbox(state, b_fwd_transform, inverter);
110 mds_mul(state);
111 constants_add(state, &ROUND_CONSTANTS[round_constants_idx + 1]);
112}
113
114pub fn permutation(state: &mut [Ghash; M]) {
116 constants_add(state, &ROUND_CONSTANTS[0]);
117 let mut inverter = BatchInversion::<Ghash>::new(M);
118 for round_num in 0..NUM_ROUNDS {
119 round(state, 1 + 2 * round_num, &mut inverter);
120 }
121}
122
123#[cfg(test)]
124mod tests {
125
126 use std::array;
127
128 use binius_field::{Field, Random};
129 use rand::{SeedableRng, rngs::StdRng};
130
131 use super::*;
132
133 fn matrix_mul(matrix: &[Ghash; M * M], input: &[Ghash; M]) -> [Ghash; M] {
134 let mut result = [Ghash::ZERO; M];
135 for i in 0..M {
136 for j in 0..M {
137 result[i] += matrix[i * M + j] * input[j];
138 }
139 }
140 result
141 }
142
143 #[test]
144 fn test_mds() {
145 use rand::{SeedableRng, rngs::StdRng};
146
147 let mut rng = StdRng::seed_from_u64(0);
148 let input: [Ghash; M] = std::array::from_fn(|_| Ghash::random(&mut rng));
149
150 let x_inv = Ghash::new(2).invert().expect("2 is invertible");
152 let x = Ghash::new(2);
153 let x_plus_1 = Ghash::new(3);
154 let one = Ghash::ONE;
155
156 let matrix: [Ghash; M * M] = [
157 x_inv, x_inv, x, one, x, x_plus_1, x_plus_1, x_inv, x_inv, x, one, x, x, x_plus_1, x_inv, x_inv, x, one, one, x, x_plus_1, x_inv, x_inv, x, x, one, x, x_plus_1, x_inv, x_inv, x_inv, x, one, x, x_plus_1, x_inv, ];
164 let expected = matrix_mul(&matrix, &input);
165
166 let mut actual = input;
167 mds_mul(&mut actual);
168
169 assert_eq!(actual, expected);
170 }
171
172 #[test]
173 fn test_permutation() {
174 let mut rng = StdRng::seed_from_u64(0);
175 let cases = [
177 (
178 array::from_fn(|_| Ghash::new(0x0)),
179 [
180 Ghash::new(0xd41c58ea75c2e3a8e5004834f122d650),
181 Ghash::new(0xb1a6fb890a3a7520384c2e21f6dcc18d),
182 Ghash::new(0xe4ef3f5c84fe8bef518f57ee1f38dc05),
183 Ghash::new(0x9cb1081fc97c17719c9527727f991bc1),
184 Ghash::new(0x42ae4487ccb1a4af24ad33acf8f9a8cd),
185 Ghash::new(0x53faafdae4007e9983ec18971b8ce524),
186 ],
187 ),
188 (
189 [
190 Ghash::new(0xdeadbeef),
191 Ghash::new(0x0),
192 Ghash::new(0xdeadbeef),
193 Ghash::new(0x0),
194 Ghash::new(0xdeadbeef),
195 Ghash::new(0x0),
196 ],
197 [
198 Ghash::new(0x6cef13e30578bbc055e541b8daae5525),
199 Ghash::new(0xe19d52ae54a01f3aacdb3dd2b8968a5f),
200 Ghash::new(0xca3289530c76d0c696e313ed5c1b7727),
201 Ghash::new(0x14dc021e84aa3ce6e7bb3a9452f61adc),
202 Ghash::new(0xde1d940f6c8b7d4869f02157f2f939df),
203 Ghash::new(0xd19a101d6d736dacaedad738dd35596a),
204 ],
205 ),
206 (
207 array::from_fn(|_| Ghash::random(&mut rng)),
208 [
209 Ghash::new(0x54cf96e0da8f01e8b8a4688cd0f8b881),
210 Ghash::new(0xdbbc6cb3d5a96cee7d5ad99fe00f7874),
211 Ghash::new(0xaa44b45bed826f5baa02979e91593a7b),
212 Ghash::new(0x679fd76c310f55de3d216c5b4572597d),
213 Ghash::new(0x535121f744503928e021d9c4b8a56c46),
214 Ghash::new(0xc7de0f11beaf12aed3cdd8d3c4d1b4b8),
215 ],
216 ),
217 ];
218
219 for (input, expected) in cases {
220 let mut state = input;
221 permutation(&mut state);
222 assert_eq!(state, expected);
223 }
224 }
225}