binius_hash/
multi_digest.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::{array, mem::MaybeUninit};
4
5use binius_field::TowerField;
6use binius_maybe_rayon::{
7	iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
8	slice::ParallelSliceMut,
9};
10use binius_utils::{SerializationMode, SerializeBytes};
11use bytes::{BufMut, BytesMut};
12use digest::{Digest, Output, core_api::BlockSizeUser};
13
14use crate::HashBuffer;
15
16/// An object that efficiently computes `N` instances of a cryptographic hash function
17/// in parallel.
18///
19/// This trait is useful when there is a more efficient way of calculating multiple digests at once,
20/// e.g. using SIMD instructions. It is supposed that this trait is implemented directly for some
21/// digest and some fixed `N` and passed as an implementation of the `ParallelDigest` trait which
22/// hides the `N` value.
23pub trait MultiDigest<const N: usize>: Clone {
24	/// The corresponding non-parallelized hash function.
25	type Digest: Digest;
26
27	/// Create new hasher instance with empty state.
28	fn new() -> Self;
29
30	/// Create new hasher instance which has processed the provided data.
31	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
32		let mut hasher = Self::new();
33		hasher.update([data.as_ref(); N]);
34		hasher
35	}
36
37	/// Process data, updating the internal state.
38	/// The number of rows in `data` must be equal to `parallel_instances()`.
39	fn update(&mut self, data: [&[u8]; N]);
40
41	/// Process input data in a chained manner.
42	#[must_use]
43	fn chain_update(self, data: [&[u8]; N]) -> Self {
44		let mut hasher = self;
45		hasher.update(data);
46		hasher
47	}
48
49	/// Write result into provided array and consume the hasher instance.
50	fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
51
52	/// Write result into provided array and reset the hasher instance.
53	fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
54
55	/// Reset hasher instance to its initial state.
56	fn reset(&mut self);
57
58	/// Compute hash of `data`.
59	/// All slices in the `data` must have the same length.
60	///
61	/// # Panics
62	/// Panics if data contains slices of different lengths.
63	fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
64}
65
66pub trait Serializable {
67	fn serialize(self, buffer: impl BufMut);
68}
69
70impl<F: TowerField, I: IntoIterator<Item = F>> Serializable for I {
71	fn serialize(self, mut buffer: impl BufMut) {
72		let mode = SerializationMode::CanonicalTower;
73		for elem in self {
74			SerializeBytes::serialize(&elem, &mut buffer, mode)
75				.expect("buffer must have enough capacity");
76		}
77	}
78}
79
80pub trait ParallelDigest: Send {
81	/// The corresponding non-parallelized hash function.
82	type Digest: digest::Digest + Send;
83
84	/// Create new hasher instance with empty state.
85	fn new() -> Self;
86
87	/// Create new hasher instance which has processed the provided data.
88	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self;
89
90	/// Calculate the digest of multiple hashes where each of them is serialized into
91	/// the same number of bytes.
92	fn digest(
93		&self,
94		source: impl IndexedParallelIterator<Item: Serializable>,
95		out: &mut [MaybeUninit<Output<Self::Digest>>],
96	);
97}
98
99/// A wrapper that implements the `ParallelDigest` trait for a `MultiDigest` implementation.
100#[derive(Clone)]
101pub struct ParallelMultidigestImpl<D: MultiDigest<N>, const N: usize>(D);
102
103impl<D: MultiDigest<N, Digest: Send> + Send + Sync, const N: usize> ParallelDigest
104	for ParallelMultidigestImpl<D, N>
105{
106	type Digest = D::Digest;
107
108	fn new() -> Self {
109		Self(D::new())
110	}
111
112	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
113		Self(D::new_with_prefix(data.as_ref()))
114	}
115
116	fn digest(
117		&self,
118		source: impl IndexedParallelIterator<Item: Serializable>,
119		out: &mut [MaybeUninit<Output<Self::Digest>>],
120	) {
121		let buffers = array::from_fn::<_, N, _>(|_| BytesMut::new());
122		source.chunks(N).zip(out.par_chunks_mut(N)).for_each_with(
123			buffers,
124			|buffers, (data, out_chunk)| {
125				let mut hasher = self.0.clone();
126				for (buf, chunk) in buffers.iter_mut().zip(data.into_iter()) {
127					buf.clear();
128					chunk.serialize(buf);
129				}
130				let data = array::from_fn(|i| buffers[i].as_ref());
131				hasher.update(data);
132
133				if out_chunk.len() == N {
134					hasher
135						.finalize_into_reset(out_chunk.try_into().expect("chunk size is correct"));
136				} else {
137					let mut result = array::from_fn::<_, N, _>(|_| MaybeUninit::uninit());
138					hasher.finalize_into(&mut result);
139					for (out, res) in out_chunk.iter_mut().zip(result.into_iter()) {
140						out.write(unsafe { res.assume_init() });
141					}
142				}
143			},
144		);
145	}
146}
147
148impl<D: Digest + BlockSizeUser + Send + Sync + Clone> ParallelDigest for D {
149	type Digest = D;
150
151	fn new() -> Self {
152		Digest::new()
153	}
154
155	fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
156		Digest::new_with_prefix(data)
157	}
158
159	fn digest(
160		&self,
161		source: impl IndexedParallelIterator<Item: Serializable>,
162		out: &mut [MaybeUninit<Output<Self::Digest>>],
163	) {
164		source.zip(out.par_iter_mut()).for_each(|(data, out)| {
165			let mut hasher = self.clone();
166			{
167				let mut buffer = HashBuffer::new(&mut hasher);
168				data.serialize(&mut buffer);
169			}
170			out.write(hasher.finalize());
171		});
172	}
173}
174
175#[cfg(test)]
176mod tests {
177	use std::iter::repeat_with;
178
179	use binius_maybe_rayon::iter::IntoParallelRefIterator;
180	use digest::{
181		FixedOutput, HashMarker, OutputSizeUser, Reset, Update,
182		consts::{U1, U32},
183	};
184	use itertools::izip;
185	use rand::{RngCore, SeedableRng, rngs::StdRng};
186
187	use super::*;
188
189	#[derive(Clone, Default)]
190	struct MockDigest {
191		state: u8,
192	}
193
194	impl HashMarker for MockDigest {}
195
196	impl Update for MockDigest {
197		fn update(&mut self, data: &[u8]) {
198			for &byte in data {
199				self.state ^= byte;
200			}
201		}
202	}
203
204	impl Reset for MockDigest {
205		fn reset(&mut self) {
206			self.state = 0;
207		}
208	}
209
210	impl OutputSizeUser for MockDigest {
211		type OutputSize = U32;
212	}
213
214	impl BlockSizeUser for MockDigest {
215		type BlockSize = U1;
216	}
217
218	impl FixedOutput for MockDigest {
219		fn finalize_into(self, out: &mut digest::Output<Self>) {
220			out[0] = self.state;
221			for byte in &mut out[1..] {
222				*byte = 0;
223			}
224		}
225	}
226
227	#[derive(Clone, Default)]
228	struct MockMultiDigest {
229		digests: [MockDigest; 4],
230	}
231
232	impl MultiDigest<4> for MockMultiDigest {
233		type Digest = MockDigest;
234
235		fn new() -> Self {
236			Self::default()
237		}
238
239		fn update(&mut self, data: [&[u8]; 4]) {
240			for (digest, &chunk) in self.digests.iter_mut().zip(data.iter()) {
241				digest::Digest::update(digest, chunk);
242			}
243		}
244
245		fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
246			for (digest, out) in self.digests.into_iter().zip(out.iter_mut()) {
247				let mut output = digest::Output::<Self::Digest>::default();
248				digest::Digest::finalize_into(digest, &mut output);
249				*out = MaybeUninit::new(output);
250			}
251		}
252
253		fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
254			for (digest, out) in self.digests.iter_mut().zip(out.iter_mut()) {
255				let mut digest_copy = MockDigest::default();
256				std::mem::swap(digest, &mut digest_copy);
257				*out = MaybeUninit::new(digest_copy.finalize());
258			}
259			self.reset();
260		}
261
262		fn reset(&mut self) {
263			for digest in &mut self.digests {
264				*digest = MockDigest::default();
265			}
266		}
267
268		fn digest(data: [&[u8]; 4], out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
269			let mut hasher = Self::default();
270			hasher.update(data);
271			hasher.finalize_into(out);
272		}
273	}
274
275	struct DataWrapper(Vec<u8>);
276
277	impl Serializable for &DataWrapper {
278		fn serialize(self, mut buffer: impl BufMut) {
279			buffer.put_slice(&self.0);
280		}
281	}
282
283	fn generate_mock_data(n_hashes: usize, chunk_size: usize) -> Vec<DataWrapper> {
284		let mut rng = StdRng::seed_from_u64(0);
285
286		(0..n_hashes)
287			.map(|_| {
288				let mut chunk = vec![0; chunk_size];
289				rng.fill_bytes(&mut chunk);
290
291				DataWrapper(chunk)
292			})
293			.collect()
294	}
295
296	fn check_parallel_digest_consistency<
297		D: ParallelDigest<Digest: BlockSizeUser + Send + Sync + Clone>,
298	>(
299		data: Vec<DataWrapper>,
300	) {
301		let parallel_digest = D::new();
302		let mut parallel_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
303			.take(data.len())
304			.collect::<Vec<_>>();
305		parallel_digest.digest(data.par_iter(), &mut parallel_results);
306
307		let single_digest_as_parallel = <D::Digest as ParallelDigest>::new();
308		let mut single_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
309			.take(data.len())
310			.collect::<Vec<_>>();
311		single_digest_as_parallel.digest(data.par_iter(), &mut single_results);
312
313		let serial_results = data
314			.iter()
315			.map(move |data| <D::Digest as digest::Digest>::digest(&data.0));
316
317		for (parallel, single, serial) in izip!(parallel_results, single_results, serial_results) {
318			assert_eq!(unsafe { parallel.assume_init() }, serial);
319			assert_eq!(unsafe { single.assume_init() }, serial);
320		}
321	}
322
323	#[test]
324	fn test_empty_data() {
325		let data = generate_mock_data(0, 16);
326		check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
327	}
328
329	#[test]
330	fn test_non_empty_data() {
331		for n_hashes in [1, 2, 4, 8, 9] {
332			let data = generate_mock_data(n_hashes, 16);
333			check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
334		}
335	}
336}