binius_hash/
parallel_digest.rs

1// Copyright 2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{array, mem::MaybeUninit};
5
6use binius_utils::{
7	SerializeBytes,
8	rayon::{
9		iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
10		slice::ParallelSliceMut,
11	},
12};
13use bytes::BytesMut;
14use digest::{Digest, Output, core_api::BlockSizeUser};
15
16use crate::HashBuffer;
17
18/// An object that efficiently computes `N` instances of a cryptographic hash function
19/// in parallel.
20///
21/// This trait is useful when there is a more efficient way of calculating multiple digests at once,
22/// e.g. using SIMD instructions. It is supposed that this trait is implemented directly for some
23/// digest and some fixed `N` and passed as an implementation of the `ParallelDigest` trait which
24/// hides the `N` value.
25pub trait MultiDigest<const N: usize>: Clone {
26	/// The corresponding non-parallelized hash function.
27	type Digest: Digest;
28
29	/// Create new hasher instance with empty state.
30	fn new() -> Self;
31
32	/// Create new hasher instance which has processed the provided data.
33	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
34		let mut hasher = Self::new();
35		hasher.update([data.as_ref(); N]);
36		hasher
37	}
38
39	/// Process data, updating the internal state.
40	/// The number of rows in `data` must be equal to `parallel_instances()`.
41	fn update(&mut self, data: [&[u8]; N]);
42
43	/// Process input data in a chained manner.
44	#[must_use]
45	fn chain_update(self, data: [&[u8]; N]) -> Self {
46		let mut hasher = self;
47		hasher.update(data);
48		hasher
49	}
50
51	/// Write result into provided array and consume the hasher instance.
52	fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
53
54	/// Write result into provided array and reset the hasher instance.
55	fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
56
57	/// Reset hasher instance to its initial state.
58	fn reset(&mut self);
59
60	/// Compute hash of `data`.
61	/// All slices in the `data` must have the same length.
62	///
63	/// # Panics
64	/// Panics if data contains slices of different lengths.
65	fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
66}
67
68pub trait ParallelDigest: Send {
69	/// The corresponding non-parallelized hash function.
70	type Digest: Digest + Send;
71
72	/// Create new hasher instance with empty state.
73	fn new() -> Self;
74
75	/// Create new hasher instance which has processed the provided data.
76	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self;
77
78	/// Calculate the digest of multiple hashes by processing a parallel iterator of iterators.
79	///
80	/// The source parameter provides a parallel iterator where:
81	/// - Each element of the outer iterator maps to one leaf/digest in the output
82	/// - Each element contains an inner iterator of items that will be serialized and concatenated
83	///   to form that leaf's content
84	///
85	/// # Panics
86	/// All items must be able to serialize with SerializationMode::Native without error, or this
87	/// method will panic.
88	fn digest<I: IntoIterator<Item: SerializeBytes>>(
89		&self,
90		source: impl IndexedParallelIterator<Item = I>,
91		out: &mut [MaybeUninit<Output<Self::Digest>>],
92	);
93}
94
95/// A wrapper that implements the `ParallelDigest` trait for a `MultiDigest` implementation.
96#[derive(Clone)]
97pub struct ParallelMultidigestImpl<D: MultiDigest<N>, const N: usize>(D);
98
99impl<D: MultiDigest<N, Digest: Send> + Send + Sync, const N: usize> ParallelDigest
100	for ParallelMultidigestImpl<D, N>
101{
102	type Digest = D::Digest;
103
104	fn new() -> Self {
105		Self(D::new())
106	}
107
108	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
109		Self(D::new_with_prefix(data.as_ref()))
110	}
111
112	fn digest<I: IntoIterator<Item: SerializeBytes>>(
113		&self,
114		source: impl IndexedParallelIterator<Item = I>,
115		out: &mut [MaybeUninit<Output<Self::Digest>>],
116	) {
117		let buffers = array::from_fn::<_, N, _>(|_| BytesMut::new());
118		source.chunks(N).zip(out.par_chunks_mut(N)).for_each_with(
119			buffers,
120			|buffers, (data, out_chunk)| {
121				let mut hasher = self.0.clone();
122				for (mut buf, chunk) in buffers.iter_mut().zip(data.into_iter()) {
123					buf.clear();
124					for item in chunk {
125						item.serialize(&mut buf)
126							.expect("pre-condition: items must serialize without error")
127					}
128				}
129				let data = array::from_fn(|i| buffers[i].as_ref());
130				hasher.update(data);
131
132				if out_chunk.len() == N {
133					hasher
134						.finalize_into_reset(out_chunk.try_into().expect("chunk size is correct"));
135				} else {
136					let mut result = array::from_fn::<_, N, _>(|_| MaybeUninit::uninit());
137					hasher.finalize_into(&mut result);
138					for (out, res) in out_chunk.iter_mut().zip(result.into_iter()) {
139						out.write(unsafe { res.assume_init() });
140					}
141				}
142			},
143		);
144	}
145}
146
147impl<D: Digest + BlockSizeUser + Send + Sync + Clone> ParallelDigest for D {
148	type Digest = D;
149
150	fn new() -> Self {
151		Digest::new()
152	}
153
154	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
155		Digest::new_with_prefix(data)
156	}
157
158	fn digest<I: IntoIterator<Item: SerializeBytes>>(
159		&self,
160		source: impl IndexedParallelIterator<Item = I>,
161		out: &mut [MaybeUninit<Output<Self::Digest>>],
162	) {
163		source.zip(out.par_iter_mut()).for_each(|(items, out)| {
164			let mut hasher = self.clone();
165			{
166				let mut buffer = HashBuffer::new(&mut hasher);
167				for item in items {
168					item.serialize(&mut buffer)
169						.expect("pre-condition: items must serialize without error")
170				}
171			}
172			out.write(hasher.finalize());
173		});
174	}
175}
176
177#[cfg(test)]
178mod tests {
179	use std::iter::repeat_with;
180
181	use binius_utils::rayon::iter::IntoParallelRefIterator;
182	use digest::{
183		FixedOutput, HashMarker, OutputSizeUser, Reset, Update,
184		consts::{U1, U32},
185	};
186	use itertools::izip;
187	use rand::{RngCore, SeedableRng, rngs::StdRng};
188
189	use super::*;
190
191	#[derive(Clone, Default)]
192	struct MockDigest {
193		state: u8,
194	}
195
196	impl HashMarker for MockDigest {}
197
198	impl Update for MockDigest {
199		fn update(&mut self, data: &[u8]) {
200			for &byte in data {
201				self.state ^= byte;
202			}
203		}
204	}
205
206	impl Reset for MockDigest {
207		fn reset(&mut self) {
208			self.state = 0;
209		}
210	}
211
212	impl OutputSizeUser for MockDigest {
213		type OutputSize = U32;
214	}
215
216	impl BlockSizeUser for MockDigest {
217		type BlockSize = U1;
218	}
219
220	impl FixedOutput for MockDigest {
221		fn finalize_into(self, out: &mut Output<Self>) {
222			out[0] = self.state;
223			for byte in &mut out[1..] {
224				*byte = 0;
225			}
226		}
227	}
228
229	#[derive(Clone, Default)]
230	struct MockMultiDigest {
231		digests: [MockDigest; 4],
232	}
233
234	impl MultiDigest<4> for MockMultiDigest {
235		type Digest = MockDigest;
236
237		fn new() -> Self {
238			Self::default()
239		}
240
241		fn update(&mut self, data: [&[u8]; 4]) {
242			for (digest, &chunk) in self.digests.iter_mut().zip(data.iter()) {
243				digest::Digest::update(digest, chunk);
244			}
245		}
246
247		fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
248			for (digest, out) in self.digests.into_iter().zip(out.iter_mut()) {
249				let mut output = digest::Output::<Self::Digest>::default();
250				digest::Digest::finalize_into(digest, &mut output);
251				*out = MaybeUninit::new(output);
252			}
253		}
254
255		fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
256			for (digest, out) in self.digests.iter_mut().zip(out.iter_mut()) {
257				let mut digest_copy = MockDigest::default();
258				std::mem::swap(digest, &mut digest_copy);
259				*out = MaybeUninit::new(digest_copy.finalize());
260			}
261			self.reset();
262		}
263
264		fn reset(&mut self) {
265			for digest in &mut self.digests {
266				*digest = MockDigest::default();
267			}
268		}
269
270		fn digest(data: [&[u8]; 4], out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
271			let mut hasher = Self::default();
272			hasher.update(data);
273			hasher.finalize_into(out);
274		}
275	}
276
277	fn generate_mock_data(n_hashes: usize, chunk_size: usize) -> Vec<Vec<u8>> {
278		let mut rng = StdRng::seed_from_u64(0);
279
280		(0..n_hashes)
281			.map(|_| {
282				let mut chunk = vec![0; chunk_size];
283				rng.fill_bytes(&mut chunk);
284				chunk
285			})
286			.collect()
287	}
288
289	fn check_parallel_digest_consistency<
290		D: ParallelDigest<Digest: BlockSizeUser + Send + Sync + Clone>,
291	>(
292		data: Vec<Vec<u8>>,
293	) {
294		let parallel_digest = D::new();
295		let mut parallel_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
296			.take(data.len())
297			.collect::<Vec<_>>();
298		parallel_digest.digest(data.par_iter(), &mut parallel_results);
299
300		let single_digest_as_parallel = <D::Digest as ParallelDigest>::new();
301		let mut single_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
302			.take(data.len())
303			.collect::<Vec<_>>();
304		single_digest_as_parallel.digest(data.par_iter(), &mut single_results);
305
306		let serial_results = data.iter().map(<D::Digest as Digest>::digest);
307
308		for (parallel, single, serial) in izip!(parallel_results, single_results, serial_results) {
309			assert_eq!(unsafe { parallel.assume_init() }, serial);
310			assert_eq!(unsafe { single.assume_init() }, serial);
311		}
312	}
313
314	#[test]
315	fn test_empty_data() {
316		let data = generate_mock_data(0, 16);
317		check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
318	}
319
320	#[test]
321	fn test_non_empty_data() {
322		for n_hashes in [1, 2, 4, 8, 9] {
323			let data = generate_mock_data(n_hashes, 16);
324			check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
325		}
326	}
327}