binius_hash/vision_4/
permutation.rs1use binius_field::{BinaryField128bGhash as Ghash, Divisible, WithUnderlier};
9use binius_math::batch_invert::BatchInversion;
10
11use super::{
12 constants::{B_FWD_COEFFS, B_INV_COEFFS, BYTES_PER_GHASH, M, NUM_ROUNDS, ROUND_CONSTANTS},
13 linear_tables::{LINEAR_B_FWD_TABLE, LINEAR_B_INV_TABLE},
14};
15
16pub fn linearized_b_inv_transform_scalar(x: &mut Ghash) {
18 linearized_transform_scalar(x, &LINEAR_B_INV_TABLE);
19 *x += B_INV_COEFFS[0];
20}
21
22pub fn linearized_transform_scalar(x: &mut Ghash, table: &'static [[Ghash; 256]; BYTES_PER_GHASH]) {
24 *x = <u128 as Divisible<u8>>::ref_iter(x.to_underlier_ref())
25 .zip(table)
26 .map(|(byte, lookup)| lookup[byte as usize])
27 .sum();
28}
29
30pub fn b_fwd_transform<const N: usize>(state: &mut [Ghash; N]) {
32 (0..N).for_each(|i| {
33 linearized_transform_scalar(&mut state[i], &LINEAR_B_FWD_TABLE);
34 state[i] += B_FWD_COEFFS[0];
35 });
36}
37
38pub fn b_inv_transform<const N: usize>(state: &mut [Ghash; N]) {
40 (0..N).for_each(|i| {
41 linearized_transform_scalar(&mut state[i], &LINEAR_B_INV_TABLE);
42 state[i] += B_INV_COEFFS[0];
43 });
44}
45
46pub fn sbox(
48 state: &mut [Ghash; M],
49 transform: impl Fn(&mut [Ghash; M]),
50 inverter: &mut BatchInversion<Ghash>,
51) {
52 inverter.invert_or_zero(state);
53 transform(state);
54}
55
56pub fn mds_mul(a: &mut [Ghash]) {
58 let sum = a[0] + a[1] + a[2] + a[3];
60 let a0 = a[0];
61
62 a[0] += sum + (a[0] + a[1]).mul_x();
64
65 a[1] += sum + (a[1] + a[2]).mul_x();
67
68 a[2] += sum + (a[2] + a[3]).mul_x();
70
71 a[3] += sum + (a[3] + a0).mul_x();
73}
74
75pub fn constants_add(state: &mut [Ghash], constants: &[Ghash]) {
77 for i in 0..M {
78 state[i] += constants[i];
79 }
80}
81
82fn round(state: &mut [Ghash; M], round_constants_idx: usize, inverter: &mut BatchInversion<Ghash>) {
84 sbox(state, b_inv_transform, inverter);
86 mds_mul(state);
87 constants_add(state, &ROUND_CONSTANTS[round_constants_idx]);
88 sbox(state, b_fwd_transform, inverter);
90 mds_mul(state);
91 constants_add(state, &ROUND_CONSTANTS[round_constants_idx + 1]);
92}
93
94pub fn permutation(state: &mut [Ghash; M]) {
96 constants_add(state, &ROUND_CONSTANTS[0]);
97 let mut inverter = BatchInversion::<Ghash>::new(M);
98 for round_num in 0..NUM_ROUNDS {
99 round(state, 1 + 2 * round_num, &mut inverter);
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use std::array;
106
107 use binius_field::Random;
108 use rand::{SeedableRng, rngs::StdRng};
109
110 use super::{super::constants::tests::matrix_mul, *};
111
112 #[test]
113 fn test_mds() {
114 let mut rng = StdRng::seed_from_u64(0);
115 let input: [Ghash; M] = std::array::from_fn(|_| Ghash::random(&mut rng));
116
117 let matrix: [Ghash; M * M] = [
118 2, 3, 1, 1, 1, 2, 3, 1, 1, 1, 2, 3, 3, 1, 1, 2,
122 ]
123 .map(Ghash::new);
124 let expected = matrix_mul(&matrix, &input);
125
126 let mut actual = input;
127 mds_mul(&mut actual);
128
129 assert_eq!(actual, expected);
130 }
131
132 #[test]
133 fn test_permutation() {
134 let mut rng = StdRng::seed_from_u64(0);
135 let cases = [
137 (
138 array::from_fn(|_| Ghash::new(0x0)),
139 [
140 Ghash::new(0x5e9a7b63d8d1a93953d56ceb6dcf6a35),
141 Ghash::new(0xa3262c57f6cdd8c368639c1a4f01ab5a),
142 Ghash::new(0x1dc99e37723063c4f178826d2a6802e3),
143 Ghash::new(0xfdf935c9d9fae3d560a75026a049bf7c),
144 ],
145 ),
146 (
147 [
148 Ghash::new(0xdeadbeef),
149 Ghash::new(0x0),
150 Ghash::new(0xdeadbeef),
151 Ghash::new(0x0),
152 ],
153 [
154 Ghash::new(0x1d02eaf6cf48c108a2ae1d9e27812364),
155 Ghash::new(0xc9bae4f4c782d46ed28245525f04fb3c),
156 Ghash::new(0xf4fea518a1e62f97748266e86acac536),
157 Ghash::new(0x22b25c68a52fef4b855f8862bdd418c4),
158 ],
159 ),
160 (
161 array::from_fn(|_| Ghash::random(&mut rng)),
162 [
163 Ghash::new(0xdd1c99b8f9f2ec20abf21f082a56c9f3),
164 Ghash::new(0x3f5ec0a548673b571ba93d7751c98624),
165 Ghash::new(0xe1c5c8fc8f4c80cfa8841cfd0ae0fbbb),
166 Ghash::new(0xa054cc0d7379b474df8726cb448ca22b),
167 ],
168 ),
169 ];
170
171 for (input, expected) in cases {
172 let mut state = input;
173 permutation(&mut state);
174 assert_eq!(state, expected);
175 }
176 }
177}