1use std::{array, mem::MaybeUninit};
4
5use binius_field::TowerField;
6use binius_maybe_rayon::{
7 iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
8 slice::ParallelSliceMut,
9};
10use binius_utils::{SerializationMode, SerializeBytes};
11use bytes::{BufMut, BytesMut};
12use digest::{Digest, Output, core_api::BlockSizeUser};
13
14use crate::HashBuffer;
15
16pub trait MultiDigest<const N: usize>: Clone {
24 type Digest: Digest;
26
27 fn new() -> Self;
29
30 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
32 let mut hasher = Self::new();
33 hasher.update([data.as_ref(); N]);
34 hasher
35 }
36
37 fn update(&mut self, data: [&[u8]; N]);
40
41 #[must_use]
43 fn chain_update(self, data: [&[u8]; N]) -> Self {
44 let mut hasher = self;
45 hasher.update(data);
46 hasher
47 }
48
49 fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
51
52 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
54
55 fn reset(&mut self);
57
58 fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
64}
65
66pub trait Serializable {
67 fn serialize(self, buffer: impl BufMut);
68}
69
70impl<F: TowerField, I: IntoIterator<Item = F>> Serializable for I {
71 fn serialize(self, mut buffer: impl BufMut) {
72 let mode = SerializationMode::CanonicalTower;
73 for elem in self {
74 SerializeBytes::serialize(&elem, &mut buffer, mode)
75 .expect("buffer must have enough capacity");
76 }
77 }
78}
79
80pub trait ParallelDigest: Send {
81 type Digest: digest::Digest + Send;
83
84 fn new() -> Self;
86
87 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self;
89
90 fn digest(
93 &self,
94 source: impl IndexedParallelIterator<Item: Serializable>,
95 out: &mut [MaybeUninit<Output<Self::Digest>>],
96 );
97}
98
99#[derive(Clone)]
101pub struct ParallelMultidigestImpl<D: MultiDigest<N>, const N: usize>(D);
102
103impl<D: MultiDigest<N, Digest: Send> + Send + Sync, const N: usize> ParallelDigest
104 for ParallelMultidigestImpl<D, N>
105{
106 type Digest = D::Digest;
107
108 fn new() -> Self {
109 Self(D::new())
110 }
111
112 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
113 Self(D::new_with_prefix(data.as_ref()))
114 }
115
116 fn digest(
117 &self,
118 source: impl IndexedParallelIterator<Item: Serializable>,
119 out: &mut [MaybeUninit<Output<Self::Digest>>],
120 ) {
121 let buffers = array::from_fn::<_, N, _>(|_| BytesMut::new());
122 source.chunks(N).zip(out.par_chunks_mut(N)).for_each_with(
123 buffers,
124 |buffers, (data, out_chunk)| {
125 let mut hasher = self.0.clone();
126 for (buf, chunk) in buffers.iter_mut().zip(data.into_iter()) {
127 buf.clear();
128 chunk.serialize(buf);
129 }
130 let data = array::from_fn(|i| buffers[i].as_ref());
131 hasher.update(data);
132
133 if out_chunk.len() == N {
134 hasher
135 .finalize_into_reset(out_chunk.try_into().expect("chunk size is correct"));
136 } else {
137 let mut result = array::from_fn::<_, N, _>(|_| MaybeUninit::uninit());
138 hasher.finalize_into(&mut result);
139 for (out, res) in out_chunk.iter_mut().zip(result.into_iter()) {
140 out.write(unsafe { res.assume_init() });
141 }
142 }
143 },
144 );
145 }
146}
147
148impl<D: Digest + BlockSizeUser + Send + Sync + Clone> ParallelDigest for D {
149 type Digest = D;
150
151 fn new() -> Self {
152 Digest::new()
153 }
154
155 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
156 Digest::new_with_prefix(data)
157 }
158
159 fn digest(
160 &self,
161 source: impl IndexedParallelIterator<Item: Serializable>,
162 out: &mut [MaybeUninit<Output<Self::Digest>>],
163 ) {
164 source.zip(out.par_iter_mut()).for_each(|(data, out)| {
165 let mut hasher = self.clone();
166 {
167 let mut buffer = HashBuffer::new(&mut hasher);
168 data.serialize(&mut buffer);
169 }
170 out.write(hasher.finalize());
171 });
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use std::iter::repeat_with;
178
179 use binius_maybe_rayon::iter::IntoParallelRefIterator;
180 use digest::{
181 FixedOutput, HashMarker, OutputSizeUser, Reset, Update,
182 consts::{U1, U32},
183 };
184 use itertools::izip;
185 use rand::{RngCore, SeedableRng, rngs::StdRng};
186
187 use super::*;
188
189 #[derive(Clone, Default)]
190 struct MockDigest {
191 state: u8,
192 }
193
194 impl HashMarker for MockDigest {}
195
196 impl Update for MockDigest {
197 fn update(&mut self, data: &[u8]) {
198 for &byte in data {
199 self.state ^= byte;
200 }
201 }
202 }
203
204 impl Reset for MockDigest {
205 fn reset(&mut self) {
206 self.state = 0;
207 }
208 }
209
210 impl OutputSizeUser for MockDigest {
211 type OutputSize = U32;
212 }
213
214 impl BlockSizeUser for MockDigest {
215 type BlockSize = U1;
216 }
217
218 impl FixedOutput for MockDigest {
219 fn finalize_into(self, out: &mut digest::Output<Self>) {
220 out[0] = self.state;
221 for byte in &mut out[1..] {
222 *byte = 0;
223 }
224 }
225 }
226
227 #[derive(Clone, Default)]
228 struct MockMultiDigest {
229 digests: [MockDigest; 4],
230 }
231
232 impl MultiDigest<4> for MockMultiDigest {
233 type Digest = MockDigest;
234
235 fn new() -> Self {
236 Self::default()
237 }
238
239 fn update(&mut self, data: [&[u8]; 4]) {
240 for (digest, &chunk) in self.digests.iter_mut().zip(data.iter()) {
241 digest::Digest::update(digest, chunk);
242 }
243 }
244
245 fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
246 for (digest, out) in self.digests.into_iter().zip(out.iter_mut()) {
247 let mut output = digest::Output::<Self::Digest>::default();
248 digest::Digest::finalize_into(digest, &mut output);
249 *out = MaybeUninit::new(output);
250 }
251 }
252
253 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
254 for (digest, out) in self.digests.iter_mut().zip(out.iter_mut()) {
255 let mut digest_copy = MockDigest::default();
256 std::mem::swap(digest, &mut digest_copy);
257 *out = MaybeUninit::new(digest_copy.finalize());
258 }
259 self.reset();
260 }
261
262 fn reset(&mut self) {
263 for digest in &mut self.digests {
264 *digest = MockDigest::default();
265 }
266 }
267
268 fn digest(data: [&[u8]; 4], out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
269 let mut hasher = Self::default();
270 hasher.update(data);
271 hasher.finalize_into(out);
272 }
273 }
274
275 struct DataWrapper(Vec<u8>);
276
277 impl Serializable for &DataWrapper {
278 fn serialize(self, mut buffer: impl BufMut) {
279 buffer.put_slice(&self.0);
280 }
281 }
282
283 fn generate_mock_data(n_hashes: usize, chunk_size: usize) -> Vec<DataWrapper> {
284 let mut rng = StdRng::seed_from_u64(0);
285
286 (0..n_hashes)
287 .map(|_| {
288 let mut chunk = vec![0; chunk_size];
289 rng.fill_bytes(&mut chunk);
290
291 DataWrapper(chunk)
292 })
293 .collect()
294 }
295
296 fn check_parallel_digest_consistency<
297 D: ParallelDigest<Digest: BlockSizeUser + Send + Sync + Clone>,
298 >(
299 data: Vec<DataWrapper>,
300 ) {
301 let parallel_digest = D::new();
302 let mut parallel_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
303 .take(data.len())
304 .collect::<Vec<_>>();
305 parallel_digest.digest(data.par_iter(), &mut parallel_results);
306
307 let single_digest_as_parallel = <D::Digest as ParallelDigest>::new();
308 let mut single_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
309 .take(data.len())
310 .collect::<Vec<_>>();
311 single_digest_as_parallel.digest(data.par_iter(), &mut single_results);
312
313 let serial_results = data
314 .iter()
315 .map(move |data| <D::Digest as digest::Digest>::digest(&data.0));
316
317 for (parallel, single, serial) in izip!(parallel_results, single_results, serial_results) {
318 assert_eq!(unsafe { parallel.assume_init() }, serial);
319 assert_eq!(unsafe { single.assume_init() }, serial);
320 }
321 }
322
323 #[test]
324 fn test_empty_data() {
325 let data = generate_mock_data(0, 16);
326 check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
327 }
328
329 #[test]
330 fn test_non_empty_data() {
331 for n_hashes in [1, 2, 4, 8, 9] {
332 let data = generate_mock_data(n_hashes, 16);
333 check_parallel_digest_consistency::<ParallelMultidigestImpl<MockMultiDigest, 4>>(data);
334 }
335 }
336}