1use std::{array, marker::PhantomData, mem::MaybeUninit};
5
6use binius_utils::{
7 FixedSizeSerializeBytes, SerializeBytes,
8 rayon::{
9 iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
10 slice::ParallelSliceMut,
11 },
12};
13use bytes::BytesMut;
14use digest::{Digest, FixedOutputReset, Output, block_api::BlockSizeUser};
15
16use crate::HashBuffer;
17
18pub trait MultiDigest<const N: usize>: Clone {
26 type Digest: Digest;
28
29 fn new() -> Self;
31
32 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
34 let mut hasher = Self::new();
35 hasher.update([data.as_ref(); N]);
36 hasher
37 }
38
39 fn update(&mut self, data: [&[u8]; N]);
42
43 #[must_use]
45 fn chain_update(self, data: [&[u8]; N]) -> Self {
46 let mut hasher = self;
47 hasher.update(data);
48 hasher
49 }
50
51 fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
53
54 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
56
57 fn reset(&mut self);
59
60 fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
66}
67
68pub trait ParallelDigest: Send {
69 type Digest: Digest;
71
72 fn new() -> Self;
74
75 fn digest<I: IntoIterator<Item: SerializeBytes>>(
86 &self,
87 source: impl IndexedParallelIterator<Item = I>,
88 out: &mut [MaybeUninit<Output<Self::Digest>>],
89 );
90
91 fn digest_with_const_len<I: IntoIterator<Item: FixedSizeSerializeBytes>>(
103 &self,
104 n_items_per_input: usize,
105 source: impl IndexedParallelIterator<Item = I>,
106 out: &mut [MaybeUninit<Output<Self::Digest>>],
107 ) {
108 let _ = n_items_per_input;
109 self.digest(source, out);
110 }
111}
112
113#[derive(Clone)]
115pub struct ParallelMultidigestImpl<D: MultiDigest<N>, const N: usize>(D);
116
117impl<D: MultiDigest<N> + Default, const N: usize> Default for ParallelMultidigestImpl<D, N> {
118 fn default() -> Self {
119 Self(D::default())
120 }
121}
122
123impl<D: MultiDigest<N, Digest: Send> + Send + Sync, const N: usize> ParallelDigest
124 for ParallelMultidigestImpl<D, N>
125{
126 type Digest = D::Digest;
127
128 fn new() -> Self {
129 Self(D::new())
130 }
131
132 fn digest<I: IntoIterator<Item: SerializeBytes>>(
133 &self,
134 source: impl IndexedParallelIterator<Item = I>,
135 out: &mut [MaybeUninit<Output<Self::Digest>>],
136 ) {
137 let buffers = array::from_fn::<_, N, _>(|_| BytesMut::new());
138 source.chunks(N).zip(out.par_chunks_mut(N)).for_each_with(
139 buffers,
140 |buffers, (data, out_chunk)| {
141 let mut hasher = self.0.clone();
142 for (mut buf, chunk) in buffers.iter_mut().zip(data) {
143 buf.clear();
144 for item in chunk {
145 item.serialize(&mut buf)
146 .expect("pre-condition: items must serialize without error")
147 }
148 }
149 let data = array::from_fn(|i| buffers[i].as_ref());
150 hasher.update(data);
151
152 if out_chunk.len() == N {
153 hasher
154 .finalize_into_reset(out_chunk.try_into().expect("chunk size is correct"));
155 } else {
156 let mut result = array::from_fn::<_, N, _>(|_| MaybeUninit::uninit());
157 hasher.finalize_into(&mut result);
158 for (out, res) in out_chunk.iter_mut().zip(result) {
159 out.write(unsafe { res.assume_init() });
160 }
161 }
162 },
163 );
164 }
165}
166
167pub struct ParallelDigestAdapter<D>(PhantomData<D>);
174
175impl<D> Default for ParallelDigestAdapter<D> {
176 fn default() -> Self {
177 Self(PhantomData)
178 }
179}
180
181impl<D> ParallelDigest for ParallelDigestAdapter<D>
182where
183 D: Digest + FixedOutputReset + BlockSizeUser + Send + Sync + Clone,
184{
185 type Digest = D;
186
187 fn new() -> Self {
188 Self(PhantomData)
189 }
190
191 fn digest<I: IntoIterator<Item: SerializeBytes>>(
192 &self,
193 source: impl IndexedParallelIterator<Item = I>,
194 out: &mut [MaybeUninit<Output<Self::Digest>>],
195 ) {
196 source
197 .zip(out.par_iter_mut())
198 .for_each_with(D::new(), |hasher, (items, out)| {
199 {
200 let mut buffer = HashBuffer::new(hasher);
201 for item in items {
202 item.serialize(&mut buffer)
203 .expect("pre-condition: items must serialize without error")
204 }
205 }
206 out.write(hasher.finalize_reset());
207 });
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use std::iter::repeat_with;
214
215 use binius_utils::rayon::iter::IntoParallelRefIterator;
216 use digest::{
217 FixedOutput, HashMarker, OutputSizeUser, Reset, Update,
218 consts::{U1, U32},
219 };
220 use itertools::izip;
221 use rand::prelude::*;
222
223 use super::*;
224
225 #[derive(Clone, Default)]
226 struct MockDigest {
227 state: u8,
228 }
229
230 impl HashMarker for MockDigest {}
231
232 impl Update for MockDigest {
233 fn update(&mut self, data: &[u8]) {
234 for &byte in data {
235 self.state ^= byte;
236 }
237 }
238 }
239
240 impl Reset for MockDigest {
241 fn reset(&mut self) {
242 self.state = 0;
243 }
244 }
245
246 impl OutputSizeUser for MockDigest {
247 type OutputSize = U32;
248 }
249
250 impl BlockSizeUser for MockDigest {
251 type BlockSize = U1;
252 }
253
254 impl FixedOutput for MockDigest {
255 fn finalize_into(self, out: &mut Output<Self>) {
256 out[0] = self.state;
257 for byte in &mut out[1..] {
258 *byte = 0;
259 }
260 }
261 }
262
263 #[derive(Clone, Default)]
264 struct MockMultiDigest {
265 digests: [MockDigest; 4],
266 }
267
268 impl MultiDigest<4> for MockMultiDigest {
269 type Digest = MockDigest;
270
271 fn new() -> Self {
272 Self::default()
273 }
274
275 fn update(&mut self, data: [&[u8]; 4]) {
276 for (digest, &chunk) in self.digests.iter_mut().zip(data.iter()) {
277 digest::Digest::update(digest, chunk);
278 }
279 }
280
281 fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
282 for (digest, out) in self.digests.into_iter().zip(out.iter_mut()) {
283 let mut output = digest::Output::<Self::Digest>::default();
284 digest::Digest::finalize_into(digest, &mut output);
285 *out = MaybeUninit::new(output);
286 }
287 }
288
289 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
290 for (digest, out) in self.digests.iter_mut().zip(out.iter_mut()) {
291 let mut digest_copy = MockDigest::default();
292 std::mem::swap(digest, &mut digest_copy);
293 *out = MaybeUninit::new(digest_copy.finalize());
294 }
295 self.reset();
296 }
297
298 fn reset(&mut self) {
299 for digest in &mut self.digests {
300 *digest = MockDigest::default();
301 }
302 }
303
304 fn digest(data: [&[u8]; 4], out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
305 let mut hasher = Self::default();
306 hasher.update(data);
307 hasher.finalize_into(out);
308 }
309 }
310
311 fn generate_mock_data(n_hashes: usize, chunk_size: usize) -> Vec<Vec<u8>> {
312 let mut rng = StdRng::seed_from_u64(0);
313
314 (0..n_hashes)
315 .map(|_| {
316 let mut chunk = vec![0; chunk_size];
317 rng.fill_bytes(&mut chunk);
318 chunk
319 })
320 .collect()
321 }
322
323 fn check_parallel_digest_consistency<
324 D: ParallelDigest<Digest: BlockSizeUser + Send + Sync + Clone>,
325 >(
326 data: Vec<Vec<u8>>,
327 ) {
328 let parallel_digest = D::new();
329 let mut parallel_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
330 .take(data.len())
331 .collect::<Vec<_>>();
332 parallel_digest.digest(data.par_iter(), &mut parallel_results);
333
334 let serial_results = data.iter().map(<D::Digest as Digest>::digest);
335
336 for (parallel, serial) in izip!(parallel_results, serial_results) {
337 assert_eq!(unsafe { parallel.assume_init() }, serial);
338 }
339 }
340
341 #[test]
342 fn test_empty_data() {
343 let data = generate_mock_data(0, 16);
344 check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
345 }
346
347 #[test]
348 fn test_non_empty_data() {
349 for n_hashes in [1, 2, 4, 8, 9] {
350 let data = generate_mock_data(n_hashes, 16);
351 check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
352 }
353 }
354
355 #[test]
356 fn test_adapter_matches_serial_sha256() {
357 use sha2::Sha256;
358
359 for n_hashes in [0, 1, 2, 4, 8, 9, 100] {
360 let data = generate_mock_data(n_hashes, 16);
361
362 let adapter = ParallelDigestAdapter::<Sha256>::new();
363 let mut results = repeat_with(MaybeUninit::<Output<Sha256>>::uninit)
364 .take(data.len())
365 .collect::<Vec<_>>();
366 adapter.digest(data.par_iter(), &mut results);
367
368 for (result, leaf) in results.into_iter().zip(&data) {
369 assert_eq!(unsafe { result.assume_init() }, <Sha256 as Digest>::digest(leaf));
370 }
371 }
372 }
373}