binius_hash/
multi_digest.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{array, mem::MaybeUninit};
4
5use binius_field::TowerField;
6use binius_maybe_rayon::{
7	iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
8	slice::ParallelSliceMut,
9};
10use binius_utils::{SerializationMode, SerializeBytes};
11use bytes::{BufMut, BytesMut};
12use digest::{Digest, Output};
13
14/// An object that efficiently computes `N` instances of a cryptographic hash function
15/// in parallel.
16///
17/// This trait is useful when there is a more efficient way of calculating multiple digests at once,
18/// e.g. using SIMD instructions. It is supposed that this trait is implemented directly for some digest and
19/// some fixed `N` and passed as an implementation of the `ParallelDigest` trait which hides the `N` value.
20pub trait MultiDigest<const N: usize>: Clone {
21	/// The corresponding non-parallelized hash function.
22	type Digest: Digest;
23
24	/// Create new hasher instance with empty state.
25	fn new() -> Self;
26
27	/// Create new hasher instance which has processed the provided data.
28	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
29		let mut hasher = Self::new();
30		hasher.update([data.as_ref(); N]);
31		hasher
32	}
33
34	/// Process data, updating the internal state.
35	/// The number of rows in `data` must be equal to `parallel_instances()`.
36	fn update(&mut self, data: [&[u8]; N]);
37
38	/// Process input data in a chained manner.
39	#[must_use]
40	fn chain_update(self, data: [&[u8]; N]) -> Self {
41		let mut hasher = self;
42		hasher.update(data);
43		hasher
44	}
45
46	/// Write result into provided array and consume the hasher instance.
47	fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
48
49	/// Write result into provided array and reset the hasher instance.
50	fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
51
52	/// Reset hasher instance to its initial state.
53	fn reset(&mut self);
54
55	/// Compute hash of `data`.
56	/// All slices in the `data` must have the same length.
57	///
58	/// # Panics
59	/// Panics if data contains slices of different lengths.
60	fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
61}
62
63pub trait Serializable {
64	fn serialize(self, buffer: impl BufMut);
65}
66
67impl<F: TowerField, I: IntoIterator<Item = F>> Serializable for I {
68	fn serialize(self, mut buffer: impl BufMut) {
69		let mode = SerializationMode::CanonicalTower;
70		for elem in self {
71			SerializeBytes::serialize(&elem, &mut buffer, mode)
72				.expect("buffer must have enough capacity");
73		}
74	}
75}
76
77pub trait ParallelDigest: Send {
78	/// The corresponding non-parallelized hash function.
79	type Digest: digest::Digest + Send;
80
81	/// Create new hasher instance with empty state.
82	fn new() -> Self;
83
84	/// Create new hasher instance which has processed the provided data.
85	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self;
86
87	/// Calculate the digest of multiple hashes where each of them is serialized into
88	/// the same number of bytes.
89	fn digest(
90		&self,
91		source: impl IndexedParallelIterator<Item: Serializable>,
92		out: &mut [MaybeUninit<Output<Self::Digest>>],
93	);
94}
95
96/// A wrapper that implements the `ParallelDigest` trait for a `MultiDigest` implementation.
97#[derive(Clone)]
98pub struct ParallelMulidigestImpl<D: MultiDigest<N>, const N: usize>(D);
99
100impl<D: MultiDigest<N, Digest: Send> + Send + Sync, const N: usize> ParallelDigest
101	for ParallelMulidigestImpl<D, N>
102{
103	type Digest = D::Digest;
104
105	fn new() -> Self {
106		Self(D::new())
107	}
108
109	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
110		Self(D::new_with_prefix(data.as_ref()))
111	}
112
113	fn digest(
114		&self,
115		source: impl IndexedParallelIterator<Item: Serializable>,
116		out: &mut [MaybeUninit<Output<Self::Digest>>],
117	) {
118		let buffers = array::from_fn::<_, N, _>(|_| BytesMut::new());
119		source.chunks(N).zip(out.par_chunks_mut(N)).for_each_with(
120			buffers,
121			|buffers, (data, out_chunk)| {
122				let mut hasher = self.0.clone();
123				for (buf, chunk) in buffers.iter_mut().zip(data.into_iter()) {
124					buf.clear();
125					chunk.serialize(buf);
126				}
127				let data = array::from_fn(|i| buffers[i].as_ref());
128				hasher.update(data);
129
130				if out_chunk.len() == N {
131					hasher
132						.finalize_into_reset(out_chunk.try_into().expect("chunk size is correct"));
133				} else {
134					let mut result = array::from_fn::<_, N, _>(|_| MaybeUninit::uninit());
135					hasher.finalize_into(&mut result);
136					for (out, res) in out_chunk.iter_mut().zip(result.into_iter()) {
137						out.write(unsafe { res.assume_init() });
138					}
139				}
140			},
141		);
142	}
143}
144
145impl<D: Digest + Send + Sync + Clone> ParallelDigest for D {
146	type Digest = D;
147
148	fn new() -> Self {
149		Digest::new()
150	}
151
152	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
153		Digest::new_with_prefix(data)
154	}
155
156	fn digest(
157		&self,
158		source: impl IndexedParallelIterator<Item: Serializable>,
159		out: &mut [MaybeUninit<Output<Self::Digest>>],
160	) {
161		source
162			.zip(out.par_iter_mut())
163			.for_each_with(BytesMut::new(), |mut buffer, (data, out)| {
164				buffer.clear();
165				data.serialize(&mut buffer);
166
167				let mut hasher = self.clone();
168				hasher.update(buffer.as_ref());
169				out.write(hasher.finalize());
170			});
171	}
172}
173
174#[cfg(test)]
175mod tests {
176	use std::iter::repeat_with;
177
178	use binius_maybe_rayon::iter::IntoParallelRefIterator;
179	use digest::{consts::U32, Digest, FixedOutput, HashMarker, OutputSizeUser, Update};
180	use itertools::izip;
181	use rand::{rngs::StdRng, RngCore, SeedableRng};
182
183	use super::*;
184
185	#[derive(Clone, Default)]
186	struct MockDigest {
187		state: u8,
188	}
189
190	impl HashMarker for MockDigest {}
191
192	impl Update for MockDigest {
193		fn update(&mut self, data: &[u8]) {
194			for &byte in data {
195				self.state ^= byte;
196			}
197		}
198	}
199
200	impl OutputSizeUser for MockDigest {
201		type OutputSize = U32;
202	}
203
204	impl FixedOutput for MockDigest {
205		fn finalize_into(self, out: &mut digest::Output<Self>) {
206			out[0] = self.state;
207			for byte in &mut out[1..] {
208				*byte = 0;
209			}
210		}
211	}
212
213	#[derive(Clone, Default)]
214	struct MockMultiDigest {
215		digests: [MockDigest; 4],
216	}
217
218	impl MultiDigest<4> for MockMultiDigest {
219		type Digest = MockDigest;
220
221		fn new() -> Self {
222			Self::default()
223		}
224
225		fn update(&mut self, data: [&[u8]; 4]) {
226			for (digest, &chunk) in self.digests.iter_mut().zip(data.iter()) {
227				digest::Digest::update(digest, chunk);
228			}
229		}
230
231		fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
232			for (digest, out) in self.digests.into_iter().zip(out.iter_mut()) {
233				let mut output = digest::Output::<Self::Digest>::default();
234				digest::Digest::finalize_into(digest, &mut output);
235				*out = MaybeUninit::new(output);
236			}
237		}
238
239		fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
240			for (digest, out) in self.digests.iter_mut().zip(out.iter_mut()) {
241				let mut digest_copy = MockDigest::default();
242				std::mem::swap(digest, &mut digest_copy);
243				*out = MaybeUninit::new(digest_copy.finalize());
244			}
245			self.reset();
246		}
247
248		fn reset(&mut self) {
249			for digest in &mut self.digests {
250				*digest = MockDigest::default();
251			}
252		}
253
254		fn digest(data: [&[u8]; 4], out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
255			let mut hasher = Self::default();
256			hasher.update(data);
257			hasher.finalize_into(out);
258		}
259	}
260
261	struct DataWrapper(Vec<u8>);
262
263	impl Serializable for &DataWrapper {
264		fn serialize(self, mut buffer: impl BufMut) {
265			buffer.put_slice(&self.0);
266		}
267	}
268
269	fn generate_mock_data(n_hashes: usize, chunk_size: usize) -> Vec<DataWrapper> {
270		let mut rng = StdRng::seed_from_u64(0);
271
272		(0..n_hashes)
273			.map(|_| {
274				let mut chunk = vec![0; chunk_size];
275				rng.fill_bytes(&mut chunk);
276
277				DataWrapper(chunk)
278			})
279			.collect()
280	}
281
282	fn check_parallel_digest_consistency<D: ParallelDigest<Digest: Send + Sync + Clone>>(
283		data: Vec<DataWrapper>,
284	) {
285		let parallel_digest = D::new();
286		let mut parallel_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
287			.take(data.len())
288			.collect::<Vec<_>>();
289		parallel_digest.digest(data.par_iter(), &mut parallel_results);
290
291		let single_digest_as_parallel = <D::Digest as ParallelDigest>::new();
292		let mut single_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
293			.take(data.len())
294			.collect::<Vec<_>>();
295		single_digest_as_parallel.digest(data.par_iter(), &mut single_results);
296
297		let serial_results = data
298			.iter()
299			.map(move |data| <D::Digest as digest::Digest>::digest(&data.0));
300
301		for (parallel, single, serial) in izip!(parallel_results, single_results, serial_results) {
302			assert_eq!(unsafe { parallel.assume_init() }, serial);
303			assert_eq!(unsafe { single.assume_init() }, serial);
304		}
305	}
306
307	#[test]
308	fn test_empty_data() {
309		let data = generate_mock_data(0, 16);
310		check_parallel_digest_consistency::<ParallelMulidigestImpl<MockMultiDigest, 4>>(data);
311	}
312
313	#[test]
314	fn test_non_empty_data() {
315		for n_hashes in [1, 2, 4, 8, 9] {
316			let data = generate_mock_data(n_hashes, 16);
317			check_parallel_digest_consistency::<ParallelMulidigestImpl<MockMultiDigest, 4>>(data);
318		}
319	}
320}