1use std::{array, mem::MaybeUninit};
5
6use binius_field::{BinaryField128bGhash as Ghash, Field};
7use binius_utils::SerializeBytes;
8use digest::Output;
9use itertools::izip;
10
11use crate::{
12 parallel_digest::MultiDigest,
13 vision_4::{
14 M,
15 digest::{PADDING_BLOCK, RATE_AS_U8, VisionHasherDigest, fill_padding},
16 },
17};
18
19#[derive(Clone)]
28pub struct VisionHasherMultiDigest<const N: usize> {
29 states: [[Ghash; M]; N],
30 buffers: [[u8; RATE_AS_U8]; N],
31 filled_bytes: usize,
32}
33
34impl<const N: usize> Default for VisionHasherMultiDigest<N> {
35 fn default() -> Self {
36 Self {
37 states: array::from_fn(|_| [Ghash::ZERO; M]),
38 buffers: array::from_fn(|_| [0; RATE_AS_U8]),
39 filled_bytes: 0,
40 }
41 }
42}
43
44impl<const N: usize> VisionHasherMultiDigest<N> {
45 #[inline]
46 fn advance_data(data: &mut [&[u8]; N], bytes: usize) {
47 for i in 0..N {
48 data[i] = &data[i][bytes..];
49 }
50 }
51
52 fn permute(states: &mut [[Ghash; M]; N], data: [&[u8]; N]) {
53 izip!(states, data).for_each(|(state, data)| {
55 VisionHasherDigest::permute(state, data);
56 });
57 }
58 fn finalize(&mut self, out: &mut [MaybeUninit<digest::Output<VisionHasherDigest>>; N]) {
59 if self.filled_bytes != 0 {
60 for i in 0..N {
61 fill_padding(&mut self.buffers[i][self.filled_bytes..]);
62 }
63 Self::permute(&mut self.states, array::from_fn(|i| &self.buffers[i][..]));
64 } else {
65 Self::permute(&mut self.states, array::from_fn(|_| &PADDING_BLOCK[..]));
66 }
67
68 for i in 0..N {
70 let output_slice = out[i].as_mut_ptr() as *mut u8;
71 let output_bytes = unsafe { std::slice::from_raw_parts_mut(output_slice, 32) };
72 let (state0, state1) = output_bytes.split_at_mut(16);
73 self.states[i][0]
74 .serialize(state0)
75 .expect("fits in 16 bytes");
76 self.states[i][1]
77 .serialize(state1)
78 .expect("fits in 16 bytes");
79 }
80 }
81}
82
83impl<const N: usize> MultiDigest<N> for VisionHasherMultiDigest<N> {
84 type Digest = VisionHasherDigest;
85
86 fn new() -> Self {
87 Self::default()
88 }
89
90 fn update(&mut self, mut data: [&[u8]; N]) {
91 data[1..].iter().for_each(|row| {
92 assert_eq!(row.len(), data[0].len());
93 });
94
95 if self.filled_bytes != 0 {
96 let to_copy = std::cmp::min(data[0].len(), RATE_AS_U8 - self.filled_bytes);
97 data.iter().enumerate().for_each(|(row_i, row)| {
98 self.buffers[row_i][self.filled_bytes..self.filled_bytes + to_copy]
99 .copy_from_slice(&row[..to_copy]);
100 });
101 Self::advance_data(&mut data, to_copy);
102 self.filled_bytes += to_copy;
103
104 if self.filled_bytes == RATE_AS_U8 {
105 Self::permute(&mut self.states, array::from_fn(|i| &self.buffers[i][..]));
106 self.filled_bytes = 0;
107 }
108 }
109
110 while data[0].len() >= RATE_AS_U8 {
111 let chunks = array::from_fn(|i| &data[i][..RATE_AS_U8]);
112 Self::permute(&mut self.states, chunks);
113 Self::advance_data(&mut data, RATE_AS_U8);
114 }
115
116 if !data[0].is_empty() {
117 data.iter().enumerate().for_each(|(row_i, row)| {
118 self.buffers[row_i][..row.len()].copy_from_slice(row);
119 });
120 self.filled_bytes = data[0].len();
121 }
122 }
123
124 fn finalize_into(mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
125 self.finalize(out);
126 }
127
128 fn finalize_into_reset(&mut self, out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
129 self.finalize(out);
130 self.reset();
131 }
132
133 fn reset(&mut self) {
134 self.states = array::from_fn(|_| [Ghash::ZERO; M]);
135 self.buffers = array::from_fn(|_| [0; RATE_AS_U8]);
136 self.filled_bytes = 0;
137 }
138
139 fn digest(data: [&[u8]; N], out: &mut [MaybeUninit<Output<Self::Digest>>; N]) {
140 let mut digest = Self::default();
141 digest.update(data);
142 digest.finalize_into(out);
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use std::mem::MaybeUninit;
149
150 use digest::Digest;
151 use rand::{Rng, SeedableRng, rngs::StdRng};
152
153 use super::*;
154
155 fn generate_random_data<const N: usize>(length: usize, seed: u64) -> Vec<Vec<u8>> {
157 let mut rng = StdRng::seed_from_u64(seed);
158 let mut data_vecs = Vec::new();
159 for _ in 0..N {
160 let mut vec = Vec::with_capacity(length);
161 for _ in 0..length {
162 vec.push(rng.random());
163 }
164 data_vecs.push(vec);
165 }
166 data_vecs
167 }
168
169 fn test_parallel_vs_sequential<const N: usize>(data: [&[u8]; N], description: &str) {
171 let mut parallel_outputs = [MaybeUninit::uninit(); N];
173 VisionHasherMultiDigest::<N>::digest(data, &mut parallel_outputs);
174 let parallel_results: [Output<VisionHasherDigest>; N] =
175 array::from_fn(|i| unsafe { parallel_outputs[i].assume_init() });
176
177 let sequential_results: [_; N] = array::from_fn(|i| {
179 let mut hasher = VisionHasherDigest::new();
180 hasher.update(data[i]);
181 hasher.finalize()
182 });
183
184 for i in 0..N {
186 assert_eq!(
187 parallel_results[i], sequential_results[i],
188 "Mismatch at index {i} for {description}"
189 );
190 }
191 }
192
193 #[test]
194 fn test_empty_inputs() {
195 const N: usize = 4;
196 let data: [&[u8]; N] = [&[], &[], &[], &[]];
197 test_parallel_vs_sequential(data, "empty inputs");
198 }
199
200 #[test]
201 fn test_small_inputs() {
202 const N: usize = 3;
203 let data: [&[u8]; N] = [b"Hello... World!", b"Rust is awesome", b"Parallel hash!!"];
204 test_parallel_vs_sequential(data, "small inputs");
205 }
206
207 #[test]
208 fn test_multi_block() {
209 const N: usize = 4;
210 let target_len = RATE_AS_U8 * 2 + 10;
212 let data_vecs = generate_random_data::<N>(target_len, 42);
213 let data: [&[u8]; N] = array::from_fn(|i| data_vecs[i].as_slice());
214
215 test_parallel_vs_sequential(data, "multi-block inputs");
216 }
217
218 #[test]
219 fn test_various_sizes() {
220 let sizes = [
222 1,
223 RATE_AS_U8 - 7,
224 RATE_AS_U8,
225 RATE_AS_U8 + 5,
226 RATE_AS_U8 * 2 - 3,
227 ];
228
229 for &size in &sizes {
230 const N: usize = 3;
231 let data_vecs = generate_random_data::<N>(size, 123);
232 let data: [&[u8]; N] = array::from_fn(|i| data_vecs[i].as_slice());
233 test_parallel_vs_sequential(data, &format!("size {size}"));
234 }
235 }
236}