binius_hash/vision/
digest.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, mem::MaybeUninit};
4
5use binius_field::{
6	AesToBinaryTransformation, BinaryField8b, BinaryToAesTransformation, ByteSlicedAES32x32b,
7	PackedAESBinaryField8x32b, PackedBinaryField8x32b, PackedExtensionIndexable, PackedField,
8	PackedFieldIndexable, linear_transformation::Transformation,
9	make_aes_to_binary_packed_transformer, make_binary_to_aes_packed_transformer,
10	underlier::WithUnderlier,
11};
12use binius_utils::mem::slice_assume_init_mut;
13use digest::{
14	FixedOutput, FixedOutputReset, HashMarker, OutputSizeUser, Reset, Update,
15	consts::{U32, U96},
16	core_api::BlockSizeUser,
17};
18use lazy_static::lazy_static;
19
20use super::permutation::{HASHES_PER_BYTE_SLICED_PERMUTATION, PERMUTATION};
21use crate::{
22	multi_digest::{MultiDigest, ParallelMultidigestImpl},
23	permutation::Permutation,
24};
25
26const RATE_AS_U32: usize = 16;
27const RATE_AS_U8: usize = RATE_AS_U32 * std::mem::size_of::<u32>();
28
29const PADDING_START: u8 = 0x80;
30const PADDING_END: u8 = 0x01;
31
32lazy_static! {
33	static ref TRANS_AES_TO_CANONICAL: AesToBinaryTransformation<PackedAESBinaryField8x32b, PackedBinaryField8x32b> =
34		make_aes_to_binary_packed_transformer::<PackedAESBinaryField8x32b, PackedBinaryField8x32b>();
35	static ref TRANS_CANONICAL_TO_AES: BinaryToAesTransformation<PackedBinaryField8x32b, PackedAESBinaryField8x32b> =
36		make_binary_to_aes_packed_transformer::<PackedBinaryField8x32b, PackedAESBinaryField8x32b>();
37
38
39	// Padding block for the case when the input is a multiple of the rate.
40	static ref PADDING_BLOCK: [u8; RATE_AS_U8] = {
41		let mut block = [0; RATE_AS_U8];
42		block[0] = PADDING_START;
43		block[RATE_AS_U8 - 1] |= PADDING_END;
44		block
45	};
46}
47
48#[derive(Clone)]
49pub struct VisionHasherDigest {
50	// The hashed state
51	state: [PackedAESBinaryField8x32b; 3],
52	buffer: [u8; RATE_AS_U8],
53	filled_bytes: usize,
54}
55
56impl Default for VisionHasherDigest {
57	fn default() -> Self {
58		Self {
59			state: [PackedAESBinaryField8x32b::zero(); 3],
60			buffer: [0; RATE_AS_U8],
61			filled_bytes: 0,
62		}
63	}
64}
65
66impl VisionHasherDigest {
67	fn permute(state: &mut [PackedAESBinaryField8x32b; 3], data: &[u8]) {
68		debug_assert_eq!(data.len(), RATE_AS_U8);
69
70		let mut data_packed = [PackedBinaryField8x32b::zero(); 2];
71		for (i, value_32) in WithUnderlier::to_underliers_ref_mut(
72			PackedBinaryField8x32b::unpack_scalars_mut(&mut data_packed),
73		)
74		.iter_mut()
75		.enumerate()
76		{
77			*value_32 =
78				u32::from_le_bytes(data[i * 4..i * 4 + 4].try_into().expect("chunk is 4 bytes"));
79		}
80
81		for i in 0..2 {
82			state[i] = TRANS_CANONICAL_TO_AES.transform(&data_packed[i]);
83		}
84
85		PERMUTATION.permute_mut(state);
86	}
87
88	fn finalize_into(&mut self, out: &mut digest::Output<Self>) {
89		if self.filled_bytes != 0 {
90			fill_padding(&mut self.buffer[self.filled_bytes..]);
91			Self::permute(&mut self.state, &self.buffer);
92		} else {
93			Self::permute(&mut self.state, &*PADDING_BLOCK);
94		}
95
96		let canonical_tower: PackedBinaryField8x32b =
97			TRANS_AES_TO_CANONICAL.transform(&self.state[0]);
98		out.copy_from_slice(BinaryField8b::to_underliers_ref(
99			PackedBinaryField8x32b::unpack_base_scalars(std::slice::from_ref(&canonical_tower)),
100		));
101	}
102}
103
104impl HashMarker for VisionHasherDigest {}
105
106impl Update for VisionHasherDigest {
107	fn update(&mut self, mut data: &[u8]) {
108		if self.filled_bytes != 0 {
109			let to_copy = std::cmp::min(data.len(), RATE_AS_U8 - self.filled_bytes);
110			self.buffer[self.filled_bytes..self.filled_bytes + to_copy]
111				.copy_from_slice(&data[..to_copy]);
112			data = &data[to_copy..];
113			self.filled_bytes += to_copy;
114
115			if self.filled_bytes == RATE_AS_U8 {
116				Self::permute(&mut self.state, &self.buffer);
117				self.filled_bytes = 0;
118			}
119		}
120
121		let mut chunks = data.chunks_exact(RATE_AS_U8);
122		for chunk in &mut chunks {
123			Self::permute(&mut self.state, chunk);
124		}
125
126		let remaining = chunks.remainder();
127		if !remaining.is_empty() {
128			self.buffer[..remaining.len()].copy_from_slice(remaining);
129			self.filled_bytes = remaining.len();
130		}
131	}
132}
133
134impl OutputSizeUser for VisionHasherDigest {
135	type OutputSize = U32;
136}
137
138impl BlockSizeUser for VisionHasherDigest {
139	type BlockSize = U96;
140}
141
142impl FixedOutput for VisionHasherDigest {
143	fn finalize_into(mut self, out: &mut digest::Output<Self>) {
144		Self::finalize_into(&mut self, out);
145	}
146}
147
148impl Reset for VisionHasherDigest {
149	fn reset(&mut self) {
150		bytemuck::fill_zeroes(&mut self.state);
151		bytemuck::fill_zeroes(&mut self.buffer);
152		self.filled_bytes = 0;
153	}
154}
155
156impl FixedOutputReset for VisionHasherDigest {
157	fn finalize_into_reset(&mut self, out: &mut digest::Output<Self>) {
158		Self::finalize_into(self, out);
159		Reset::reset(self);
160	}
161}
162
163/// Fill the data using Keccak padding scheme.
164#[inline(always)]
165fn fill_padding(data: &mut [u8]) {
166	debug_assert!(!data.is_empty() && data.len() <= RATE_AS_U8);
167
168	data.fill(0);
169	data[0] |= PADDING_START;
170	data[data.len() - 1] |= PADDING_END;
171}
172
173#[derive(Clone)]
174pub struct VisionHasherDigestByteSliced {
175	// The hashed state
176	state: [ByteSlicedAES32x32b; 24],
177	// Buffer to hold the temporary data
178	buffer: [[u8; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
179	filled_bytes: usize,
180}
181
182impl Default for VisionHasherDigestByteSliced {
183	fn default() -> Self {
184		Self {
185			state: [ByteSlicedAES32x32b::zero(); 24],
186			buffer: [[0; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
187			filled_bytes: 0,
188		}
189	}
190}
191
192impl VisionHasherDigestByteSliced {
193	fn permute(
194		state: &mut [ByteSlicedAES32x32b; 24],
195		data: [&[u8; RATE_AS_U8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
196	) {
197		for row in &data {
198			debug_assert_eq!(row.len(), RATE_AS_U8);
199		}
200
201		for state_element_index in 0..2 {
202			let data_offset = state_element_index * HASHES_PER_BYTE_SLICED_PERMUTATION;
203
204			for (i, state_element) in state[state_element_index * 8..state_element_index * 8 + 8]
205				.iter_mut()
206				.enumerate()
207			{
208				let ordinary_range_data: [PackedAESBinaryField8x32b; 4] = array::from_fn(|j| {
209					let canonical = PackedBinaryField8x32b::from_fn(|k| {
210						u32::from_le_bytes(
211							(data[j * 8 + k][data_offset + 4 * i..data_offset + 4 * i + 4])
212								.try_into()
213								.expect("chunk is 4 bytes"),
214						)
215						.into()
216					});
217
218					TRANS_CANONICAL_TO_AES.transform(&canonical)
219				});
220
221				*state_element = ByteSlicedAES32x32b::transpose_from(&ordinary_range_data);
222			}
223		}
224
225		PERMUTATION.permute_mut(state);
226	}
227
228	fn finalize(
229		&mut self,
230		out: &mut [MaybeUninit<digest::Output<VisionHasherDigest>>;
231			     HASHES_PER_BYTE_SLICED_PERMUTATION],
232	) {
233		if self.filled_bytes > 0 {
234			for row in 0..HASHES_PER_BYTE_SLICED_PERMUTATION {
235				fill_padding(&mut self.buffer[row][self.filled_bytes..]);
236			}
237
238			Self::permute(&mut self.state, array::from_fn(|i| &self.buffer[i]));
239		} else {
240			Self::permute(&mut self.state, array::from_fn(|_| &*PADDING_BLOCK));
241		}
242
243		// TODO: Use transposition function here as soon as it is merged in.
244		let out: &mut [digest::Output<VisionHasherDigest>; HASHES_PER_BYTE_SLICED_PERMUTATION] =
245			unsafe { slice_assume_init_mut(out) }
246				.try_into()
247				.expect("array is 32 elements");
248		for (i, state_data) in self.state[0..8].iter().enumerate() {
249			let mut transposed_aes = Default::default();
250			state_data.transpose_to(&mut transposed_aes);
251
252			for (j, transposed_aes) in transposed_aes.iter().enumerate() {
253				let transposed_canonical: PackedBinaryField8x32b =
254					TRANS_AES_TO_CANONICAL.transform(transposed_aes);
255				for (k, scalar) in transposed_canonical.iter().enumerate() {
256					out[j * 8 + k][i * 4..i * 4 + 4]
257						.copy_from_slice(&scalar.to_underlier().to_le_bytes());
258				}
259			}
260		}
261	}
262}
263
264impl MultiDigest<HASHES_PER_BYTE_SLICED_PERMUTATION> for VisionHasherDigestByteSliced {
265	type Digest = VisionHasherDigest;
266
267	fn new() -> Self {
268		Self::default()
269	}
270
271	fn update(&mut self, data: [&[u8]; HASHES_PER_BYTE_SLICED_PERMUTATION]) {
272		for row in 1..HASHES_PER_BYTE_SLICED_PERMUTATION {
273			debug_assert_eq!(data[row].len(), data[0].len());
274		}
275
276		let mut offset = if self.filled_bytes > 0 {
277			let to_copy = std::cmp::min(data[0].len(), RATE_AS_U8 - self.filled_bytes);
278			for (row_i, row) in data
279				.iter()
280				.enumerate()
281				.take(HASHES_PER_BYTE_SLICED_PERMUTATION)
282			{
283				self.buffer[row_i][self.filled_bytes..self.filled_bytes + to_copy]
284					.copy_from_slice(&row[..to_copy]);
285			}
286
287			self.filled_bytes += to_copy;
288
289			if self.filled_bytes == RATE_AS_U8 {
290				Self::permute(&mut self.state, array::from_fn(|i| &self.buffer[i]));
291				self.filled_bytes = 0;
292			}
293
294			to_copy
295		} else {
296			0
297		};
298
299		while offset + RATE_AS_U8 <= data[0].len() {
300			let chunk = array::from_fn(|i| {
301				(&data[i][offset..offset + RATE_AS_U8])
302					.try_into()
303					.expect("array is 32 bytes")
304			});
305			Self::permute(&mut self.state, chunk);
306			offset += RATE_AS_U8;
307		}
308
309		if offset < data[0].len() {
310			for (row_i, row) in data
311				.iter()
312				.enumerate()
313				.take(HASHES_PER_BYTE_SLICED_PERMUTATION)
314			{
315				self.buffer[row_i][..row.len() - offset].copy_from_slice(&row[offset..]);
316			}
317
318			self.filled_bytes = data[0].len() - offset;
319		}
320	}
321
322	fn finalize_into(
323		mut self,
324		out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
325	) {
326		self.finalize(out);
327	}
328
329	fn finalize_into_reset(
330		&mut self,
331		out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
332	) {
333		self.finalize(out);
334		self.reset();
335	}
336
337	fn reset(&mut self) {
338		bytemuck::fill_zeroes(&mut self.state);
339		self.filled_bytes = 0;
340	}
341
342	fn digest(
343		data: [&[u8]; HASHES_PER_BYTE_SLICED_PERMUTATION],
344		out: &mut [MaybeUninit<digest::Output<Self::Digest>>; HASHES_PER_BYTE_SLICED_PERMUTATION],
345	) {
346		let mut digest = Self::default();
347		digest.update(data);
348		digest.finalize_into(out);
349	}
350}
351
352pub type Vision32ParallelDigest =
353	ParallelMultidigestImpl<VisionHasherDigestByteSliced, HASHES_PER_BYTE_SLICED_PERMUTATION>;
354
355#[cfg(test)]
356mod tests {
357	use std::{array, mem::MaybeUninit};
358
359	use digest::Digest;
360	use hex_literal::hex;
361
362	use super::{
363		HASHES_PER_BYTE_SLICED_PERMUTATION, MultiDigest, VisionHasherDigest,
364		VisionHasherDigestByteSliced,
365	};
366
367	#[test]
368	fn test_simple_hash() {
369		let mut hasher = VisionHasherDigest::default();
370		let data = [0xde, 0xad, 0xbe, 0xef];
371		hasher.update(data);
372		let out = hasher.finalize();
373		// This hash is retrieved from a modified python implementation with the Keccak padding
374		// scheme and the changed mds matrix.
375		let expected = &hex!("8ed389809fabe91cead4786eb08e2d32647a9ac69143040de500e4465c72f173");
376		assert_eq!(expected, &*out);
377	}
378
379	#[test]
380	fn test_multi_block_aligned() {
381		let mut hasher = VisionHasherDigest::default();
382		let input = "One part of the mysterious existence of Captain Nemo had been unveiled and, if his identity had not been recognised, at least, the nations united against him were no longer hunting a chimerical creature, but a man who had vowed a deadly hatred against them";
383		hasher.update(input.as_bytes());
384		let out = hasher.finalize();
385
386		let expected = &hex!("b615664d0249149b5655a86919169f0fd4b44fec83d4c43e4f1f124c3f9a82c3");
387		assert_eq!(expected, &*out);
388
389		let mut hasher = VisionHasherDigest::default();
390		let input_as_b = input.as_bytes();
391		hasher.update(&input_as_b[0..63]);
392		hasher.update(&input_as_b[63..128]);
393		hasher.update(&input_as_b[128..163]);
394		hasher.update(&input_as_b[163..]);
395
396		assert_eq!(expected, &*hasher.finalize());
397	}
398
399	#[test]
400	fn test_multi_block_unaligned() {
401		let mut hasher = VisionHasherDigest::default();
402		let input = "You can prove anything you want by coldly logical reason--if you pick the proper postulates.";
403		hasher.update(input.as_bytes());
404
405		let expected = &hex!("0aa2879dcac953550ebe5d9da2a91d3c0356feca9044acf4edca87b28d9959e1");
406		let out = hasher.finalize();
407		assert_eq!(expected, &*out);
408	}
409
410	fn check_multihash_consistency(chunks: &[[&[u8]; 32]]) {
411		let mut scalar_digests = array::from_fn::<_, 32, _>(|_| VisionHasherDigest::default());
412		let mut multidigest = VisionHasherDigestByteSliced::default();
413
414		for chunk in chunks {
415			for (scalar_digest, data) in scalar_digests.iter_mut().zip(chunk.iter()) {
416				scalar_digest.update(data);
417			}
418
419			multidigest.update(*chunk);
420		}
421
422		let scalar_digests = scalar_digests.map(|d| d.finalize());
423		let mut output = [MaybeUninit::uninit(); 32];
424		multidigest.finalize_into(&mut output);
425		let output = unsafe { array::from_fn::<_, 4, _>(|i| output[i].assume_init()) };
426
427		for i in 0..4 {
428			assert_eq!(&*scalar_digests[i], &*output[i]);
429		}
430	}
431
432	#[test]
433	fn test_multihash_consistency_small_data() {
434		let data = array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| {
435			[i as u8, (i + 1) as _, (i + 2) as _, (i + 3) as _]
436		});
437
438		check_multihash_consistency(&[array::from_fn::<
439			_,
440			{ HASHES_PER_BYTE_SLICED_PERMUTATION },
441			_,
442		>(|i| &data[i][..])]);
443	}
444
445	#[test]
446	fn test_multihash_consistency_small_rate() {
447		let data =
448			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8, 64]);
449
450		check_multihash_consistency(&[array::from_fn::<
451			_,
452			{ HASHES_PER_BYTE_SLICED_PERMUTATION },
453			_,
454		>(|i| &data[i][..])]);
455	}
456
457	#[test]
458	fn test_multihash_consistency_large_rate() {
459		let data =
460			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8; 1024]);
461
462		check_multihash_consistency(&[array::from_fn::<
463			_,
464			{ HASHES_PER_BYTE_SLICED_PERMUTATION },
465			_,
466		>(|i| &data[i][..])]);
467	}
468
469	#[test]
470	fn test_multihash_consistency_several_chunks() {
471		let data_0 =
472			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [i as u8, 48]);
473		let data_1 =
474			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| [(i + 1) as u8, 64]);
475		let data_2 = array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| {
476			[(i + 2) as u8, 128]
477		});
478
479		check_multihash_consistency(&[
480			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_0[i][..]),
481			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_1[i][..]),
482			array::from_fn::<_, { HASHES_PER_BYTE_SLICED_PERMUTATION }, _>(|i| &data_2[i][..]),
483		]);
484	}
485}