1use std::{array, mem::MaybeUninit};
4
5use binius_field::{
6 AesToBinaryTransformation, BinaryField8b, BinaryToAesTransformation, ByteSlicedAES32x32b,
7 PackedAESBinaryField8x32b, PackedBinaryField8x32b, PackedExtensionIndexable, PackedField,
8 PackedFieldIndexable, linear_transformation::Transformation,
9 make_aes_to_binary_packed_transformer, make_binary_to_aes_packed_transformer,
10 underlier::WithUnderlier,
11};
12use binius_utils::mem::slice_assume_init_mut;
13use digest::{
14 FixedOutput, FixedOutputReset, HashMarker, OutputSizeUser, Reset, Update,
15 consts::{U32, U96},
16 core_api::BlockSizeUser,
17};
18use lazy_static::lazy_static;
19
20use super::permutation::{HASHES_PER_BYTE_SLICED_PERMUTATION, PERMUTATION};
21use crate::{
22 multi_digest::{MultiDigest, ParallelMultidigestImpl},
23 permutation::Permutation,
24};
25
26const RATE_AS_U32: usize = 16;
27const RATE_AS_U8: usize = RATE_AS_U32 * std::mem::size_of::<u32>();
28
29const PADDING_START: u8 = 0x80;
30const PADDING_END: u8 = 0x01;
31
32lazy_static! {
33 static ref TRANS_AES_TO_CANONICAL: AesToBinaryTransformation<PackedAESBinaryField8x32b, PackedBinaryField8x32b> =
34 make_aes_to_binary_packed_transformer::<PackedAESBinaryField8x32b, PackedBinaryField8x32b>();
35 static ref TRANS_CANONICAL_TO_AES: BinaryToAesTransformation<PackedBinaryField8x32b, PackedAESBinaryField8x32b> =
36 make_binary_to_aes_packed_transformer::<PackedBinaryField8x32b, PackedAESBinaryField8x32b>();
37
38
39 static ref PADDING_BLOCK: [u8; RATE_AS_U8] = {
41 let mut block = [0; RATE_AS_U8];
42 block[0] = PADDING_START;
43 block[RATE_AS_U8 - 1] |= PADDING_END;
44 block
45 };
46}
47
48#[derive(Clone)]
49pub struct VisionHasherDigest {
50 state: [PackedAESBinaryField8x32b; 3],
52 buffer: [u8; RATE_AS_U8],
53 filled_bytes: usize,
54}
55
56impl Default for VisionHasherDigest {
57 fn default() -> Self {
58 Self {
59 state: [PackedAESBinaryField8x32b::zero(); 3],
60 buffer: [0; RATE_AS_U8],
61 filled_bytes: 0,
62 }
63 }
64}
65
66impl VisionHasherDigest {
67 fn permute(state: &mut [PackedAESBinaryField8x32b; 3], data: &[u8]) {
68 debug_assert_eq!(data.len(), RATE_AS_U8);
69
70 let mut data_packed = [PackedBinaryField8x32b::zero(); 2];
71 for (i, value_32) in WithUnderlier::to_underliers_ref_mut(
72 PackedBinaryField8x32b::unpack_scalars_mut(&mut data_packed),
73 )
74 .iter_mut()
75 .enumerate()
76 {
77 *value_32 =
78 u32::from_le_bytes(data[i * 4..i * 4 + 4].try_into().expect("chunk is 4 bytes"));
79 }
80
81 for i in 0..2 {
82 state[i] = TRANS_CANONICAL_TO_AES.transform(&data_packed[i]);
83 }
84
85 PERMUTATION.permute_mut(state);
86 }
87
88 fn finalize_into(&mut self, out: &mut digest::Output<Self>) {
89 if self.filled_bytes != 0 {
90 fill_padding(&mut self.buffer[self.filled_bytes..]);
91 Self::permute(&mut self.state, &self.buffer);
92 } else {
93 Self::permute(&mut self.state, &*PADDING_BLOCK);
94 }
95
96 let canonical_tower: PackedBinaryField8x32b =
97 TRANS_AES_TO_CANONICAL.transform(&self.state[0]);
98 out.copy_from_slice(BinaryField8b::to_underliers_ref(
99 PackedBinaryField8x32b::unpack_base_scalars(std::slice::from_ref(&canonical_tower)),
100 ));
101 }
102}
103
104impl HashMarker for VisionHasherDigest {}
105
106impl Update for VisionHasherDigest {
107 fn update(&mut self, mut data: &[u8]) {
108 if self.filled_bytes != 0 {
109 let to_copy = std::cmp::min(data.len(), RATE_AS_U8 - self.filled_bytes);
110 self.buffer[self.filled_bytes..self.filled_bytes + to_copy]
111 .copy_from_slice(&data[..to_copy]);
112 data = &data[to_copy..];
113 self.filled_bytes += to_copy;
114
115 if self.filled_bytes == RATE_AS_U8 {
116 Self::permute(&mut self.state, &self.buffer);
117 self.filled_bytes = 0;
118 }
119 }
120
121 let mut chunks = data.chunks_exact(RATE_AS_U8);
122 for chunk in &mut chunks {
123 Self::permute(&mut self.state, chunk);
124 }
125
126 let remaining = chunks.remainder();
127 if !remaining.is_empty() {
128 self.buffer[..remaining.len()].copy_from_slice(remaining);
129 self.filled_bytes = remaining.len();
130 }
131 }
132}
133
134impl OutputSizeUser for VisionHasherDigest {
135 type OutputSize = U32;
136}
137
138impl BlockSizeUser for VisionHasherDigest {
139 type BlockSize = U96;
140}
141
142impl FixedOutput for VisionHasherDigest {
143 fn finalize_into(mut self, out: &mut digest::Output<Self>) {
144 Self::finalize_into(&mut self, out);
145 }
146}
147
148impl Reset for VisionHasherDigest {
149 fn reset(&mut self) {
150 bytemuck::fill_zeroes(&mut self.state);
151 bytemuck::fill_zeroes(&mut self.buffer);
152 self.filled_bytes = 0;
153 }
154}
155
156impl FixedOutputReset for VisionHasherDigest {
157 fn finalize_into_reset(&mut self, out: &mut digest::Output<Self>) {
158 Self::finalize_into(self, out);
159 Reset::reset(self);
160 }
161}
162
163#[inline(always)]
165fn fill_padding(data: &mut [u8]) {
166 debug_assert!(!data.is_empty() && data.len() <= RATE_AS_U8);
167
168 data.fill(0);
169 data[0] |= PADDING_START;
170 data[data.len() - 1] |= PADDING_END;
171}
172
173#[derive(Clone)]
174pub struct VisionHasherDigestByteSliced {
175 state: [ByteSlicedAES32x32b; 24],
177 buffer: [[u8; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
179 filled_bytes: usize,
180}
181
182impl Default for VisionHasherDigestByteSliced {
183 fn default() -> Self {
184 Self {
185 state: [ByteSlicedAES32x32b::zero(); 24],
186 buffer: [[0; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
187 filled_bytes: 0,
188 }
189 }
190}
191
192impl VisionHasherDigestByteSliced {
193 fn permute(
194 state: &mut [ByteSlicedAES32x32b; 24],
195 data: [&[u8; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
196 ) {
197 for row in &data {
198 debug_assert_eq!(row.len(), RATE_AS_U8);
199 }
200
201 for state_element_index in 0..2 {
202 let data_offset = state_element_index * HASHES_PER_BYTE_SLICED_PERMUTATION;
203
204 for (i, state_element) in state[state_element_index * 8..state_element_index * 8 + 8]
205 .iter_mut()
206 .enumerate()
207 {
208 let ordinary_range_data: [PackedAESBinaryField8x32b; 4] = array::from_fn(|j| {
209 let canonical = PackedBinaryField8x32b::from_fn(|k| {
210 u32::from_le_bytes(
211 (data[j * 8 + k][data_offset + 4 * i..data_offset + 4 * i + 4])
212 .try_into()
213 .expect("chunk is 4 bytes"),
214 )
215 .into()
216 });
217
218 TRANS_CANONICAL_TO_AES.transform(&canonical)
219 });
220
221 *state_element = ByteSlicedAES32x32b::transpose_from(&ordinary_range_data);
222 }
223 }
224
225 PERMUTATION.permute_mut(state);
226 }
227
228 fn finalize(
229 &mut self,
230 out: &mut [MaybeUninit<digest::Output<VisionHasherDigest>>;
231 HASHES_PER_BYTE_SLICED_PERMUTATION],
232 ) {
233 if self.filled_bytes > 0 {
234 for row in 0..HASHES_PER_BYTE_SLICED_PERMUTATION {
235 fill_padding(&mut self.buffer[row][self.filled_bytes..]);
236 }
237
238 Self::permute(&mut self.state, array::from_fn(|i| &self.buffer[i]));
239 } else {
240 Self::permute(&mut self.state, array::from_fn(|_| &*PADDING_BLOCK));
241 }
242
243 let out: &mut [digest::Output<VisionHasherDigest>; HASHES_PER_BYTE_SLICED_PERMUTATION] =
245 unsafe { slice_assume_init_mut(out) }
246 .try_into()
247 .expect("array is 32 elements");
248 for (i, state_data) in self.state[0..8].iter().enumerate() {
249 let mut transposed_aes = Default::default();
250 state_data.transpose_to(&mut transposed_aes);
251
252 for (j, transposed_aes) in transposed_aes.iter().enumerate() {
253 let transposed_canonical: PackedBinaryField8x32b =
254 TRANS_AES_TO_CANONICAL.transform(transposed_aes);
255 for (k, scalar) in transposed_canonical.iter().enumerate() {
256 out[j * 8 + k][i * 4..i * 4 + 4]
257 .copy_from_slice(&scalar.to_underlier().to_le_bytes());
258 }
259 }
260 }
261 }
262}
263
264impl MultiDigest<HASHES_PER_BYTE_SLICED_PERMUTATION> for VisionHasherDigestByteSliced {
265 type Digest = VisionHasherDigest;
266
267 fn new() -> Self {
268 Self::default()
269 }
270
271 fn update(&mut self, data: [&[u8]; HASHES_PER_BYTE_SLICED_PERMUTATION]) {
272 for row in 1..HASHES_PER_BYTE_SLICED_PERMUTATION {
273 debug_assert_eq!(data[row].len(), data[0].len());
274 }
275
276 let mut offset = if self.filled_bytes > 0 {
277 let to_copy = std::cmp::min(data[0].len(), RATE_AS_U8 - self.filled_bytes);
278 for (row_i, row) in data
279 .iter()
280 .enumerate()
281 .take(HASHES_PER_BYTE_SLICED_PERMUTATION)
282 {
283 self.buffer[row_i][self.filled_bytes..self.filled_bytes + to_copy]
284 .copy_from_slice(&row[..to_copy]);
285 }
286
287 self.filled_bytes += to_copy;
288
289 if self.filled_bytes == RATE_AS_U8 {
290 Self::permute(&mut self.state, array::from_fn(|i| &self.buffer[i]));
291 self.filled_bytes = 0;
292 }
293
294 to_copy
295 } else {
296 0
297 };
298
299 while offset + RATE_AS_U8 <= data[0].len() {
300 let chunk = array::from_fn(|i| {
301 (&data[i][offset..offset + RATE_AS_U8])
302 .try_into()
303 .expect("array is 32 bytes")
304 });
305 Self::permute(&mut self.state, chunk);
306 offset += RATE_AS_U8;
307 }
308
309 if offset < data[0].len() {
310 for (row_i, row) in data
311 .iter()
312 .enumerate()
313 .take(HASHES_PER_BYTE_SLICED_PERMUTATION)
314 {
315 self.buffer[row_i][..row.len() - offset].copy_from_slice(&row[offset..]);
316 }
317
318 self.filled_bytes = data[0].len() - offset;
319 }
320 }
321
322 fn finalize_into(
323 mut self,
324 out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
325 ) {
326 self.finalize(out);
327 }
328
329 fn finalize_into_reset(
330 &mut self,
331 out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
332 ) {
333 self.finalize(out);
334 self.reset();
335 }
336
337 fn reset(&mut self) {
338 bytemuck::fill_zeroes(&mut self.state);
339 self.filled_bytes = 0;
340 }
341
342 fn digest(
343 data: [&[u8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
344 out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
345 ) {
346 let mut digest = Self::default();
347 digest.update(data);
348 digest.finalize_into(out);
349 }
350}
351
352pub type Vision32ParallelDigest =
353 ParallelMultidigestImpl<VisionHasherDigestByteSliced, HASHES_PER_BYTE_SLICED_PERMUTATION>;
354
355#[cfg(test)]
356mod tests {
357 use std::{array, mem::MaybeUninit};
358
359 use digest::Digest;
360 use hex_literal::hex;
361
362 use super::{
363 HASHES_PER_BYTE_SLICED_PERMUTATION, MultiDigest, VisionHasherDigest,
364 VisionHasherDigestByteSliced,
365 };
366
367 #[test]
368 fn test_simple_hash() {
369 let mut hasher = VisionHasherDigest::default();
370 let data = [0xde, 0xad, 0xbe, 0xef];
371 hasher.update(data);
372 let out = hasher.finalize();
373 let expected = &hex!("8ed389809fabe91cead4786eb08e2d32647a9ac69143040de500e4465c72f173");
376 assert_eq!(expected, &*out);
377 }
378
379 #[test]
380 fn test_multi_block_aligned() {
381 let mut hasher = VisionHasherDigest::default();
382 let input = "One part of the mysterious existence of Captain Nemo had been unveiled and, if his identity had not been recognised, at least, the nations united against him were no longer hunting a chimerical creature, but a man who had vowed a deadly hatred against them";
383 hasher.update(input.as_bytes());
384 let out = hasher.finalize();
385
386 let expected = &hex!("b615664d0249149b5655a86919169f0fd4b44fec83d4c43e4f1f124c3f9a82c3");
387 assert_eq!(expected, &*out);
388
389 let mut hasher = VisionHasherDigest::default();
390 let input_as_b = input.as_bytes();
391 hasher.update(&input_as_b[0..63]);
392 hasher.update(&input_as_b[63..128]);
393 hasher.update(&input_as_b[128..163]);
394 hasher.update(&input_as_b[163..]);
395
396 assert_eq!(expected, &*hasher.finalize());
397 }
398
399 #[test]
400 fn test_multi_block_unaligned() {
401 let mut hasher = VisionHasherDigest::default();
402 let input = "You can prove anything you want by coldly logical reason--if you pick the proper postulates.";
403 hasher.update(input.as_bytes());
404
405 let expected = &hex!("0aa2879dcac953550ebe5d9da2a91d3c0356feca9044acf4edca87b28d9959e1");
406 let out = hasher.finalize();
407 assert_eq!(expected, &*out);
408 }
409
410 fn check_multihash_consistency(chunks: &[[&[u8]; 32]]) {
411 let mut scalar_digests = array::from_fn::<_, 32, _>(|_| VisionHasherDigest::default());
412 let mut multidigest = VisionHasherDigestByteSliced::default();
413
414 for chunk in chunks {
415 for (scalar_digest, data) in scalar_digests.iter_mut().zip(chunk.iter()) {
416 scalar_digest.update(data);
417 }
418
419 multidigest.update(*chunk);
420 }
421
422 let scalar_digests = scalar_digests.map(|d| d.finalize());
423 let mut output = [MaybeUninit::uninit(); 32];
424 multidigest.finalize_into(&mut output);
425 let output = unsafe { array::from_fn::<_, 4, _>(|i| output[i].assume_init()) };
426
427 for i in 0..4 {
428 assert_eq!(&*scalar_digests[i], &*output[i]);
429 }
430 }
431
432 #[test]
433 fn test_multihash_consistency_small_data() {
434 let data = array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| {
435 [i as u8, (i + 1) as _, (i + 2) as _, (i + 3) as _]
436 });
437
438 check_multihash_consistency(&[array::from_fn::<
439 _,
440 { HASHES_PER_BYTE_SLICED_PERMUTATION },
441 _,
442 >(|i| &data[i][..])]);
443 }
444
445 #[test]
446 fn test_multihash_consistency_small_rate() {
447 let data =
448 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8, 64]);
449
450 check_multihash_consistency(&[array::from_fn::<
451 _,
452 { HASHES_PER_BYTE_SLICED_PERMUTATION },
453 _,
454 >(|i| &data[i][..])]);
455 }
456
457 #[test]
458 fn test_multihash_consistency_large_rate() {
459 let data =
460 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8; 1024]);
461
462 check_multihash_consistency(&[array::from_fn::<
463 _,
464 { HASHES_PER_BYTE_SLICED_PERMUTATION },
465 _,
466 >(|i| &data[i][..])]);
467 }
468
469 #[test]
470 fn test_multihash_consistency_several_chunks() {
471 let data_0 =
472 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8, 48]);
473 let data_1 =
474 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [(i + 1) as u8, 64]);
475 let data_2 = array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| {
476 [(i + 2) as u8, 128]
477 });
478
479 check_multihash_consistency(&[
480 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_0[i][..]),
481 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_1[i][..]),
482 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_2[i][..]),
483 ]);
484 }
485}