1use std::{array, mem::MaybeUninit};
4
5use binius_field::{
6 linear_transformation::Transformation, make_aes_to_binary_packed_transformer,
7 make_binary_to_aes_packed_transformer, underlier::WithUnderlier, AesToBinaryTransformation,
8 BinaryField8b, BinaryToAesTransformation, ByteSlicedAES32x32b, PackedAESBinaryField8x32b,
9 PackedBinaryField8x32b, PackedExtensionIndexable, PackedField, PackedFieldIndexable,
10};
11use digest::{
12 consts::{U32, U96},
13 core_api::BlockSizeUser,
14 FixedOutput, FixedOutputReset, HashMarker, OutputSizeUser, Reset, Update,
15};
16use lazy_static::lazy_static;
17use stackalloc::helpers::slice_assume_init_mut;
18
19use super::permutation::{HASHES_PER_BYTE_SLICED_PERMUTATION, PERMUTATION};
20use crate::{
21 multi_digest::{MultiDigest, ParallelMulidigestImpl},
22 permutation::Permutation,
23};
24
25const RATE_AS_U32: usize = 16;
26const RATE_AS_U8: usize = RATE_AS_U32 * std::mem::size_of::<u32>();
27
28const PADDING_START: u8 = 0x80;
29const PADDING_END: u8 = 0x01;
30
31lazy_static! {
32 static ref TRANS_AES_TO_CANONICAL: AesToBinaryTransformation<PackedAESBinaryField8x32b, PackedBinaryField8x32b> =
33 make_aes_to_binary_packed_transformer::<PackedAESBinaryField8x32b, PackedBinaryField8x32b>();
34 static ref TRANS_CANONICAL_TO_AES: BinaryToAesTransformation<PackedBinaryField8x32b, PackedAESBinaryField8x32b> =
35 make_binary_to_aes_packed_transformer::<PackedBinaryField8x32b, PackedAESBinaryField8x32b>();
36
37
38 static ref PADDING_BLOCK: [u8; RATE_AS_U8] = {
40 let mut block = [0; RATE_AS_U8];
41 block[0] = PADDING_START;
42 block[RATE_AS_U8 - 1] |= PADDING_END;
43 block
44 };
45}
46
47#[derive(Clone)]
48pub struct VisionHasherDigest {
49 state: [PackedAESBinaryField8x32b; 3],
51 buffer: [u8; RATE_AS_U8],
52 filled_bytes: usize,
53}
54
55impl Default for VisionHasherDigest {
56 fn default() -> Self {
57 Self {
58 state: [PackedAESBinaryField8x32b::zero(); 3],
59 buffer: [0; RATE_AS_U8],
60 filled_bytes: 0,
61 }
62 }
63}
64
65impl VisionHasherDigest {
66 fn permute(state: &mut [PackedAESBinaryField8x32b; 3], data: &[u8]) {
67 debug_assert_eq!(data.len(), RATE_AS_U8);
68
69 let mut data_packed = [PackedBinaryField8x32b::zero(); 2];
70 for (i, value_32) in WithUnderlier::to_underliers_ref_mut(
71 PackedBinaryField8x32b::unpack_scalars_mut(&mut data_packed),
72 )
73 .iter_mut()
74 .enumerate()
75 {
76 *value_32 =
77 u32::from_le_bytes(data[i * 4..i * 4 + 4].try_into().expect("chunk is 4 bytes"));
78 }
79
80 for i in 0..2 {
81 state[i] = TRANS_CANONICAL_TO_AES.transform(&data_packed[i]);
82 }
83
84 PERMUTATION.permute_mut(state);
85 }
86
87 fn finalize_into(&mut self, out: &mut digest::Output<Self>) {
88 if self.filled_bytes != 0 {
89 fill_padding(&mut self.buffer[self.filled_bytes..]);
90 Self::permute(&mut self.state, &self.buffer);
91 } else {
92 Self::permute(&mut self.state, &*PADDING_BLOCK);
93 }
94
95 let canonical_tower: PackedBinaryField8x32b =
96 TRANS_AES_TO_CANONICAL.transform(&self.state[0]);
97 out.copy_from_slice(BinaryField8b::to_underliers_ref(
98 PackedBinaryField8x32b::unpack_base_scalars(std::slice::from_ref(&canonical_tower)),
99 ));
100 }
101}
102
103impl HashMarker for VisionHasherDigest {}
104
105impl Update for VisionHasherDigest {
106 fn update(&mut self, mut data: &[u8]) {
107 if self.filled_bytes != 0 {
108 let to_copy = std::cmp::min(data.len(), RATE_AS_U8 - self.filled_bytes);
109 self.buffer[self.filled_bytes..self.filled_bytes + to_copy]
110 .copy_from_slice(&data[..to_copy]);
111 data = &data[to_copy..];
112 self.filled_bytes += to_copy;
113
114 if self.filled_bytes == RATE_AS_U8 {
115 Self::permute(&mut self.state, &self.buffer);
116 self.filled_bytes = 0;
117 }
118 }
119
120 let mut chunks = data.chunks_exact(RATE_AS_U8);
121 for chunk in &mut chunks {
122 Self::permute(&mut self.state, chunk);
123 }
124
125 let remaining = chunks.remainder();
126 if !remaining.is_empty() {
127 self.buffer[..remaining.len()].copy_from_slice(remaining);
128 self.filled_bytes = remaining.len();
129 }
130 }
131}
132
133impl OutputSizeUser for VisionHasherDigest {
134 type OutputSize = U32;
135}
136
137impl BlockSizeUser for VisionHasherDigest {
138 type BlockSize = U96;
139}
140
141impl FixedOutput for VisionHasherDigest {
142 fn finalize_into(mut self, out: &mut digest::Output<Self>) {
143 Self::finalize_into(&mut self, out);
144 }
145}
146
147impl Reset for VisionHasherDigest {
148 fn reset(&mut self) {
149 bytemuck::fill_zeroes(&mut self.state);
150 bytemuck::fill_zeroes(&mut self.buffer);
151 self.filled_bytes = 0;
152 }
153}
154
155impl FixedOutputReset for VisionHasherDigest {
156 fn finalize_into_reset(&mut self, out: &mut digest::Output<Self>) {
157 Self::finalize_into(self, out);
158 Reset::reset(self);
159 }
160}
161
162#[inline(always)]
164fn fill_padding(data: &mut [u8]) {
165 debug_assert!(!data.is_empty() && data.len() <= RATE_AS_U8);
166
167 data.fill(0);
168 data[0] |= PADDING_START;
169 data[data.len() - 1] |= PADDING_END;
170}
171
172#[derive(Clone)]
173pub struct VisionHasherDigestByteSliced {
174 state: [ByteSlicedAES32x32b; 24],
176 buffer: [[u8; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
178 filled_bytes: usize,
179}
180
181impl Default for VisionHasherDigestByteSliced {
182 fn default() -> Self {
183 Self {
184 state: [ByteSlicedAES32x32b::zero(); 24],
185 buffer: [[0; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
186 filled_bytes: 0,
187 }
188 }
189}
190
191impl VisionHasherDigestByteSliced {
192 fn permute(
193 state: &mut [ByteSlicedAES32x32b; 24],
194 data: [&[u8; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
195 ) {
196 for row in &data {
197 debug_assert_eq!(row.len(), RATE_AS_U8);
198 }
199
200 for state_element_index in 0..2 {
201 let data_offset = state_element_index * HASHES_PER_BYTE_SLICED_PERMUTATION;
202
203 for (i, state_element) in state[state_element_index * 8..state_element_index * 8 + 8]
204 .iter_mut()
205 .enumerate()
206 {
207 let ordinary_range_data: [PackedAESBinaryField8x32b; 4] = array::from_fn(|j| {
208 let canonical = PackedBinaryField8x32b::from_fn(|k| {
209 u32::from_le_bytes(
210 (data[j * 8 + k][data_offset + 4 * i..data_offset + 4 * i + 4])
211 .try_into()
212 .expect("chunk is 4 bytes"),
213 )
214 .into()
215 });
216
217 TRANS_CANONICAL_TO_AES.transform(&canonical)
218 });
219
220 *state_element = ByteSlicedAES32x32b::transpose_from(&ordinary_range_data);
221 }
222 }
223
224 PERMUTATION.permute_mut(state);
225 }
226
227 fn finalize(
228 &mut self,
229 out: &mut [MaybeUninit<digest::Output<VisionHasherDigest>>;
230 HASHES_PER_BYTE_SLICED_PERMUTATION],
231 ) {
232 if self.filled_bytes > 0 {
233 for row in 0..HASHES_PER_BYTE_SLICED_PERMUTATION {
234 fill_padding(&mut self.buffer[row][self.filled_bytes..]);
235 }
236
237 Self::permute(&mut self.state, array::from_fn(|i| &self.buffer[i]));
238 } else {
239 Self::permute(&mut self.state, array::from_fn(|_| &*PADDING_BLOCK));
240 }
241
242 let out: &mut [digest::Output<VisionHasherDigest>; HASHES_PER_BYTE_SLICED_PERMUTATION] =
244 unsafe { slice_assume_init_mut(out) }
245 .try_into()
246 .expect("array is 32 elements");
247 for (i, state_data) in self.state[0..8].iter().enumerate() {
248 let mut transposed_aes = Default::default();
249 state_data.transpose_to(&mut transposed_aes);
250
251 for (j, transposed_aes) in transposed_aes.iter().enumerate() {
252 let transposed_canonical: PackedBinaryField8x32b =
253 TRANS_AES_TO_CANONICAL.transform(transposed_aes);
254 for (k, scalar) in transposed_canonical.iter().enumerate() {
255 out[j * 8 + k][i * 4..i * 4 + 4]
256 .copy_from_slice(&scalar.to_underlier().to_le_bytes());
257 }
258 }
259 }
260 }
261}
262
263impl MultiDigest<HASHES_PER_BYTE_SLICED_PERMUTATION> for VisionHasherDigestByteSliced {
264 type Digest = VisionHasherDigest;
265
266 fn new() -> Self {
267 Self::default()
268 }
269
270 fn update(&mut self, data: [&[u8]; HASHES_PER_BYTE_SLICED_PERMUTATION]) {
271 for row in 1..HASHES_PER_BYTE_SLICED_PERMUTATION {
272 debug_assert_eq!(data[row].len(), data[0].len());
273 }
274
275 let mut offset = if self.filled_bytes > 0 {
276 let to_copy = std::cmp::min(data[0].len(), RATE_AS_U8 - self.filled_bytes);
277 for (row_i, row) in data
278 .iter()
279 .enumerate()
280 .take(HASHES_PER_BYTE_SLICED_PERMUTATION)
281 {
282 self.buffer[row_i][self.filled_bytes..self.filled_bytes + to_copy]
283 .copy_from_slice(&row[..to_copy]);
284 }
285
286 self.filled_bytes += to_copy;
287
288 if self.filled_bytes == RATE_AS_U8 {
289 Self::permute(&mut self.state, array::from_fn(|i| &self.buffer[i]));
290 self.filled_bytes = 0;
291 }
292
293 to_copy
294 } else {
295 0
296 };
297
298 while offset + RATE_AS_U8 <= data[0].len() {
299 let chunk = array::from_fn(|i| {
300 (&data[i][offset..offset + RATE_AS_U8])
301 .try_into()
302 .expect("array is 32 bytes")
303 });
304 Self::permute(&mut self.state, chunk);
305 offset += RATE_AS_U8;
306 }
307
308 if offset < data[0].len() {
309 for (row_i, row) in data
310 .iter()
311 .enumerate()
312 .take(HASHES_PER_BYTE_SLICED_PERMUTATION)
313 {
314 self.buffer[row_i][..row.len() - offset].copy_from_slice(&row[offset..]);
315 }
316
317 self.filled_bytes = data[0].len() - offset;
318 }
319 }
320
321 fn finalize_into(
322 mut self,
323 out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
324 ) {
325 self.finalize(out);
326 }
327
328 fn finalize_into_reset(
329 &mut self,
330 out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
331 ) {
332 self.finalize(out);
333 self.reset();
334 }
335
336 fn reset(&mut self) {
337 bytemuck::fill_zeroes(&mut self.state);
338 self.filled_bytes = 0;
339 }
340
341 fn digest(
342 data: [&[u8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
343 out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
344 ) {
345 let mut digest = Self::default();
346 digest.update(data);
347 digest.finalize_into(out);
348 }
349}
350
351pub type Vision32ParallelDigest =
352 ParallelMulidigestImpl<VisionHasherDigestByteSliced, HASHES_PER_BYTE_SLICED_PERMUTATION>;
353
354#[cfg(test)]
355mod tests {
356 use std::{array, mem::MaybeUninit};
357
358 use digest::Digest;
359 use hex_literal::hex;
360
361 use super::{
362 MultiDigest, VisionHasherDigest, VisionHasherDigestByteSliced,
363 HASHES_PER_BYTE_SLICED_PERMUTATION,
364 };
365
366 #[test]
367 fn test_simple_hash() {
368 let mut hasher = VisionHasherDigest::default();
369 let data = [0xde, 0xad, 0xbe, 0xef];
370 hasher.update(data);
371 let out = hasher.finalize();
372 let expected = &hex!("8ed389809fabe91cead4786eb08e2d32647a9ac69143040de500e4465c72f173");
374 assert_eq!(expected, &*out);
375 }
376
377 #[test]
378 fn test_multi_block_aligned() {
379 let mut hasher = VisionHasherDigest::default();
380 let input = "One part of the mysterious existence of Captain Nemo had been unveiled and, if his identity had not been recognised, at least, the nations united against him were no longer hunting a chimerical creature, but a man who had vowed a deadly hatred against them";
381 hasher.update(input.as_bytes());
382 let out = hasher.finalize();
383
384 let expected = &hex!("b615664d0249149b5655a86919169f0fd4b44fec83d4c43e4f1f124c3f9a82c3");
385 assert_eq!(expected, &*out);
386
387 let mut hasher = VisionHasherDigest::default();
388 let input_as_b = input.as_bytes();
389 hasher.update(&input_as_b[0..63]);
390 hasher.update(&input_as_b[63..128]);
391 hasher.update(&input_as_b[128..163]);
392 hasher.update(&input_as_b[163..]);
393
394 assert_eq!(expected, &*hasher.finalize());
395 }
396
397 #[test]
398 fn test_multi_block_unaligned() {
399 let mut hasher = VisionHasherDigest::default();
400 let input = "You can prove anything you want by coldly logical reason--if you pick the proper postulates.";
401 hasher.update(input.as_bytes());
402
403 let expected = &hex!("0aa2879dcac953550ebe5d9da2a91d3c0356feca9044acf4edca87b28d9959e1");
404 let out = hasher.finalize();
405 assert_eq!(expected, &*out);
406 }
407
408 fn check_multihash_consistency(chunks: &[[&[u8]; 32]]) {
409 let mut scalar_digests = array::from_fn::<_, 32, _>(|_| VisionHasherDigest::default());
410 let mut multidigest = VisionHasherDigestByteSliced::default();
411
412 for chunk in chunks {
413 for (scalar_digest, data) in scalar_digests.iter_mut().zip(chunk.iter()) {
414 scalar_digest.update(data);
415 }
416
417 multidigest.update(*chunk);
418 }
419
420 let scalar_digests = scalar_digests.map(|d| d.finalize());
421 let mut output = [MaybeUninit::uninit(); 32];
422 multidigest.finalize_into(&mut output);
423 let output = unsafe { array::from_fn::<_, 4, _>(|i| output[i].assume_init()) };
424
425 for i in 0..4 {
426 assert_eq!(&*scalar_digests[i], &*output[i]);
427 }
428 }
429
430 #[test]
431 fn test_multihash_consistency_small_data() {
432 let data = array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| {
433 [i as u8, (i + 1) as _, (i + 2) as _, (i + 3) as _]
434 });
435
436 check_multihash_consistency(&[array::from_fn::<
437 _,
438 { HASHES_PER_BYTE_SLICED_PERMUTATION },
439 _,
440 >(|i| &data[i][..])]);
441 }
442
443 #[test]
444 fn test_multihash_consistency_small_rate() {
445 let data =
446 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8, 64]);
447
448 check_multihash_consistency(&[array::from_fn::<
449 _,
450 { HASHES_PER_BYTE_SLICED_PERMUTATION },
451 _,
452 >(|i| &data[i][..])]);
453 }
454
455 #[test]
456 fn test_multihash_consistency_large_rate() {
457 let data =
458 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8; 1024]);
459
460 check_multihash_consistency(&[array::from_fn::<
461 _,
462 { HASHES_PER_BYTE_SLICED_PERMUTATION },
463 _,
464 >(|i| &data[i][..])]);
465 }
466
467 #[test]
468 fn test_multihash_consistency_several_chunks() {
469 let data_0 =
470 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8, 48]);
471 let data_1 =
472 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [(i + 1) as u8, 64]);
473 let data_2 = array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| {
474 [(i + 2) as u8, 128]
475 });
476
477 check_multihash_consistency(&[
478 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_0[i][..]),
479 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_1[i][..]),
480 array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_2[i][..]),
481 ]);
482 }
483}