binius_hash/vision_4/
parallel_permutation.rs

1// Copyright 2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4//! Parallel Vision-4 hash permutation using flattened state arrays.
5//!
6//! Processes N Vision-4 states simultaneously by flattening them into a single N×4 array.
7//! The key optimization is **batch inversion** - replacing N expensive field inversions
8//! with a single inversion across all states using Montgomery's algorithm.
9//!
10//! # Layout
11//! States: `[s0[0], s0[1], s0[2], s0[3], s1[0], s1[1], ...]` where `N` = number of states, `M = 4`.
12//!
13//! # Round Structure
14//! Each round: inversion → transform → MDS → constants → inversion → transform → MDS → constants
15
16use binius_field::{BinaryField128bGhash as Ghash, arithmetic_traits::Square};
17use binius_math::batch_invert::BatchInversion;
18
19use super::{
20	constants::{B_FWD_COEFFS, M, NUM_ROUNDS, ROUND_CONSTANTS},
21	permutation::{constants_add, linearized_b_inv_transform_scalar, mds_mul},
22};
23
24/// Applies forward B-polynomial transformation: B(x) = c₀ + c₁x + c₂x² + c₃x⁴.
25#[inline]
26fn batch_forward_transform<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
27	for i in 0..MN {
28		let scalar = states[i];
29		let square = scalar.square();
30		let quartic = square.square();
31
32		states[i] = B_FWD_COEFFS[0]
33			+ B_FWD_COEFFS[1] * scalar
34			+ B_FWD_COEFFS[2] * square
35			+ B_FWD_COEFFS[3] * quartic;
36	}
37}
38
39/// Applies inverse B-polynomial transformation using lookups.
40#[inline]
41fn batch_inverse_transform<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
42	for i in 0..MN {
43		linearized_b_inv_transform_scalar(&mut states[i]);
44	}
45}
46
47/// Applies MDS matrix multiplication to each of the N parallel states.
48#[inline]
49fn batch_mds_mul<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
50	for i in 0..N {
51		let state = &mut states[i * M..];
52		mds_mul(state);
53	}
54}
55
56/// Adds round constants to each of the N parallel states.
57#[inline]
58fn batch_constants_add<const N: usize, const MN: usize>(
59	states: &mut [Ghash; MN],
60	constants: &[Ghash; M],
61) {
62	for i in 0..N {
63		let state = &mut states[i * M..];
64		constants_add(state, constants);
65	}
66}
67
68/// Executes a complete Vision-4 round on all parallel states.
69#[inline]
70fn batch_round<const N: usize, const MN: usize>(
71	states: &mut [Ghash; MN],
72	inverter: &mut BatchInversion<Ghash>,
73	round_constants_idx: usize,
74) {
75	// First half-round: inversion → inverse transform → MDS → constants
76	inverter.invert_or_zero(states);
77	batch_inverse_transform::<N, MN>(states);
78	batch_mds_mul::<N, MN>(states);
79	batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[round_constants_idx]);
80
81	// Second half-round: inversion → forward transform → MDS → constants
82	inverter.invert_or_zero(states);
83	batch_forward_transform::<N, MN>(states);
84	batch_mds_mul::<N, MN>(states);
85	batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[round_constants_idx + 1]);
86}
87
88/// Executes the complete Vision-4 permutation on N parallel states.
89///
90/// Main entry point for parallel Vision-4 hashing.
91#[inline]
92pub fn batch_permutation<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
93	// Initial round constant addition
94	batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[0]);
95
96	let mut inverter = BatchInversion::<Ghash>::new(MN);
97
98	// Execute all rounds of the permutation
99	for round_num in 0..NUM_ROUNDS {
100		batch_round::<N, MN>(states, &mut inverter, 1 + 2 * round_num);
101	}
102}
103
104#[cfg(test)]
105mod tests {
106	use std::array;
107
108	use binius_field::Random;
109	use rand::{SeedableRng, rngs::StdRng};
110
111	use super::*;
112	use crate::vision_4::permutation::permutation;
113
114	macro_rules! test_parallel_permutation {
115		($name:ident, $n:expr) => {
116			#[test]
117			fn $name() {
118				const N: usize = $n;
119				const MN: usize = M * N;
120				let mut rng = StdRng::seed_from_u64(0);
121
122				for _ in 0..4 {
123					let mut parallel_states: [Ghash; MN] =
124						array::from_fn(|_| Ghash::random(&mut rng));
125
126					let mut single_states: [[Ghash; M]; N] =
127						array::from_fn(|i| array::from_fn(|j| parallel_states[i * M + j]));
128
129					batch_permutation::<N, MN>(&mut parallel_states);
130
131					for state in single_states.iter_mut() {
132						permutation(state);
133					}
134
135					let expected_parallel: [Ghash; MN] =
136						array::from_fn(|i| single_states[i / M][i % M]);
137
138					assert_eq!(parallel_states, expected_parallel);
139				}
140			}
141		};
142	}
143
144	test_parallel_permutation!(test_parallel_permutation_1, 1);
145	test_parallel_permutation!(test_parallel_permutation_2, 2);
146	test_parallel_permutation!(test_parallel_permutation_4, 4);
147	test_parallel_permutation!(test_parallel_permutation_8, 8);
148	test_parallel_permutation!(test_parallel_permutation_16, 16);
149	test_parallel_permutation!(test_parallel_permutation_32, 32);
150	test_parallel_permutation!(test_parallel_permutation_64, 64);
151}