binius_hash/
parallel_digest.rs1use std::{array, mem::MaybeUninit};
5
6use binius_utils::{
7 SerializeBytes,
8 rayon::{
9 iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
10 slice::ParallelSliceMut,
11 },
12};
13use bytes::BytesMut;
14use digest::{Digest, Output, core_api::BlockSizeUser};
15
16use crate::HashBuffer;
17
18pub trait MultiDigest<const N: usize>: Clone {
26 type Digest: Digest;
28
29 fn new() -> Self;
31
32 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
34 let mut hasher = Self::new();
35 hasher.update([data.as_ref(); N]);
36 hasher
37 }
38
39 fn update(&mut self, data: [&[u8]; N]);
42
43 #[must_use]
45 fn chain_update(self, data: [&[u8]; N]) -> Self {
46 let mut hasher = self;
47 hasher.update(data);
48 hasher
49 }
50
51 fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
53
54 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
56
57 fn reset(&mut self);
59
60 fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
66}
67
68pub trait ParallelDigest: Send {
69 type Digest: Digest + Send;
71
72 fn new() -> Self;
74
75 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self;
77
78 fn digest<I: IntoIterator<Item: SerializeBytes>>(
89 &self,
90 source: impl IndexedParallelIterator<Item = I>,
91 out: &mut [MaybeUninit<Output<Self::Digest>>],
92 );
93}
94
95#[derive(Clone)]
97pub struct ParallelMultidigestImpl<D: MultiDigest<N>, const N: usize>(D);
98
99impl<D: MultiDigest<N, Digest: Send> + Send + Sync, const N: usize> ParallelDigest
100 for ParallelMultidigestImpl<D, N>
101{
102 type Digest = D::Digest;
103
104 fn new() -> Self {
105 Self(D::new())
106 }
107
108 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
109 Self(D::new_with_prefix(data.as_ref()))
110 }
111
112 fn digest<I: IntoIterator<Item: SerializeBytes>>(
113 &self,
114 source: impl IndexedParallelIterator<Item = I>,
115 out: &mut [MaybeUninit<Output<Self::Digest>>],
116 ) {
117 let buffers = array::from_fn::<_, N, _>(|_| BytesMut::new());
118 source.chunks(N).zip(out.par_chunks_mut(N)).for_each_with(
119 buffers,
120 |buffers, (data, out_chunk)| {
121 let mut hasher = self.0.clone();
122 for (mut buf, chunk) in buffers.iter_mut().zip(data.into_iter()) {
123 buf.clear();
124 for item in chunk {
125 item.serialize(&mut buf)
126 .expect("pre-condition: items must serialize without error")
127 }
128 }
129 let data = array::from_fn(|i| buffers[i].as_ref());
130 hasher.update(data);
131
132 if out_chunk.len() == N {
133 hasher
134 .finalize_into_reset(out_chunk.try_into().expect("chunk size is correct"));
135 } else {
136 let mut result = array::from_fn::<_, N, _>(|_| MaybeUninit::uninit());
137 hasher.finalize_into(&mut result);
138 for (out, res) in out_chunk.iter_mut().zip(result.into_iter()) {
139 out.write(unsafe { res.assume_init() });
140 }
141 }
142 },
143 );
144 }
145}
146
147impl<D: Digest + BlockSizeUser + Send + Sync + Clone> ParallelDigest for D {
148 type Digest = D;
149
150 fn new() -> Self {
151 Digest::new()
152 }
153
154 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
155 Digest::new_with_prefix(data)
156 }
157
158 fn digest<I: IntoIterator<Item: SerializeBytes>>(
159 &self,
160 source: impl IndexedParallelIterator<Item = I>,
161 out: &mut [MaybeUninit<Output<Self::Digest>>],
162 ) {
163 source.zip(out.par_iter_mut()).for_each(|(items, out)| {
164 let mut hasher = self.clone();
165 {
166 let mut buffer = HashBuffer::new(&mut hasher);
167 for item in items {
168 item.serialize(&mut buffer)
169 .expect("pre-condition: items must serialize without error")
170 }
171 }
172 out.write(hasher.finalize());
173 });
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use std::iter::repeat_with;
180
181 use binius_utils::rayon::iter::IntoParallelRefIterator;
182 use digest::{
183 FixedOutput, HashMarker, OutputSizeUser, Reset, Update,
184 consts::{U1, U32},
185 };
186 use itertools::izip;
187 use rand::{RngCore, SeedableRng, rngs::StdRng};
188
189 use super::*;
190
191 #[derive(Clone, Default)]
192 struct MockDigest {
193 state: u8,
194 }
195
196 impl HashMarker for MockDigest {}
197
198 impl Update for MockDigest {
199 fn update(&mut self, data: &[u8]) {
200 for &byte in data {
201 self.state ^= byte;
202 }
203 }
204 }
205
206 impl Reset for MockDigest {
207 fn reset(&mut self) {
208 self.state = 0;
209 }
210 }
211
212 impl OutputSizeUser for MockDigest {
213 type OutputSize = U32;
214 }
215
216 impl BlockSizeUser for MockDigest {
217 type BlockSize = U1;
218 }
219
220 impl FixedOutput for MockDigest {
221 fn finalize_into(self, out: &mut Output<Self>) {
222 out[0] = self.state;
223 for byte in &mut out[1..] {
224 *byte = 0;
225 }
226 }
227 }
228
229 #[derive(Clone, Default)]
230 struct MockMultiDigest {
231 digests: [MockDigest; 4],
232 }
233
234 impl MultiDigest<4> for MockMultiDigest {
235 type Digest = MockDigest;
236
237 fn new() -> Self {
238 Self::default()
239 }
240
241 fn update(&mut self, data: [&[u8]; 4]) {
242 for (digest, &chunk) in self.digests.iter_mut().zip(data.iter()) {
243 digest::Digest::update(digest, chunk);
244 }
245 }
246
247 fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
248 for (digest, out) in self.digests.into_iter().zip(out.iter_mut()) {
249 let mut output = digest::Output::<Self::Digest>::default();
250 digest::Digest::finalize_into(digest, &mut output);
251 *out = MaybeUninit::new(output);
252 }
253 }
254
255 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
256 for (digest, out) in self.digests.iter_mut().zip(out.iter_mut()) {
257 let mut digest_copy = MockDigest::default();
258 std::mem::swap(digest, &mut digest_copy);
259 *out = MaybeUninit::new(digest_copy.finalize());
260 }
261 self.reset();
262 }
263
264 fn reset(&mut self) {
265 for digest in &mut self.digests {
266 *digest = MockDigest::default();
267 }
268 }
269
270 fn digest(data: [&[u8]; 4], out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
271 let mut hasher = Self::default();
272 hasher.update(data);
273 hasher.finalize_into(out);
274 }
275 }
276
277 fn generate_mock_data(n_hashes: usize, chunk_size: usize) -> Vec<Vec<u8>> {
278 let mut rng = StdRng::seed_from_u64(0);
279
280 (0..n_hashes)
281 .map(|_| {
282 let mut chunk = vec![0; chunk_size];
283 rng.fill_bytes(&mut chunk);
284 chunk
285 })
286 .collect()
287 }
288
289 fn check_parallel_digest_consistency<
290 D: ParallelDigest<Digest: BlockSizeUser + Send + Sync + Clone>,
291 >(
292 data: Vec<Vec<u8>>,
293 ) {
294 let parallel_digest = D::new();
295 let mut parallel_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
296 .take(data.len())
297 .collect::<Vec<_>>();
298 parallel_digest.digest(data.par_iter(), &mut parallel_results);
299
300 let single_digest_as_parallel = <D::Digest as ParallelDigest>::new();
301 let mut single_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
302 .take(data.len())
303 .collect::<Vec<_>>();
304 single_digest_as_parallel.digest(data.par_iter(), &mut single_results);
305
306 let serial_results = data.iter().map(<D::Digest as Digest>::digest);
307
308 for (parallel, single, serial) in izip!(parallel_results, single_results, serial_results) {
309 assert_eq!(unsafe { parallel.assume_init() }, serial);
310 assert_eq!(unsafe { single.assume_init() }, serial);
311 }
312 }
313
314 #[test]
315 fn test_empty_data() {
316 let data = generate_mock_data(0, 16);
317 check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
318 }
319
320 #[test]
321 fn test_non_empty_data() {
322 for n_hashes in [1, 2, 4, 8, 9] {
323 let data = generate_mock_data(n_hashes, 16);
324 check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
325 }
326 }
327}