binius_hash/vision_6/
parallel_digest.rs1use std::{array, mem::MaybeUninit};
5
6use binius_field::{BinaryField128bGhash as Ghash, Field};
7use binius_utils::{DeserializeBytes, SerializeBytes};
8use digest::Output;
9
10use super::{
11 constants::M,
12 digest::{PADDING_BLOCK, RATE_AS_U8, RATE_AS_U128, VisionHasherDigest, fill_padding},
13 parallel_permutation::batch_permutation,
14};
15use crate::parallel_digest::MultiDigest;
16
17#[derive(Clone)]
26pub struct VisionHasherMultiDigest<const N: usize, const MN: usize> {
27 states: [Ghash; MN],
28 buffers: [[u8; RATE_AS_U8]; N],
29 filled_bytes: usize,
30}
31
32impl<const N: usize, const MN: usize> Default for VisionHasherMultiDigest<N, MN> {
33 fn default() -> Self {
34 assert!(N.is_power_of_two() && N >= 2, "N must be a power of 2 and >= 2");
35 assert_eq!(MN, M * N);
36 Self {
37 states: array::from_fn(|_| Ghash::ZERO),
38 buffers: array::from_fn(|_| [0; RATE_AS_U8]),
39 filled_bytes: 0,
40 }
41 }
42}
43
44impl<const N: usize, const MN: usize> VisionHasherMultiDigest<N, MN> {
45 #[inline]
46 fn advance_data(data: &mut [&[u8]; N], bytes: usize) {
47 for i in 0..N {
48 data[i] = &data[i][bytes..];
49 }
50 }
51
52 fn permute(states: &mut [Ghash; MN], data: [&[u8]; N]) {
53 for (i, data) in data.iter().enumerate() {
54 debug_assert_eq!(data.len(), RATE_AS_U8);
55
56 let state_start = i * M;
58 for j in 0..RATE_AS_U128 {
59 let element_bytes = &data[j * (128 / 8)..];
60 states[state_start + j] =
61 Ghash::deserialize(element_bytes).expect("data len checked");
62 }
63 }
64
65 batch_permutation::<N, MN>(states);
66 }
67 fn finalize(&mut self, out: &mut [MaybeUninit<digest::Output<VisionHasherDigest>>; N]) {
68 if self.filled_bytes != 0 {
69 for i in 0..N {
70 fill_padding(&mut self.buffers[i][self.filled_bytes..]);
71 }
72 Self::permute(&mut self.states, array::from_fn(|i| &self.buffers[i][..]));
73 } else {
74 Self::permute(&mut self.states, array::from_fn(|_| &PADDING_BLOCK[..]));
75 }
76
77 for i in 0..N {
79 let output_slice = out[i].as_mut_ptr() as *mut u8;
80 let output_bytes = unsafe { std::slice::from_raw_parts_mut(output_slice, 32) };
81 let (state0, state1) = output_bytes.split_at_mut(16);
82 self.states[i * M]
83 .serialize(state0)
84 .expect("fits in 16 bytes");
85 self.states[i * M + 1]
86 .serialize(state1)
87 .expect("fits in 16 bytes");
88 }
89 }
90}
91
92impl<const N: usize, const MN: usize> MultiDigest<N> for VisionHasherMultiDigest<N, MN> {
93 type Digest = VisionHasherDigest;
94
95 fn new() -> Self {
96 Self::default()
97 }
98
99 fn update(&mut self, mut data: [&[u8]; N]) {
100 data[1..].iter().for_each(|row| {
101 assert_eq!(row.len(), data[0].len());
102 });
103
104 if self.filled_bytes != 0 {
105 let to_copy = std::cmp::min(data[0].len(), RATE_AS_U8 - self.filled_bytes);
106 data.iter().enumerate().for_each(|(row_i, row)| {
107 self.buffers[row_i][self.filled_bytes..self.filled_bytes + to_copy]
108 .copy_from_slice(&row[..to_copy]);
109 });
110 Self::advance_data(&mut data, to_copy);
111 self.filled_bytes += to_copy;
112
113 if self.filled_bytes == RATE_AS_U8 {
114 Self::permute(&mut self.states, array::from_fn(|i| &self.buffers[i][..]));
115 self.filled_bytes = 0;
116 }
117 }
118
119 while data[0].len() >= RATE_AS_U8 {
120 let chunks = array::from_fn(|i| &data[i][..RATE_AS_U8]);
121 Self::permute(&mut self.states, chunks);
122 Self::advance_data(&mut data, RATE_AS_U8);
123 }
124
125 if !data[0].is_empty() {
126 data.iter().enumerate().for_each(|(row_i, row)| {
127 self.buffers[row_i][..row.len()].copy_from_slice(row);
128 });
129 self.filled_bytes = data[0].len();
130 }
131 }
132
133 fn finalize_into(mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
134 self.finalize(out);
135 }
136
137 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
138 self.finalize(out);
139 self.reset();
140 }
141
142 fn reset(&mut self) {
143 self.states = array::from_fn(|_| Ghash::ZERO);
144 self.buffers = array::from_fn(|_| [0; RATE_AS_U8]);
145 self.filled_bytes = 0;
146 }
147
148 fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
149 let mut digest = Self::default();
150 digest.update(data);
151 digest.finalize_into(out);
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use std::mem::MaybeUninit;
158
159 use digest::Digest;
160 use rand::{Rng, SeedableRng, rngs::StdRng};
161
162 use super::*;
163
164 fn generate_random_data<const N: usize>(length: usize, seed: u64) -> Vec<Vec<u8>> {
166 let mut rng = StdRng::seed_from_u64(seed);
167 let mut data_vecs = Vec::new();
168 for _ in 0..N {
169 let mut vec = Vec::with_capacity(length);
170 for _ in 0..length {
171 vec.push(rng.random());
172 }
173 data_vecs.push(vec);
174 }
175 data_vecs
176 }
177
178 fn test_parallel_vs_sequential<const N: usize, const MN: usize>(
180 data: [&[u8]; N],
181 description: &str,
182 ) {
183 let mut parallel_outputs = [MaybeUninit::uninit(); N];
185 VisionHasherMultiDigest::<N, MN>::digest(data, &mut parallel_outputs);
186 let parallel_results: [Output<VisionHasherDigest>; N] =
187 array::from_fn(|i| unsafe { parallel_outputs[i].assume_init() });
188
189 let sequential_results: [_; N] = array::from_fn(|i| {
191 let mut hasher = VisionHasherDigest::new();
192 hasher.update(data[i]);
193 hasher.finalize()
194 });
195
196 for i in 0..N {
198 assert_eq!(
199 parallel_results[i], sequential_results[i],
200 "Mismatch at index {i} for {description}"
201 );
202 }
203 }
204
205 #[test]
206 fn test_empty_inputs() {
207 const N: usize = 4;
208 let data: [&[u8]; N] = [&[], &[], &[], &[]];
209 test_parallel_vs_sequential::<N, { N * M }>(data, "empty inputs");
210 }
211
212 #[test]
213 fn test_small_inputs() {
214 const N: usize = 2;
215 let data: [&[u8]; N] = [b"Hello... World!", b"Rust is awesome"];
216 test_parallel_vs_sequential::<N, { N * M }>(data, "small inputs");
217 }
218
219 #[test]
220 fn test_multi_block() {
221 const N: usize = 4;
222 let target_len = RATE_AS_U8 * 2 + 10;
224 let data_vecs = generate_random_data::<N>(target_len, 42);
225 let data: [&[u8]; N] = array::from_fn(|i| data_vecs[i].as_slice());
226
227 test_parallel_vs_sequential::<N, { N * M }>(data, "multi-block inputs");
228 }
229
230 #[test]
231 fn test_various_sizes() {
232 let sizes = [
234 1,
235 RATE_AS_U8 - 7,
236 RATE_AS_U8,
237 RATE_AS_U8 + 5,
238 RATE_AS_U8 * 2 - 3,
239 ];
240
241 for &size in &sizes {
242 const N: usize = 2;
243 let data_vecs = generate_random_data::<N>(size, 123);
244 let data: [&[u8]; N] = array::from_fn(|i| data_vecs[i].as_slice());
245 test_parallel_vs_sequential::<N, { N * M }>(data, &format!("size {size}"));
246 }
247 }
248}