Skip to main content

binius_hash/
parallel_digest.rs

1// Copyright 2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{array, marker::PhantomData, mem::MaybeUninit};
5
6use binius_utils::{
7	FixedSizeSerializeBytes, SerializeBytes,
8	rayon::{
9		iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
10		slice::ParallelSliceMut,
11	},
12};
13use bytes::BytesMut;
14use digest::{Digest, FixedOutputReset, Output, block_api::BlockSizeUser};
15
16use crate::HashBuffer;
17
18/// An object that efficiently computes `N` instances of a cryptographic hash function
19/// in parallel.
20///
21/// This trait is useful when there is a more efficient way of calculating multiple digests at once,
22/// e.g. using SIMD instructions. It is supposed that this trait is implemented directly for some
23/// digest and some fixed `N` and passed as an implementation of the `ParallelDigest` trait which
24/// hides the `N` value.
25pub trait MultiDigest<const N: usize>: Clone {
26	/// The corresponding non-parallelized hash function.
27	type Digest: Digest;
28
29	/// Create new hasher instance with empty state.
30	fn new() -> Self;
31
32	/// Create new hasher instance which has processed the provided data.
33	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
34		let mut hasher = Self::new();
35		hasher.update([data.as_ref(); N]);
36		hasher
37	}
38
39	/// Process data, updating the internal state.
40	/// The number of rows in `data` must be equal to `parallel_instances()`.
41	fn update(&mut self, data: [&[u8]; N]);
42
43	/// Process input data in a chained manner.
44	#[must_use]
45	fn chain_update(self, data: [&[u8]; N]) -> Self {
46		let mut hasher = self;
47		hasher.update(data);
48		hasher
49	}
50
51	/// Write result into provided array and consume the hasher instance.
52	fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
53
54	/// Write result into provided array and reset the hasher instance.
55	fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
56
57	/// Reset hasher instance to its initial state.
58	fn reset(&mut self);
59
60	/// Compute hash of `data`.
61	/// All slices in the `data` must have the same length.
62	///
63	/// # Panics
64	/// Panics if data contains slices of different lengths.
65	fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
66}
67
68pub trait ParallelDigest: Send {
69	/// The corresponding non-parallelized hash function.
70	type Digest: Digest;
71
72	/// Create new hasher instance with empty state.
73	fn new() -> Self;
74
75	/// Calculate the digest of multiple hashes by processing a parallel iterator of iterators.
76	///
77	/// The source parameter provides a parallel iterator where:
78	/// - Each element of the outer iterator maps to one leaf/digest in the output
79	/// - Each element contains an inner iterator of items that will be serialized and concatenated
80	///   to form that leaf's content
81	///
82	/// # Panics
83	/// All items must be able to serialize with SerializationMode::Native without error, or this
84	/// method will panic.
85	fn digest<I: IntoIterator<Item: SerializeBytes>>(
86		&self,
87		source: impl IndexedParallelIterator<Item = I>,
88		out: &mut [MaybeUninit<Output<Self::Digest>>],
89	);
90
91	/// Like [`digest`](Self::digest), but specialized for the case where every leaf is built from
92	/// exactly `n_items_per_input` items of a [`FixedSizeSerializeBytes`] type, so that each leaf
93	/// has the same, compile-time-derivable byte length.
94	///
95	/// This extra structure lets implementations skip per-leaf length bookkeeping (and, for short
96	/// leaves, the message padding) that [`digest`](Self::digest) must redo every time. The default
97	/// implementation simply forwards to [`digest`](Self::digest).
98	///
99	/// # Panics
100	/// Each iterator in `source` must yield exactly `n_items_per_input` items, and all items must
101	/// serialize without error, or this method may panic.
102	fn digest_with_const_len<I: IntoIterator<Item: FixedSizeSerializeBytes>>(
103		&self,
104		n_items_per_input: usize,
105		source: impl IndexedParallelIterator<Item = I>,
106		out: &mut [MaybeUninit<Output<Self::Digest>>],
107	) {
108		let _ = n_items_per_input;
109		self.digest(source, out);
110	}
111}
112
113/// A wrapper that implements the `ParallelDigest` trait for a `MultiDigest` implementation.
114#[derive(Clone)]
115pub struct ParallelMultidigestImpl<D: MultiDigest<N>, const N: usize>(D);
116
117impl<D: MultiDigest<N> + Default, const N: usize> Default for ParallelMultidigestImpl<D, N> {
118	fn default() -> Self {
119		Self(D::default())
120	}
121}
122
123impl<D: MultiDigest<N, Digest: Send> + Send + Sync, const N: usize> ParallelDigest
124	for ParallelMultidigestImpl<D, N>
125{
126	type Digest = D::Digest;
127
128	fn new() -> Self {
129		Self(D::new())
130	}
131
132	fn digest<I: IntoIterator<Item: SerializeBytes>>(
133		&self,
134		source: impl IndexedParallelIterator<Item = I>,
135		out: &mut [MaybeUninit<Output<Self::Digest>>],
136	) {
137		let buffers = array::from_fn::<_, N, _>(|_| BytesMut::new());
138		source.chunks(N).zip(out.par_chunks_mut(N)).for_each_with(
139			buffers,
140			|buffers, (data, out_chunk)| {
141				let mut hasher = self.0.clone();
142				for (mut buf, chunk) in buffers.iter_mut().zip(data) {
143					buf.clear();
144					for item in chunk {
145						item.serialize(&mut buf)
146							.expect("pre-condition: items must serialize without error")
147					}
148				}
149				let data = array::from_fn(|i| buffers[i].as_ref());
150				hasher.update(data);
151
152				if out_chunk.len() == N {
153					hasher
154						.finalize_into_reset(out_chunk.try_into().expect("chunk size is correct"));
155				} else {
156					let mut result = array::from_fn::<_, N, _>(|_| MaybeUninit::uninit());
157					hasher.finalize_into(&mut result);
158					for (out, res) in out_chunk.iter_mut().zip(result) {
159						out.write(unsafe { res.assume_init() });
160					}
161				}
162			},
163		);
164	}
165}
166
167/// Adapts a sequential [`Digest`] into a [`ParallelDigest`] that hashes one leaf per element of a
168/// parallel iterator.
169///
170/// Each Rayon work-item is seeded with a single hasher (via `for_each_with`) which is recycled in
171/// place with `finalize_reset` between leaves, rather than cloning a fresh hasher per leaf. This
172/// requires `D: FixedOutputReset`.
173pub struct ParallelDigestAdapter<D>(PhantomData<D>);
174
175impl<D> Default for ParallelDigestAdapter<D> {
176	fn default() -> Self {
177		Self(PhantomData)
178	}
179}
180
181impl<D> ParallelDigest for ParallelDigestAdapter<D>
182where
183	D: Digest + FixedOutputReset + BlockSizeUser + Send + Sync + Clone,
184{
185	type Digest = D;
186
187	fn new() -> Self {
188		Self(PhantomData)
189	}
190
191	fn digest<I: IntoIterator<Item: SerializeBytes>>(
192		&self,
193		source: impl IndexedParallelIterator<Item = I>,
194		out: &mut [MaybeUninit<Output<Self::Digest>>],
195	) {
196		source
197			.zip(out.par_iter_mut())
198			.for_each_with(D::new(), |hasher, (items, out)| {
199				{
200					let mut buffer = HashBuffer::new(hasher);
201					for item in items {
202						item.serialize(&mut buffer)
203							.expect("pre-condition: items must serialize without error")
204					}
205				}
206				out.write(hasher.finalize_reset());
207			});
208	}
209}
210
211#[cfg(test)]
212mod tests {
213	use std::iter::repeat_with;
214
215	use binius_utils::rayon::iter::IntoParallelRefIterator;
216	use digest::{
217		FixedOutput, HashMarker, OutputSizeUser, Reset, Update,
218		consts::{U1, U32},
219	};
220	use itertools::izip;
221	use rand::prelude::*;
222
223	use super::*;
224
225	#[derive(Clone, Default)]
226	struct MockDigest {
227		state: u8,
228	}
229
230	impl HashMarker for MockDigest {}
231
232	impl Update for MockDigest {
233		fn update(&mut self, data: &[u8]) {
234			for &byte in data {
235				self.state ^= byte;
236			}
237		}
238	}
239
240	impl Reset for MockDigest {
241		fn reset(&mut self) {
242			self.state = 0;
243		}
244	}
245
246	impl OutputSizeUser for MockDigest {
247		type OutputSize = U32;
248	}
249
250	impl BlockSizeUser for MockDigest {
251		type BlockSize = U1;
252	}
253
254	impl FixedOutput for MockDigest {
255		fn finalize_into(self, out: &mut Output<Self>) {
256			out[0] = self.state;
257			for byte in &mut out[1..] {
258				*byte = 0;
259			}
260		}
261	}
262
263	#[derive(Clone, Default)]
264	struct MockMultiDigest {
265		digests: [MockDigest; 4],
266	}
267
268	impl MultiDigest<4> for MockMultiDigest {
269		type Digest = MockDigest;
270
271		fn new() -> Self {
272			Self::default()
273		}
274
275		fn update(&mut self, data: [&[u8]; 4]) {
276			for (digest, &chunk) in self.digests.iter_mut().zip(data.iter()) {
277				digest::Digest::update(digest, chunk);
278			}
279		}
280
281		fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
282			for (digest, out) in self.digests.into_iter().zip(out.iter_mut()) {
283				let mut output = digest::Output::<Self::Digest>::default();
284				digest::Digest::finalize_into(digest, &mut output);
285				*out = MaybeUninit::new(output);
286			}
287		}
288
289		fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
290			for (digest, out) in self.digests.iter_mut().zip(out.iter_mut()) {
291				let mut digest_copy = MockDigest::default();
292				std::mem::swap(digest, &mut digest_copy);
293				*out = MaybeUninit::new(digest_copy.finalize());
294			}
295			self.reset();
296		}
297
298		fn reset(&mut self) {
299			for digest in &mut self.digests {
300				*digest = MockDigest::default();
301			}
302		}
303
304		fn digest(data: [&[u8]; 4], out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
305			let mut hasher = Self::default();
306			hasher.update(data);
307			hasher.finalize_into(out);
308		}
309	}
310
311	fn generate_mock_data(n_hashes: usize, chunk_size: usize) -> Vec<Vec<u8>> {
312		let mut rng = StdRng::seed_from_u64(0);
313
314		(0..n_hashes)
315			.map(|_| {
316				let mut chunk = vec![0; chunk_size];
317				rng.fill_bytes(&mut chunk);
318				chunk
319			})
320			.collect()
321	}
322
323	fn check_parallel_digest_consistency<
324		D: ParallelDigest<Digest: BlockSizeUser + Send + Sync + Clone>,
325	>(
326		data: Vec<Vec<u8>>,
327	) {
328		let parallel_digest = D::new();
329		let mut parallel_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
330			.take(data.len())
331			.collect::<Vec<_>>();
332		parallel_digest.digest(data.par_iter(), &mut parallel_results);
333
334		let serial_results = data.iter().map(<D::Digest as Digest>::digest);
335
336		for (parallel, serial) in izip!(parallel_results, serial_results) {
337			assert_eq!(unsafe { parallel.assume_init() }, serial);
338		}
339	}
340
341	#[test]
342	fn test_empty_data() {
343		let data = generate_mock_data(0, 16);
344		check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
345	}
346
347	#[test]
348	fn test_non_empty_data() {
349		for n_hashes in [1, 2, 4, 8, 9] {
350			let data = generate_mock_data(n_hashes, 16);
351			check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
352		}
353	}
354
355	#[test]
356	fn test_adapter_matches_serial_sha256() {
357		use sha2::Sha256;
358
359		for n_hashes in [0, 1, 2, 4, 8, 9, 100] {
360			let data = generate_mock_data(n_hashes, 16);
361
362			let adapter = ParallelDigestAdapter::<Sha256>::new();
363			let mut results = repeat_with(MaybeUninit::<Output<Sha256>>::uninit)
364				.take(data.len())
365				.collect::<Vec<_>>();
366			adapter.digest(data.par_iter(), &mut results);
367
368			for (result, leaf) in results.into_iter().zip(&data) {
369				assert_eq!(unsafe { result.assume_init() }, <Sha256 as Digest>::digest(leaf));
370			}
371		}
372	}
373}