binius_hash/vision_6/
parallel_permutation.rs

1// Copyright 2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4//! Parallel Vision-6 hash permutation using flattened state arrays.
5//!
6//! Processes N Vision-6 states simultaneously by flattening them into a single N×6 array.
7//! The key optimization is **batch inversion** - replacing N expensive field inversions
8//! with a single inversion across all states using Montgomery's algorithm.
9//!
10//! # Layout
11//! States: `[s0[0], s0[1], ..., s0[5], s1[0], s1[1], ...]` where `N` = number of states, `M = 6`.
12//!
13//! # Round Structure
14//! Each round: inversion → transform → MDS → constants → inversion → transform → MDS → constants
15
16use binius_field::{BinaryField128bGhash as Ghash, arithmetic_traits::Square};
17use binius_math::batch_invert::BatchInversion;
18
19use super::{
20	constants::{B_FWD_COEFFS, M, NUM_ROUNDS, ROUND_CONSTANTS},
21	permutation::{linearized_b_inv_transform_scalar, mds_mul},
22};
23
24/// Applies forward B-polynomial transformation: B(x) = c₀ + c₁x + c₂x² + c₃x⁴.
25#[inline]
26fn batch_forward_transform<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
27	for i in 0..MN {
28		let scalar = states[i];
29		let square = scalar.square();
30		let quartic = square.square();
31
32		states[i] = B_FWD_COEFFS[0]
33			+ B_FWD_COEFFS[1] * scalar
34			+ B_FWD_COEFFS[2] * square
35			+ B_FWD_COEFFS[3] * quartic;
36	}
37}
38
39/// Applies inverse B-polynomial transformation using lookups.
40#[inline]
41fn batch_inverse_transform<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
42	for i in 0..MN {
43		linearized_b_inv_transform_scalar(&mut states[i]);
44	}
45}
46
47/// Applies MDS matrix multiplication to each of the N batch states.
48#[inline]
49fn batch_mds_mul<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
50	for i in 0..N {
51		let state = &mut states[i * M..];
52		mds_mul(state);
53	}
54}
55
56/// Adds round constants to each of the N batch states.
57#[inline]
58fn batch_constants_add<const N: usize, const MN: usize>(
59	states: &mut [Ghash; MN],
60	constants: &[Ghash; M],
61) {
62	for i in 0..N {
63		let state_start = i * M;
64		for j in 0..M {
65			states[state_start + j] += constants[j];
66		}
67	}
68}
69
70/// Executes a complete Vision-6 round on all batch states.
71#[inline]
72fn batch_round<const N: usize, const MN: usize>(
73	states: &mut [Ghash; MN],
74	inverter: &mut BatchInversion<Ghash>,
75	round_constants_idx: usize,
76) {
77	// First half-round: inversion → inverse transform → MDS → constants
78	inverter.invert_or_zero(states);
79	batch_inverse_transform::<N, MN>(states);
80	batch_mds_mul::<N, MN>(states);
81	batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[round_constants_idx]);
82
83	// Second half-round: inversion → forward transform → MDS → constants
84	inverter.invert_or_zero(states);
85	batch_forward_transform::<N, MN>(states);
86	batch_mds_mul::<N, MN>(states);
87	batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[round_constants_idx + 1]);
88}
89
90/// Executes the complete Vision-6 permutation on N batch states.
91///
92/// Main entry point for batch Vision-6 hashing.
93#[inline]
94pub fn batch_permutation<const N: usize, const MN: usize>(states: &mut [Ghash; MN]) {
95	// Initial round constant addition
96	batch_constants_add::<N, MN>(states, &ROUND_CONSTANTS[0]);
97
98	let mut inverter = BatchInversion::<Ghash>::new(MN);
99
100	// Execute all rounds of the permutation
101	for round_num in 0..NUM_ROUNDS {
102		batch_round::<N, MN>(states, &mut inverter, 1 + 2 * round_num);
103	}
104}
105
106#[cfg(test)]
107mod tests {
108	use std::array;
109
110	use binius_field::Random;
111	use rand::{SeedableRng, rngs::StdRng};
112
113	use super::*;
114	use crate::vision_6::permutation::permutation;
115
116	macro_rules! test_batch_permutation {
117		($name:ident, $n:expr) => {
118			#[test]
119			fn $name() {
120				const N: usize = $n;
121				const MN: usize = M * N;
122				let mut rng = StdRng::seed_from_u64(0);
123
124				for _ in 0..4 {
125					let mut batch_states: [Ghash; MN] = array::from_fn(|_| Ghash::random(&mut rng));
126
127					let mut single_states: [[Ghash; M]; N] =
128						array::from_fn(|i| array::from_fn(|j| batch_states[i * M + j]));
129
130					batch_permutation::<N, MN>(&mut batch_states);
131
132					for state in single_states.iter_mut() {
133						permutation(state);
134					}
135
136					let expected_batch: [Ghash; MN] =
137						array::from_fn(|i| single_states[i / M][i % M]);
138
139					assert_eq!(batch_states, expected_batch);
140				}
141			}
142		};
143	}
144
145	test_batch_permutation!(test_batch_permutation_1, 1);
146	test_batch_permutation!(test_batch_permutation_2, 2);
147	test_batch_permutation!(test_batch_permutation_4, 4);
148	test_batch_permutation!(test_batch_permutation_8, 8);
149	test_batch_permutation!(test_batch_permutation_16, 16);
150	test_batch_permutation!(test_batch_permutation_32, 32);
151	test_batch_permutation!(test_batch_permutation_64, 64);
152}