1use std::{array, mem::MaybeUninit};
4
5use binius_field::TowerField;
6use binius_maybe_rayon::{
7 iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
8 slice::ParallelSliceMut,
9};
10use binius_utils::{SerializationMode, SerializeBytes};
11use bytes::{BufMut, BytesMut};
12use digest::{Digest, Output};
13
14pub trait MultiDigest<const N: usize>: Clone {
21 type Digest: Digest;
23
24 fn new() -> Self;
26
27 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
29 let mut hasher = Self::new();
30 hasher.update([data.as_ref(); N]);
31 hasher
32 }
33
34 fn update(&mut self, data: [&[u8]; N]);
37
38 #[must_use]
40 fn chain_update(self, data: [&[u8]; N]) -> Self {
41 let mut hasher = self;
42 hasher.update(data);
43 hasher
44 }
45
46 fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
48
49 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
51
52 fn reset(&mut self);
54
55 fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]);
61}
62
63pub trait Serializable {
64 fn serialize(self, buffer: impl BufMut);
65}
66
67impl<F: TowerField, I: IntoIterator<Item = F>> Serializable for I {
68 fn serialize(self, mut buffer: impl BufMut) {
69 let mode = SerializationMode::CanonicalTower;
70 for elem in self {
71 SerializeBytes::serialize(&elem, &mut buffer, mode)
72 .expect("buffer must have enough capacity");
73 }
74 }
75}
76
77pub trait ParallelDigest: Send {
78 type Digest: digest::Digest + Send;
80
81 fn new() -> Self;
83
84 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self;
86
87 fn digest(
90 &self,
91 source: impl IndexedParallelIterator<Item: Serializable>,
92 out: &mut [MaybeUninit<Output<Self::Digest>>],
93 );
94}
95
96#[derive(Clone)]
98pub struct ParallelMulidigestImpl<D: MultiDigest<N>, const N: usize>(D);
99
100impl<D: MultiDigest<N, Digest: Send> + Send + Sync, const N: usize> ParallelDigest
101 for ParallelMulidigestImpl<D, N>
102{
103 type Digest = D::Digest;
104
105 fn new() -> Self {
106 Self(D::new())
107 }
108
109 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
110 Self(D::new_with_prefix(data.as_ref()))
111 }
112
113 fn digest(
114 &self,
115 source: impl IndexedParallelIterator<Item: Serializable>,
116 out: &mut [MaybeUninit<Output<Self::Digest>>],
117 ) {
118 let buffers = array::from_fn::<_, N, _>(|_| BytesMut::new());
119 source.chunks(N).zip(out.par_chunks_mut(N)).for_each_with(
120 buffers,
121 |buffers, (data, out_chunk)| {
122 let mut hasher = self.0.clone();
123 for (buf, chunk) in buffers.iter_mut().zip(data.into_iter()) {
124 buf.clear();
125 chunk.serialize(buf);
126 }
127 let data = array::from_fn(|i| buffers[i].as_ref());
128 hasher.update(data);
129
130 if out_chunk.len() == N {
131 hasher
132 .finalize_into_reset(out_chunk.try_into().expect("chunk size is correct"));
133 } else {
134 let mut result = array::from_fn::<_, N, _>(|_| MaybeUninit::uninit());
135 hasher.finalize_into(&mut result);
136 for (out, res) in out_chunk.iter_mut().zip(result.into_iter()) {
137 out.write(unsafe { res.assume_init() });
138 }
139 }
140 },
141 );
142 }
143}
144
145impl<D: Digest + Send + Sync + Clone> ParallelDigest for D {
146 type Digest = D;
147
148 fn new() -> Self {
149 Digest::new()
150 }
151
152 fn new_with_prefix(data: impl AsRef<[u8]>) -> Self {
153 Digest::new_with_prefix(data)
154 }
155
156 fn digest(
157 &self,
158 source: impl IndexedParallelIterator<Item: Serializable>,
159 out: &mut [MaybeUninit<Output<Self::Digest>>],
160 ) {
161 source
162 .zip(out.par_iter_mut())
163 .for_each_with(BytesMut::new(), |mut buffer, (data, out)| {
164 buffer.clear();
165 data.serialize(&mut buffer);
166
167 let mut hasher = self.clone();
168 hasher.update(buffer.as_ref());
169 out.write(hasher.finalize());
170 });
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use std::iter::repeat_with;
177
178 use binius_maybe_rayon::iter::IntoParallelRefIterator;
179 use digest::{consts::U32, Digest, FixedOutput, HashMarker, OutputSizeUser, Update};
180 use itertools::izip;
181 use rand::{rngs::StdRng, RngCore, SeedableRng};
182
183 use super::*;
184
185 #[derive(Clone, Default)]
186 struct MockDigest {
187 state: u8,
188 }
189
190 impl HashMarker for MockDigest {}
191
192 impl Update for MockDigest {
193 fn update(&mut self, data: &[u8]) {
194 for &byte in data {
195 self.state ^= byte;
196 }
197 }
198 }
199
200 impl OutputSizeUser for MockDigest {
201 type OutputSize = U32;
202 }
203
204 impl FixedOutput for MockDigest {
205 fn finalize_into(self, out: &mut digest::Output<Self>) {
206 out[0] = self.state;
207 for byte in &mut out[1..] {
208 *byte = 0;
209 }
210 }
211 }
212
213 #[derive(Clone, Default)]
214 struct MockMultiDigest {
215 digests: [MockDigest; 4],
216 }
217
218 impl MultiDigest<4> for MockMultiDigest {
219 type Digest = MockDigest;
220
221 fn new() -> Self {
222 Self::default()
223 }
224
225 fn update(&mut self, data: [&[u8]; 4]) {
226 for (digest, &chunk) in self.digests.iter_mut().zip(data.iter()) {
227 digest::Digest::update(digest, chunk);
228 }
229 }
230
231 fn finalize_into(self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
232 for (digest, out) in self.digests.into_iter().zip(out.iter_mut()) {
233 let mut output = digest::Output::<Self::Digest>::default();
234 digest::Digest::finalize_into(digest, &mut output);
235 *out = MaybeUninit::new(output);
236 }
237 }
238
239 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
240 for (digest, out) in self.digests.iter_mut().zip(out.iter_mut()) {
241 let mut digest_copy = MockDigest::default();
242 std::mem::swap(digest, &mut digest_copy);
243 *out = MaybeUninit::new(digest_copy.finalize());
244 }
245 self.reset();
246 }
247
248 fn reset(&mut self) {
249 for digest in &mut self.digests {
250 *digest = MockDigest::default();
251 }
252 }
253
254 fn digest(data: [&[u8]; 4], out: &mut [MaybeUninit<Output<Self::Digest>>; 4]) {
255 let mut hasher = Self::default();
256 hasher.update(data);
257 hasher.finalize_into(out);
258 }
259 }
260
261 struct DataWrapper(Vec<u8>);
262
263 impl Serializable for &DataWrapper {
264 fn serialize(self, mut buffer: impl BufMut) {
265 buffer.put_slice(&self.0);
266 }
267 }
268
269 fn generate_mock_data(n_hashes: usize, chunk_size: usize) -> Vec<DataWrapper> {
270 let mut rng = StdRng::seed_from_u64(0);
271
272 (0..n_hashes)
273 .map(|_| {
274 let mut chunk = vec![0; chunk_size];
275 rng.fill_bytes(&mut chunk);
276
277 DataWrapper(chunk)
278 })
279 .collect()
280 }
281
282 fn check_parallel_digest_consistency<D: ParallelDigest<Digest: Send + Sync + Clone>>(
283 data: Vec<DataWrapper>,
284 ) {
285 let parallel_digest = D::new();
286 let mut parallel_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
287 .take(data.len())
288 .collect::<Vec<_>>();
289 parallel_digest.digest(data.par_iter(), &mut parallel_results);
290
291 let single_digest_as_parallel = <D::Digest as ParallelDigest>::new();
292 let mut single_results = repeat_with(MaybeUninit::<Output<D::Digest>>::uninit)
293 .take(data.len())
294 .collect::<Vec<_>>();
295 single_digest_as_parallel.digest(data.par_iter(), &mut single_results);
296
297 let serial_results = data
298 .iter()
299 .map(move |data| <D::Digest as digest::Digest>::digest(&data.0));
300
301 for (parallel, single, serial) in izip!(parallel_results, single_results, serial_results) {
302 assert_eq!(unsafe { parallel.assume_init() }, serial);
303 assert_eq!(unsafe { single.assume_init() }, serial);
304 }
305 }
306
307 #[test]
308 fn test_empty_data() {
309 let data = generate_mock_data(0, 16);
310 check_parallel_digest_consistency::<ParallelMulidigestImpl<MockMultiDigest, 4>>(data);
311 }
312
313 #[test]
314 fn test_non_empty_data() {
315 for n_hashes in [1, 2, 4, 8, 9] {
316 let data = generate_mock_data(n_hashes, 16);
317 check_parallel_digest_consistency::<ParallelMulidigestImpl<MockMultiDigest, 4>>(data);
318 }
319 }
320}