binius_hash/vision/
digest.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, mem::MaybeUninit};
4
5use binius_field::{
6	linear_transformation::Transformation, make_aes_to_binary_packed_transformer,
7	make_binary_to_aes_packed_transformer, underlier::WithUnderlier, AesToBinaryTransformation,
8	BinaryField8b, BinaryToAesTransformation, ByteSlicedAES32x32b, PackedAESBinaryField8x32b,
9	PackedBinaryField8x32b, PackedExtensionIndexable, PackedField, PackedFieldIndexable,
10};
11use digest::{
12	consts::{U32, U96},
13	core_api::BlockSizeUser,
14	FixedOutput, FixedOutputReset, HashMarker, OutputSizeUser, Reset, Update,
15};
16use lazy_static::lazy_static;
17use stackalloc::helpers::slice_assume_init_mut;
18
19use super::permutation::{HASHES_PER_BYTE_SLICED_PERMUTATION, PERMUTATION};
20use crate::{
21	multi_digest::{MultiDigest, ParallelMulidigestImpl},
22	permutation::Permutation,
23};
24
25const RATE_AS_U32: usize = 16;
26const RATE_AS_U8: usize = RATE_AS_U32 * std::mem::size_of::<u32>();
27
28const PADDING_START: u8 = 0x80;
29const PADDING_END: u8 = 0x01;
30
31lazy_static! {
32	static ref TRANS_AES_TO_CANONICAL: AesToBinaryTransformation<PackedAESBinaryField8x32b, PackedBinaryField8x32b> =
33		make_aes_to_binary_packed_transformer::<PackedAESBinaryField8x32b, PackedBinaryField8x32b>();
34	static ref TRANS_CANONICAL_TO_AES: BinaryToAesTransformation<PackedBinaryField8x32b, PackedAESBinaryField8x32b> =
35		make_binary_to_aes_packed_transformer::<PackedBinaryField8x32b, PackedAESBinaryField8x32b>();
36
37
38	// Padding block for the case when the input is a multiple of the rate.
39	static ref PADDING_BLOCK: [u8; RATE_AS_U8] = {
40		let mut block = [0; RATE_AS_U8];
41		block[0] = PADDING_START;
42		block[RATE_AS_U8 - 1] |= PADDING_END;
43		block
44	};
45}
46
47#[derive(Clone)]
48pub struct VisionHasherDigest {
49	// The hashed state
50	state: [PackedAESBinaryField8x32b; 3],
51	buffer: [u8; RATE_AS_U8],
52	filled_bytes: usize,
53}
54
55impl Default for VisionHasherDigest {
56	fn default() -> Self {
57		Self {
58			state: [PackedAESBinaryField8x32b::zero(); 3],
59			buffer: [0; RATE_AS_U8],
60			filled_bytes: 0,
61		}
62	}
63}
64
65impl VisionHasherDigest {
66	fn permute(state: &mut [PackedAESBinaryField8x32b; 3], data: &[u8]) {
67		debug_assert_eq!(data.len(), RATE_AS_U8);
68
69		let mut data_packed = [PackedBinaryField8x32b::zero(); 2];
70		for (i, value_32) in WithUnderlier::to_underliers_ref_mut(
71			PackedBinaryField8x32b::unpack_scalars_mut(&mut data_packed),
72		)
73		.iter_mut()
74		.enumerate()
75		{
76			*value_32 =
77				u32::from_le_bytes(data[i * 4..i * 4 + 4].try_into().expect("chunk is 4 bytes"));
78		}
79
80		for i in 0..2 {
81			state[i] = TRANS_CANONICAL_TO_AES.transform(&data_packed[i]);
82		}
83
84		PERMUTATION.permute_mut(state);
85	}
86
87	fn finalize_into(&mut self, out: &mut digest::Output<Self>) {
88		if self.filled_bytes != 0 {
89			fill_padding(&mut self.buffer[self.filled_bytes..]);
90			Self::permute(&mut self.state, &self.buffer);
91		} else {
92			Self::permute(&mut self.state, &*PADDING_BLOCK);
93		}
94
95		let canonical_tower: PackedBinaryField8x32b =
96			TRANS_AES_TO_CANONICAL.transform(&self.state[0]);
97		out.copy_from_slice(BinaryField8b::to_underliers_ref(
98			PackedBinaryField8x32b::unpack_base_scalars(std::slice::from_ref(&canonical_tower)),
99		));
100	}
101}
102
103impl HashMarker for VisionHasherDigest {}
104
105impl Update for VisionHasherDigest {
106	fn update(&mut self, mut data: &[u8]) {
107		if self.filled_bytes != 0 {
108			let to_copy = std::cmp::min(data.len(), RATE_AS_U8 - self.filled_bytes);
109			self.buffer[self.filled_bytes..self.filled_bytes + to_copy]
110				.copy_from_slice(&data[..to_copy]);
111			data = &data[to_copy..];
112			self.filled_bytes += to_copy;
113
114			if self.filled_bytes == RATE_AS_U8 {
115				Self::permute(&mut self.state, &self.buffer);
116				self.filled_bytes = 0;
117			}
118		}
119
120		let mut chunks = data.chunks_exact(RATE_AS_U8);
121		for chunk in &mut chunks {
122			Self::permute(&mut self.state, chunk);
123		}
124
125		let remaining = chunks.remainder();
126		if !remaining.is_empty() {
127			self.buffer[..remaining.len()].copy_from_slice(remaining);
128			self.filled_bytes = remaining.len();
129		}
130	}
131}
132
133impl OutputSizeUser for VisionHasherDigest {
134	type OutputSize = U32;
135}
136
137impl BlockSizeUser for VisionHasherDigest {
138	type BlockSize = U96;
139}
140
141impl FixedOutput for VisionHasherDigest {
142	fn finalize_into(mut self, out: &mut digest::Output<Self>) {
143		Self::finalize_into(&mut self, out);
144	}
145}
146
147impl Reset for VisionHasherDigest {
148	fn reset(&mut self) {
149		bytemuck::fill_zeroes(&mut self.state);
150		bytemuck::fill_zeroes(&mut self.buffer);
151		self.filled_bytes = 0;
152	}
153}
154
155impl FixedOutputReset for VisionHasherDigest {
156	fn finalize_into_reset(&mut self, out: &mut digest::Output<Self>) {
157		Self::finalize_into(self, out);
158		Reset::reset(self);
159	}
160}
161
162/// Fill the data using Keccak padding scheme.
163#[inline(always)]
164fn fill_padding(data: &mut [u8]) {
165	debug_assert!(!data.is_empty() && data.len() <= RATE_AS_U8);
166
167	data.fill(0);
168	data[0] |= PADDING_START;
169	data[data.len() - 1] |= PADDING_END;
170}
171
172#[derive(Clone)]
173pub struct VisionHasherDigestByteSliced {
174	// The hashed state
175	state: [ByteSlicedAES32x32b; 24],
176	// Buffer to hold the temporary data
177	buffer: [[u8; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
178	filled_bytes: usize,
179}
180
181impl Default for VisionHasherDigestByteSliced {
182	fn default() -> Self {
183		Self {
184			state: [ByteSlicedAES32x32b::zero(); 24],
185			buffer: [[0; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
186			filled_bytes: 0,
187		}
188	}
189}
190
191impl VisionHasherDigestByteSliced {
192	fn permute(
193		state: &mut [ByteSlicedAES32x32b; 24],
194		data: [&[u8; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
195	) {
196		for row in &data {
197			debug_assert_eq!(row.len(), RATE_AS_U8);
198		}
199
200		for state_element_index in 0..2 {
201			let data_offset = state_element_index * HASHES_PER_BYTE_SLICED_PERMUTATION;
202
203			for (i, state_element) in state[state_element_index * 8..state_element_index * 8 + 8]
204				.iter_mut()
205				.enumerate()
206			{
207				let ordinary_range_data: [PackedAESBinaryField8x32b; 4] = array::from_fn(|j| {
208					let canonical = PackedBinaryField8x32b::from_fn(|k| {
209						u32::from_le_bytes(
210							(data[j * 8 + k][data_offset + 4 * i..data_offset + 4 * i + 4])
211								.try_into()
212								.expect("chunk is 4 bytes"),
213						)
214						.into()
215					});
216
217					TRANS_CANONICAL_TO_AES.transform(&canonical)
218				});
219
220				*state_element = ByteSlicedAES32x32b::transpose_from(&ordinary_range_data);
221			}
222		}
223
224		PERMUTATION.permute_mut(state);
225	}
226
227	fn finalize(
228		&mut self,
229		out: &mut [MaybeUninit<digest::Output<VisionHasherDigest>>;
230			     HASHES_PER_BYTE_SLICED_PERMUTATION],
231	) {
232		if self.filled_bytes > 0 {
233			for row in 0..HASHES_PER_BYTE_SLICED_PERMUTATION {
234				fill_padding(&mut self.buffer[row][self.filled_bytes..]);
235			}
236
237			Self::permute(&mut self.state, array::from_fn(|i| &self.buffer[i]));
238		} else {
239			Self::permute(&mut self.state, array::from_fn(|_| &*PADDING_BLOCK));
240		}
241
242		// TODO: Use transposition function here as soon as it is merged in.
243		let out: &mut [digest::Output<VisionHasherDigest>; HASHES_PER_BYTE_SLICED_PERMUTATION] =
244			unsafe { slice_assume_init_mut(out) }
245				.try_into()
246				.expect("array is 32 elements");
247		for (i, state_data) in self.state[0..8].iter().enumerate() {
248			let mut transposed_aes = Default::default();
249			state_data.transpose_to(&mut transposed_aes);
250
251			for (j, transposed_aes) in transposed_aes.iter().enumerate() {
252				let transposed_canonical: PackedBinaryField8x32b =
253					TRANS_AES_TO_CANONICAL.transform(transposed_aes);
254				for (k, scalar) in transposed_canonical.iter().enumerate() {
255					out[j * 8 + k][i * 4..i * 4 + 4]
256						.copy_from_slice(&scalar.to_underlier().to_le_bytes());
257				}
258			}
259		}
260	}
261}
262
263impl MultiDigest<HASHES_PER_BYTE_SLICED_PERMUTATION> for VisionHasherDigestByteSliced {
264	type Digest = VisionHasherDigest;
265
266	fn new() -> Self {
267		Self::default()
268	}
269
270	fn update(&mut self, data: [&[u8]; HASHES_PER_BYTE_SLICED_PERMUTATION]) {
271		for row in 1..HASHES_PER_BYTE_SLICED_PERMUTATION {
272			debug_assert_eq!(data[row].len(), data[0].len());
273		}
274
275		let mut offset = if self.filled_bytes > 0 {
276			let to_copy = std::cmp::min(data[0].len(), RATE_AS_U8 - self.filled_bytes);
277			for (row_i, row) in data
278				.iter()
279				.enumerate()
280				.take(HASHES_PER_BYTE_SLICED_PERMUTATION)
281			{
282				self.buffer[row_i][self.filled_bytes..self.filled_bytes + to_copy]
283					.copy_from_slice(&row[..to_copy]);
284			}
285
286			self.filled_bytes += to_copy;
287
288			if self.filled_bytes == RATE_AS_U8 {
289				Self::permute(&mut self.state, array::from_fn(|i| &self.buffer[i]));
290				self.filled_bytes = 0;
291			}
292
293			to_copy
294		} else {
295			0
296		};
297
298		while offset + RATE_AS_U8 <= data[0].len() {
299			let chunk = array::from_fn(|i| {
300				(&data[i][offset..offset + RATE_AS_U8])
301					.try_into()
302					.expect("array is 32 bytes")
303			});
304			Self::permute(&mut self.state, chunk);
305			offset += RATE_AS_U8;
306		}
307
308		if offset < data[0].len() {
309			for (row_i, row) in data
310				.iter()
311				.enumerate()
312				.take(HASHES_PER_BYTE_SLICED_PERMUTATION)
313			{
314				self.buffer[row_i][..row.len() - offset].copy_from_slice(&row[offset..]);
315			}
316
317			self.filled_bytes = data[0].len() - offset;
318		}
319	}
320
321	fn finalize_into(
322		mut self,
323		out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
324	) {
325		self.finalize(out);
326	}
327
328	fn finalize_into_reset(
329		&mut self,
330		out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
331	) {
332		self.finalize(out);
333		self.reset();
334	}
335
336	fn reset(&mut self) {
337		bytemuck::fill_zeroes(&mut self.state);
338		self.filled_bytes = 0;
339	}
340
341	fn digest(
342		data: [&[u8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
343		out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
344	) {
345		let mut digest = Self::default();
346		digest.update(data);
347		digest.finalize_into(out);
348	}
349}
350
351pub type Vision32ParallelDigest =
352	ParallelMulidigestImpl<VisionHasherDigestByteSliced, HASHES_PER_BYTE_SLICED_PERMUTATION>;
353
354#[cfg(test)]
355mod tests {
356	use std::{array, mem::MaybeUninit};
357
358	use digest::Digest;
359	use hex_literal::hex;
360
361	use super::{
362		MultiDigest, VisionHasherDigest, VisionHasherDigestByteSliced,
363		HASHES_PER_BYTE_SLICED_PERMUTATION,
364	};
365
366	#[test]
367	fn test_simple_hash() {
368		let mut hasher = VisionHasherDigest::default();
369		let data = [0xde, 0xad, 0xbe, 0xef];
370		hasher.update(data);
371		let out = hasher.finalize();
372		// This hash is retrieved from a modified python implementation with the Keccak padding scheme and the changed mds matrix.
373		let expected = &hex!("8ed389809fabe91cead4786eb08e2d32647a9ac69143040de500e4465c72f173");
374		assert_eq!(expected, &*out);
375	}
376
377	#[test]
378	fn test_multi_block_aligned() {
379		let mut hasher = VisionHasherDigest::default();
380		let input = "One part of the mysterious existence of Captain Nemo had been unveiled and, if his identity had not been recognised, at least, the nations united against him were no longer hunting a chimerical creature, but a man who had vowed a deadly hatred against them";
381		hasher.update(input.as_bytes());
382		let out = hasher.finalize();
383
384		let expected = &hex!("b615664d0249149b5655a86919169f0fd4b44fec83d4c43e4f1f124c3f9a82c3");
385		assert_eq!(expected, &*out);
386
387		let mut hasher = VisionHasherDigest::default();
388		let input_as_b = input.as_bytes();
389		hasher.update(&input_as_b[0..63]);
390		hasher.update(&input_as_b[63..128]);
391		hasher.update(&input_as_b[128..163]);
392		hasher.update(&input_as_b[163..]);
393
394		assert_eq!(expected, &*hasher.finalize());
395	}
396
397	#[test]
398	fn test_multi_block_unaligned() {
399		let mut hasher = VisionHasherDigest::default();
400		let input = "You can prove anything you want by coldly logical reason--if you pick the proper postulates.";
401		hasher.update(input.as_bytes());
402
403		let expected = &hex!("0aa2879dcac953550ebe5d9da2a91d3c0356feca9044acf4edca87b28d9959e1");
404		let out = hasher.finalize();
405		assert_eq!(expected, &*out);
406	}
407
408	fn check_multihash_consistency(chunks: &[[&[u8]; 32]]) {
409		let mut scalar_digests = array::from_fn::<_, 32, _>(|_| VisionHasherDigest::default());
410		let mut multidigest = VisionHasherDigestByteSliced::default();
411
412		for chunk in chunks {
413			for (scalar_digest, data) in scalar_digests.iter_mut().zip(chunk.iter()) {
414				scalar_digest.update(data);
415			}
416
417			multidigest.update(*chunk);
418		}
419
420		let scalar_digests = scalar_digests.map(|d| d.finalize());
421		let mut output = [MaybeUninit::uninit(); 32];
422		multidigest.finalize_into(&mut output);
423		let output = unsafe { array::from_fn::<_, 4, _>(|i| output[i].assume_init()) };
424
425		for i in 0..4 {
426			assert_eq!(&*scalar_digests[i], &*output[i]);
427		}
428	}
429
430	#[test]
431	fn test_multihash_consistency_small_data() {
432		let data = array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| {
433			[i as u8, (i + 1) as _, (i + 2) as _, (i + 3) as _]
434		});
435
436		check_multihash_consistency(&[array::from_fn::<
437			_,
438			{ HASHES_PER_BYTE_SLICED_PERMUTATION },
439			_,
440		>(|i| &data[i][..])]);
441	}
442
443	#[test]
444	fn test_multihash_consistency_small_rate() {
445		let data =
446			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8, 64]);
447
448		check_multihash_consistency(&[array::from_fn::<
449			_,
450			{ HASHES_PER_BYTE_SLICED_PERMUTATION },
451			_,
452		>(|i| &data[i][..])]);
453	}
454
455	#[test]
456	fn test_multihash_consistency_large_rate() {
457		let data =
458			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8; 1024]);
459
460		check_multihash_consistency(&[array::from_fn::<
461			_,
462			{ HASHES_PER_BYTE_SLICED_PERMUTATION },
463			_,
464		>(|i| &data[i][..])]);
465	}
466
467	#[test]
468	fn test_multihash_consistency_several_chunks() {
469		let data_0 =
470			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8, 48]);
471		let data_1 =
472			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [(i + 1) as u8, 64]);
473		let data_2 = array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| {
474			[(i + 2) as u8, 128]
475		});
476
477		check_multihash_consistency(&[
478			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_0[i][..]),
479			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_1[i][..]),
480			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_2[i][..]),
481		]);
482	}
483}