binius_hash/
vision.rs

1// Copyright 2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{array, mem::MaybeUninit};
5
6use binius_field::{BinaryField128bGhash as Ghash, Field};
7use binius_utils::SerializeBytes;
8use digest::Output;
9use itertools::izip;
10
11use crate::{
12	parallel_digest::MultiDigest,
13	vision_4::{
14		M,
15		digest::{PADDING_BLOCK, RATE_AS_U8, VisionHasherDigest, fill_padding},
16	},
17};
18
19/// A Vision hasher suited for parallelization.
20///
21/// Without using packed fields, there is only one advantage of an explicit parallelized
22/// Vision hasher over invoking the Vision hasher multiple times in parallel:
23/// we can amortize the cost of inversion in the sbox using
24/// [Montogery's trick](https://medium.com/eryxcoop/montgomerys-trick-for-batch-galois-field-inversion-9b6d0f399da2).
25/// We slightly modify Montogery's trick to use a binary tree structure,
26/// maximizing independence of multiplications for better instruction pipelining.
27#[derive(Clone)]
28pub struct VisionHasherMultiDigest<const N: usize> {
29	states: [[Ghash; M]; N],
30	buffers: [[u8; RATE_AS_U8]; N],
31	filled_bytes: usize,
32}
33
34impl<const N: usize> Default for VisionHasherMultiDigest<N> {
35	fn default() -> Self {
36		Self {
37			states: array::from_fn(|_| [Ghash::ZERO; M]),
38			buffers: array::from_fn(|_| [0; RATE_AS_U8]),
39			filled_bytes: 0,
40		}
41	}
42}
43
44impl<const N: usize> VisionHasherMultiDigest<N> {
45	#[inline]
46	fn advance_data(data: &mut [&[u8]; N], bytes: usize) {
47		for i in 0..N {
48			data[i] = &data[i][bytes..];
49		}
50	}
51
52	fn permute(states: &mut [[Ghash; M]; N], data: [&[u8]; N]) {
53		// Trivial independence for now. Optimization with amortization later.
54		izip!(states, data).for_each(|(state, data)| {
55			VisionHasherDigest::permute(state, data);
56		});
57	}
58	fn finalize(&mut self, out: &mut [MaybeUninit<digest::Output<VisionHasherDigest>>; N]) {
59		if self.filled_bytes != 0 {
60			for i in 0..N {
61				fill_padding(&mut self.buffers[i][self.filled_bytes..]);
62			}
63			Self::permute(&mut self.states, array::from_fn(|i| &self.buffers[i][..]));
64		} else {
65			Self::permute(&mut self.states, array::from_fn(|_| &PADDING_BLOCK[..]));
66		}
67
68		// Serialize first two state elements for each digest (32 bytes total per digest)
69		for i in 0..N {
70			let output_slice = out[i].as_mut_ptr() as *mut u8;
71			let output_bytes = unsafe { std::slice::from_raw_parts_mut(output_slice, 32) };
72			let (state0, state1) = output_bytes.split_at_mut(16);
73			self.states[i][0]
74				.serialize(state0)
75				.expect("fits in 16 bytes");
76			self.states[i][1]
77				.serialize(state1)
78				.expect("fits in 16 bytes");
79		}
80	}
81}
82
83impl<const N: usize> MultiDigest<N> for VisionHasherMultiDigest<N> {
84	type Digest = VisionHasherDigest;
85
86	fn new() -> Self {
87		Self::default()
88	}
89
90	fn update(&mut self, mut data: [&[u8]; N]) {
91		data[1..].iter().for_each(|row| {
92			assert_eq!(row.len(), data[0].len());
93		});
94
95		if self.filled_bytes != 0 {
96			let to_copy = std::cmp::min(data[0].len(), RATE_AS_U8 - self.filled_bytes);
97			data.iter().enumerate().for_each(|(row_i, row)| {
98				self.buffers[row_i][self.filled_bytes..self.filled_bytes + to_copy]
99					.copy_from_slice(&row[..to_copy]);
100			});
101			Self::advance_data(&mut data, to_copy);
102			self.filled_bytes += to_copy;
103
104			if self.filled_bytes == RATE_AS_U8 {
105				Self::permute(&mut self.states, array::from_fn(|i| &self.buffers[i][..]));
106				self.filled_bytes = 0;
107			}
108		}
109
110		while data[0].len() >= RATE_AS_U8 {
111			let chunks = array::from_fn(|i| &data[i][..RATE_AS_U8]);
112			Self::permute(&mut self.states, chunks);
113			Self::advance_data(&mut data, RATE_AS_U8);
114		}
115
116		if !data[0].is_empty() {
117			data.iter().enumerate().for_each(|(row_i, row)| {
118				self.buffers[row_i][..row.len()].copy_from_slice(row);
119			});
120			self.filled_bytes = data[0].len();
121		}
122	}
123
124	fn finalize_into(mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
125		self.finalize(out);
126	}
127
128	fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
129		self.finalize(out);
130		self.reset();
131	}
132
133	fn reset(&mut self) {
134		self.states = array::from_fn(|_| [Ghash::ZERO; M]);
135		self.buffers = array::from_fn(|_| [0; RATE_AS_U8]);
136		self.filled_bytes = 0;
137	}
138
139	fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
140		let mut digest = Self::default();
141		digest.update(data);
142		digest.finalize_into(out);
143	}
144}
145
146#[cfg(test)]
147mod tests {
148	use std::mem::MaybeUninit;
149
150	use digest::Digest;
151	use rand::{Rng, SeedableRng, rngs::StdRng};
152
153	use super::*;
154
155	// Helper function to generate random data vectors
156	fn generate_random_data<const N: usize>(length: usize, seed: u64) -> Vec<Vec<u8>> {
157		let mut rng = StdRng::seed_from_u64(seed);
158		let mut data_vecs = Vec::new();
159		for _ in 0..N {
160			let mut vec = Vec::with_capacity(length);
161			for _ in 0..length {
162				vec.push(rng.random());
163			}
164			data_vecs.push(vec);
165		}
166		data_vecs
167	}
168
169	// Generic test function that compares parallel vs sequential execution
170	fn test_parallel_vs_sequential<const N: usize>(data: [&[u8]; N], description: &str) {
171		// Parallel computation
172		let mut parallel_outputs = [MaybeUninit::uninit(); N];
173		VisionHasherMultiDigest::<N>::digest(data, &mut parallel_outputs);
174		let parallel_results: [Output<VisionHasherDigest>; N] =
175			array::from_fn(|i| unsafe { parallel_outputs[i].assume_init() });
176
177		// Sequential computation
178		let sequential_results: [_; N] = array::from_fn(|i| {
179			let mut hasher = VisionHasherDigest::new();
180			hasher.update(data[i]);
181			hasher.finalize()
182		});
183
184		// Compare results
185		for i in 0..N {
186			assert_eq!(
187				parallel_results[i], sequential_results[i],
188				"Mismatch at index {i} for {description}"
189			);
190		}
191	}
192
193	#[test]
194	fn test_empty_inputs() {
195		const N: usize = 4;
196		let data: [&[u8]; N] = [&[], &[], &[], &[]];
197		test_parallel_vs_sequential(data, "empty inputs");
198	}
199
200	#[test]
201	fn test_small_inputs() {
202		const N: usize = 3;
203		let data: [&[u8]; N] = [b"Hello... World!", b"Rust is awesome", b"Parallel hash!!"];
204		test_parallel_vs_sequential(data, "small inputs");
205	}
206
207	#[test]
208	fn test_multi_block() {
209		const N: usize = 4;
210		// Multiple blocks with different random patterns
211		let target_len = RATE_AS_U8 * 2 + 10;
212		let data_vecs = generate_random_data::<N>(target_len, 42);
213		let data: [&[u8]; N] = array::from_fn(|i| data_vecs[i].as_slice());
214
215		test_parallel_vs_sequential(data, "multi-block inputs");
216	}
217
218	#[test]
219	fn test_various_sizes() {
220		// Test different sizes separately since parallel requires same length per batch
221		let sizes = [
222			1,
223			RATE_AS_U8 - 7,
224			RATE_AS_U8,
225			RATE_AS_U8 + 5,
226			RATE_AS_U8 * 2 - 3,
227		];
228
229		for &size in &sizes {
230			const N: usize = 3;
231			let data_vecs = generate_random_data::<N>(size, 123);
232			let data: [&[u8]; N] = array::from_fn(|i| data_vecs[i].as_slice());
233			test_parallel_vs_sequential(data, &format!("size {size}"));
234		}
235	}
236}