binius_hash/vision_4/
parallel_digest.rs

1// Copyright 2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{array, mem::MaybeUninit};
5
6use binius_field::{BinaryField128bGhash as Ghash, Field};
7use binius_utils::{DeserializeBytes, SerializeBytes};
8use digest::Output;
9
10use super::{
11	constants::M,
12	digest::{PADDING_BLOCK, RATE_AS_U8, RATE_AS_U128, VisionHasherDigest, fill_padding},
13	parallel_permutation::batch_permutation,
14};
15use crate::parallel_digest::MultiDigest;
16
17/// A Vision hasher with state size M=4 suited for parallelization.
18///
19/// Without using packed fields, there is only one advantage of an explicit parallelized
20/// Vision hasher over invoking the Vision hasher multiple times in parallel:
21/// we can amortize the cost of inversion in the sbox using
22/// [Montogery's trick](https://medium.com/eryxcoop/montgomerys-trick-for-batch-galois-field-inversion-9b6d0f399da2).
23/// We slightly modify Montogery's trick to use a binary tree structure,
24/// maximizing independence of multiplications for better instruction pipelining.
25#[derive(Clone)]
26pub struct VisionHasherMultiDigest<const N: usize, const MN: usize> {
27	states: [Ghash; MN],
28	buffers: [[u8; RATE_AS_U8]; N],
29	filled_bytes: usize,
30}
31
32impl<const N: usize, const MN: usize> Default for VisionHasherMultiDigest<N, MN> {
33	fn default() -> Self {
34		assert!(N.is_power_of_two() && N >= 2, "N must be a power of 2 and >= 2");
35		assert_eq!(MN, M * N);
36		Self {
37			states: array::from_fn(|_| Ghash::ZERO),
38			buffers: array::from_fn(|_| [0; RATE_AS_U8]),
39			filled_bytes: 0,
40		}
41	}
42}
43
44impl<const N: usize, const MN: usize> VisionHasherMultiDigest<N, MN> {
45	#[inline]
46	fn advance_data(data: &mut [&[u8]; N], bytes: usize) {
47		for i in 0..N {
48			data[i] = &data[i][bytes..];
49		}
50	}
51
52	fn permute(states: &mut [Ghash; MN], data: [&[u8]; N]) {
53		for (i, data) in data.iter().enumerate() {
54			debug_assert_eq!(data.len(), RATE_AS_U8);
55
56			// Overwrite first RATE_AS_U128 elements of state i with data
57			let state_start = i * M;
58			for j in 0..RATE_AS_U128 {
59				let element_bytes = &data[j * (128 / 8)..];
60				states[state_start + j] =
61					Ghash::deserialize(element_bytes).expect("data len checked");
62			}
63		}
64
65		batch_permutation::<N, MN>(states);
66	}
67	fn finalize(&mut self, out: &mut [MaybeUninit<digest::Output<VisionHasherDigest>>; N]) {
68		if self.filled_bytes != 0 {
69			for i in 0..N {
70				fill_padding(&mut self.buffers[i][self.filled_bytes..]);
71			}
72			Self::permute(&mut self.states, array::from_fn(|i| &self.buffers[i][..]));
73		} else {
74			Self::permute(&mut self.states, array::from_fn(|_| &PADDING_BLOCK[..]));
75		}
76
77		// Serialize first two state elements for each digest (32 bytes total per digest)
78		for i in 0..N {
79			let output_slice = out[i].as_mut_ptr() as *mut u8;
80			let output_bytes = unsafe { std::slice::from_raw_parts_mut(output_slice, 32) };
81			let (state0, state1) = output_bytes.split_at_mut(16);
82			self.states[i * M]
83				.serialize(state0)
84				.expect("fits in 16 bytes");
85			self.states[i * M + 1]
86				.serialize(state1)
87				.expect("fits in 16 bytes");
88		}
89	}
90}
91
92impl<const N: usize, const MN: usize> MultiDigest<N> for VisionHasherMultiDigest<N, MN> {
93	type Digest = VisionHasherDigest;
94
95	fn new() -> Self {
96		Self::default()
97	}
98
99	fn update(&mut self, mut data: [&[u8]; N]) {
100		data[1..].iter().for_each(|row| {
101			assert_eq!(row.len(), data[0].len());
102		});
103
104		if self.filled_bytes != 0 {
105			let to_copy = std::cmp::min(data[0].len(), RATE_AS_U8 - self.filled_bytes);
106			data.iter().enumerate().for_each(|(row_i, row)| {
107				self.buffers[row_i][self.filled_bytes..self.filled_bytes + to_copy]
108					.copy_from_slice(&row[..to_copy]);
109			});
110			Self::advance_data(&mut data, to_copy);
111			self.filled_bytes += to_copy;
112
113			if self.filled_bytes == RATE_AS_U8 {
114				Self::permute(&mut self.states, array::from_fn(|i| &self.buffers[i][..]));
115				self.filled_bytes = 0;
116			}
117		}
118
119		while data[0].len() >= RATE_AS_U8 {
120			let chunks = array::from_fn(|i| &data[i][..RATE_AS_U8]);
121			Self::permute(&mut self.states, chunks);
122			Self::advance_data(&mut data, RATE_AS_U8);
123		}
124
125		if !data[0].is_empty() {
126			data.iter().enumerate().for_each(|(row_i, row)| {
127				self.buffers[row_i][..row.len()].copy_from_slice(row);
128			});
129			self.filled_bytes = data[0].len();
130		}
131	}
132
133	fn finalize_into(mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
134		self.finalize(out);
135	}
136
137	fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
138		self.finalize(out);
139		self.reset();
140	}
141
142	fn reset(&mut self) {
143		self.states = array::from_fn(|_| Ghash::ZERO);
144		self.buffers = array::from_fn(|_| [0; RATE_AS_U8]);
145		self.filled_bytes = 0;
146	}
147
148	fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
149		let mut digest = Self::default();
150		digest.update(data);
151		digest.finalize_into(out);
152	}
153}
154
155#[cfg(test)]
156mod tests {
157	use std::mem::MaybeUninit;
158
159	use digest::Digest;
160	use rand::{Rng, SeedableRng, rngs::StdRng};
161
162	use super::*;
163
164	// Helper function to generate random data vectors
165	fn generate_random_data<const N: usize>(length: usize, seed: u64) -> Vec<Vec<u8>> {
166		let mut rng = StdRng::seed_from_u64(seed);
167		let mut data_vecs = Vec::new();
168		for _ in 0..N {
169			let mut vec = Vec::with_capacity(length);
170			for _ in 0..length {
171				vec.push(rng.random());
172			}
173			data_vecs.push(vec);
174		}
175		data_vecs
176	}
177
178	// Generic test function that compares parallel vs sequential execution
179	fn test_parallel_vs_sequential<const N: usize, const MN: usize>(
180		data: [&[u8]; N],
181		description: &str,
182	) {
183		// Parallel computation
184		let mut parallel_outputs = [MaybeUninit::uninit(); N];
185		VisionHasherMultiDigest::<N, MN>::digest(data, &mut parallel_outputs);
186		let parallel_results: [Output<VisionHasherDigest>; N] =
187			array::from_fn(|i| unsafe { parallel_outputs[i].assume_init() });
188
189		// Sequential computation
190		let sequential_results: [_; N] = array::from_fn(|i| {
191			let mut hasher = VisionHasherDigest::new();
192			hasher.update(data[i]);
193			hasher.finalize()
194		});
195
196		// Compare results
197		for i in 0..N {
198			assert_eq!(
199				parallel_results[i], sequential_results[i],
200				"Mismatch at index {i} for {description}"
201			);
202		}
203	}
204
205	#[test]
206	fn test_empty_inputs() {
207		const N: usize = 4;
208		let data: [&[u8]; N] = [&[], &[], &[], &[]];
209		test_parallel_vs_sequential::<N, { N * M }>(data, "empty inputs");
210	}
211
212	#[test]
213	fn test_small_inputs() {
214		const N: usize = 2;
215		let data: [&[u8]; N] = [b"Hello... World!", b"Rust is awesome"];
216		test_parallel_vs_sequential::<N, { N * M }>(data, "small inputs");
217	}
218
219	#[test]
220	fn test_multi_block() {
221		const N: usize = 4;
222		// Multiple blocks with different random patterns
223		let target_len = RATE_AS_U8 * 2 + 10;
224		let data_vecs = generate_random_data::<N>(target_len, 42);
225		let data: [&[u8]; N] = array::from_fn(|i| data_vecs[i].as_slice());
226
227		test_parallel_vs_sequential::<N, { N * M }>(data, "multi-block inputs");
228	}
229
230	#[test]
231	fn test_various_sizes() {
232		// Test different sizes separately since parallel requires same length per batch
233		let sizes = [
234			1,
235			RATE_AS_U8 - 7,
236			RATE_AS_U8,
237			RATE_AS_U8 + 5,
238			RATE_AS_U8 * 2 - 3,
239		];
240
241		for &size in &sizes {
242			const N: usize = 2;
243			let data_vecs = generate_random_data::<N>(size, 123);
244			let data: [&[u8]; N] = array::from_fn(|i| data_vecs[i].as_slice());
245			test_parallel_vs_sequential::<N, { N * M }>(data, &format!("size {size}"));
246		}
247	}
248}