binius_hash/vision_4/
parallel_permutation.rs1use binius_field::{BinaryField128bGhash as Ghash, arithmetic_traits::Square};
17use binius_math::batch_invert::BatchInversion;
18
19use super::{
20 constants::{B_FWD_COEFFS, M, NUM_ROUNDS, ROUND_CONSTANTS},
21 permutation::{constants_add, linearized_b_inv_transform_scalar, mds_mul},
22};
23
24#[inline]
26fn batch_forward_transform<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
27 for i in 0..MN {
28 let scalar = states[i];
29 let square = scalar.square();
30 let quartic = square.square();
31
32 states[i] = B_FWD_COEFFS[0]
33 + B_FWD_COEFFS[1] * scalar
34 + B_FWD_COEFFS[2] * square
35 + B_FWD_COEFFS[3] * quartic;
36 }
37}
38
39#[inline]
41fn batch_inverse_transform<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
42 for i in 0..MN {
43 linearized_b_inv_transform_scalar(&mut states[i]);
44 }
45}
46
47#[inline]
49fn batch_mds_mul<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
50 for i in 0..N {
51 let state = &mut states[i * M..];
52 mds_mul(state);
53 }
54}
55
56#[inline]
58fn batch_constants_add<const N: usize, const MN: usize>(
59 states: &mut [Ghash; MN],
60 constants: &[Ghash; M],
61) {
62 for i in 0..N {
63 let state = &mut states[i * M..];
64 constants_add(state, constants);
65 }
66}
67
68#[inline]
70fn batch_round<const N: usize, const MN: usize>(
71 states: &mut [Ghash; MN],
72 inverter: &mut BatchInversion<Ghash>,
73 round_constants_idx: usize,
74) {
75 inverter.invert_or_zero(states);
77 batch_inverse_transform::<N, MN>(states);
78 batch_mds_mul::<N, MN>(states);
79 batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[round_constants_idx]);
80
81 inverter.invert_or_zero(states);
83 batch_forward_transform::<N, MN>(states);
84 batch_mds_mul::<N, MN>(states);
85 batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[round_constants_idx + 1]);
86}
87
88#[inline]
92pub fn batch_permutation<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
93 batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[0]);
95
96 let mut inverter = BatchInversion::<Ghash>::new(MN);
97
98 for round_num in 0..NUM_ROUNDS {
100 batch_round::<N, MN>(states, &mut inverter, 1 + 2 * round_num);
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use std::array;
107
108 use binius_field::Random;
109 use rand::{SeedableRng, rngs::StdRng};
110
111 use super::*;
112 use crate::vision_4::permutation::permutation;
113
114 macro_rules! test_parallel_permutation {
115 ($name:ident, $n:expr) => {
116 #[test]
117 fn $name() {
118 const N: usize = $n;
119 const MN: usize = M * N;
120 let mut rng = StdRng::seed_from_u64(0);
121
122 for _ in 0..4 {
123 let mut parallel_states: [Ghash; MN] =
124 array::from_fn(|_| Ghash::random(&mut rng));
125
126 let mut single_states: [[Ghash; M]; N] =
127 array::from_fn(|i| array::from_fn(|j| parallel_states[i * M + j]));
128
129 batch_permutation::<N, MN>(&mut parallel_states);
130
131 for state in single_states.iter_mut() {
132 permutation(state);
133 }
134
135 let expected_parallel: [Ghash; MN] =
136 array::from_fn(|i| single_states[i / M][i % M]);
137
138 assert_eq!(parallel_states, expected_parallel);
139 }
140 }
141 };
142 }
143
144 test_parallel_permutation!(test_parallel_permutation_1, 1);
145 test_parallel_permutation!(test_parallel_permutation_2, 2);
146 test_parallel_permutation!(test_parallel_permutation_4, 4);
147 test_parallel_permutation!(test_parallel_permutation_8, 8);
148 test_parallel_permutation!(test_parallel_permutation_16, 16);
149 test_parallel_permutation!(test_parallel_permutation_32, 32);
150 test_parallel_permutation!(test_parallel_permutation_64, 64);
151}