binius_hash/vision_6/
parallel_permutation.rs1use binius_field::{BinaryField128bGhash as Ghash, arithmetic_traits::Square};
17use binius_math::batch_invert::BatchInversion;
18
19use super::{
20 constants::{B_FWD_COEFFS, M, NUM_ROUNDS, ROUND_CONSTANTS},
21 permutation::{linearized_b_inv_transform_scalar, mds_mul},
22};
23
24#[inline]
26fn batch_forward_transform<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
27 for i in 0..MN {
28 let scalar = states[i];
29 let square = scalar.square();
30 let quartic = square.square();
31
32 states[i] = B_FWD_COEFFS[0]
33 + B_FWD_COEFFS[1] * scalar
34 + B_FWD_COEFFS[2] * square
35 + B_FWD_COEFFS[3] * quartic;
36 }
37}
38
39#[inline]
41fn batch_inverse_transform<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
42 for i in 0..MN {
43 linearized_b_inv_transform_scalar(&mut states[i]);
44 }
45}
46
47#[inline]
49fn batch_mds_mul<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
50 for i in 0..N {
51 let state = &mut states[i * M..];
52 mds_mul(state);
53 }
54}
55
56#[inline]
58fn batch_constants_add<const N: usize, const MN: usize>(
59 states: &mut [Ghash; MN],
60 constants: &[Ghash; M],
61) {
62 for i in 0..N {
63 let state_start = i * M;
64 for j in 0..M {
65 states[state_start + j] += constants[j];
66 }
67 }
68}
69
70#[inline]
72fn batch_round<const N: usize, const MN: usize>(
73 states: &mut [Ghash; MN],
74 inverter: &mut BatchInversion<Ghash>,
75 round_constants_idx: usize,
76) {
77 inverter.invert_or_zero(states);
79 batch_inverse_transform::<N, MN>(states);
80 batch_mds_mul::<N, MN>(states);
81 batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[round_constants_idx]);
82
83 inverter.invert_or_zero(states);
85 batch_forward_transform::<N, MN>(states);
86 batch_mds_mul::<N, MN>(states);
87 batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[round_constants_idx + 1]);
88}
89
90#[inline]
94pub fn batch_permutation<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
95 batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[0]);
97
98 let mut inverter = BatchInversion::<Ghash>::new(MN);
99
100 for round_num in 0..NUM_ROUNDS {
102 batch_round::<N, MN>(states, &mut inverter, 1 + 2 * round_num);
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use std::array;
109
110 use binius_field::Random;
111 use rand::{SeedableRng, rngs::StdRng};
112
113 use super::*;
114 use crate::vision_6::permutation::permutation;
115
116 macro_rules! test_batch_permutation {
117 ($name:ident, $n:expr) => {
118 #[test]
119 fn $name() {
120 const N: usize = $n;
121 const MN: usize = M * N;
122 let mut rng = StdRng::seed_from_u64(0);
123
124 for _ in 0..4 {
125 let mut batch_states: [Ghash; MN] = array::from_fn(|_| Ghash::random(&mut rng));
126
127 let mut single_states: [[Ghash; M]; N] =
128 array::from_fn(|i| array::from_fn(|j| batch_states[i * M + j]));
129
130 batch_permutation::<N, MN>(&mut batch_states);
131
132 for state in single_states.iter_mut() {
133 permutation(state);
134 }
135
136 let expected_batch: [Ghash; MN] =
137 array::from_fn(|i| single_states[i / M][i % M]);
138
139 assert_eq!(batch_states, expected_batch);
140 }
141 }
142 };
143 }
144
145 test_batch_permutation!(test_batch_permutation_1, 1);
146 test_batch_permutation!(test_batch_permutation_2, 2);
147 test_batch_permutation!(test_batch_permutation_4, 4);
148 test_batch_permutation!(test_batch_permutation_8, 8);
149 test_batch_permutation!(test_batch_permutation_16, 16);
150 test_batch_permutation!(test_batch_permutation_32, 32);
151 test_batch_permutation!(test_batch_permutation_64, 64);
152}