Skip to main content

binius_hash/
sha256.rs

1// Copyright 2024-2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4//! SHA-256 compression function for use in Merkle tree constructions.
5
6use std::mem::MaybeUninit;
7
8use binius_utils::{
9	FixedSizeSerializeBytes, SerializeBytes,
10	rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
11};
12use bytemuck::{bytes_of_mut, must_cast};
13use digest::Digest;
14use sha2::{Sha256, block_api::compress256, digest::Output};
15
16use super::{
17	binary_merkle_tree::HashSuite,
18	compress::{CompressionFunction, PseudoCompressionFunction},
19	parallel_compression::ParallelCompressionAdaptor,
20	parallel_digest::{ParallelDigest, ParallelDigestAdapter},
21};
22
23/// SHA-256 initial hash values, used as the starting state for a raw block compression.
24const SHA256_IV: [u32; 8] = [
25	0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
26];
27
28/// The largest leaf, in bytes, that still fits (together with SHA-256 padding) in a single
29/// 64-byte block: one byte for the `0x80` terminator and eight for the big-endian bit length.
30const SINGLE_BLOCK_MAX_LEN: usize = 64 - 1 - 8;
31
32/// A two-to-one compression function for SHA-256 digests.
33#[derive(Debug, Clone)]
34pub struct Sha256Compression {
35	initial_state: [u32; 8],
36}
37
38impl Default for Sha256Compression {
39	fn default() -> Self {
40		let initial_state_bytes = Sha256::digest(b"BINIUS SHA-256 COMPRESS");
41		let mut initial_state = [0u32; 8];
42		bytes_of_mut(&mut initial_state).copy_from_slice(&initial_state_bytes);
43		Self { initial_state }
44	}
45}
46
47impl PseudoCompressionFunction<Output<Sha256>, 2> for Sha256Compression {
48	fn compress(&self, input: [Output<Sha256>; 2]) -> Output<Sha256> {
49		let mut ret = self.initial_state;
50		let mut block = [0u8; 64];
51		block[..32].copy_from_slice(input[0].as_slice());
52		block[32..].copy_from_slice(input[1].as_slice());
53		compress256(&mut ret, &[block]);
54		must_cast::<[u32; 8], [u8; 32]>(ret).into()
55	}
56}
57
58impl CompressionFunction<Output<Sha256>, 2> for Sha256Compression {}
59
60/// SHA-256 [`HashSuite`]: SHA-256 leaves and a SHA-256 compression function for inner nodes.
61#[derive(Debug, Clone, Default)]
62pub struct Sha256HashSuite;
63
64impl HashSuite for Sha256HashSuite {
65	type LeafHash = Sha256;
66	type Compression = Sha256Compression;
67	type ParLeafHash = ParallelSha256Digest;
68	type ParCompression = ParallelCompressionAdaptor<Sha256Compression>;
69}
70
71/// A [`ParallelDigest`] for SHA-256 that specializes
72/// [`digest_with_const_len`](ParallelDigest::digest_with_const_len) for short, fixed-length
73/// leaves.
74///
75/// When every leaf serializes to at most `SINGLE_BLOCK_MAX_LEN` bytes, the whole leaf — message,
76/// padding, and length suffix — fits in one 64-byte block, so the digest is a single call to the
77/// raw [`compress256`] block function starting from the SHA-256 IV. This skips the `update`/
78/// `finalize` bookkeeping that the generic [`ParallelDigestAdapter`] performs per leaf.
79///
80/// Longer leaves fall back to [`ParallelDigestAdapter`].
81#[derive(Debug, Clone, Default)]
82pub struct ParallelSha256Digest;
83
84impl ParallelDigest for ParallelSha256Digest {
85	type Digest = Sha256;
86
87	fn new() -> Self {
88		Self
89	}
90
91	fn digest<I: IntoIterator<Item: SerializeBytes>>(
92		&self,
93		source: impl IndexedParallelIterator<Item = I>,
94		out: &mut [MaybeUninit<Output<Sha256>>],
95	) {
96		ParallelDigestAdapter::<Sha256>::new().digest(source, out);
97	}
98
99	fn digest_with_const_len<I: IntoIterator<Item: FixedSizeSerializeBytes>>(
100		&self,
101		n_items_per_input: usize,
102		source: impl IndexedParallelIterator<Item = I>,
103		out: &mut [MaybeUninit<Output<Sha256>>],
104	) {
105		let leaf_len = n_items_per_input * <I::Item as FixedSizeSerializeBytes>::BYTE_SIZE;
106		if leaf_len > SINGLE_BLOCK_MAX_LEN {
107			self.digest(source, out);
108			return;
109		}
110
111		// Precompute the padding suffix once: a `0x80` terminator immediately after the message,
112		// then zeros, then the 64-bit big-endian message bit length. Because `leaf_len` is constant
113		// for every leaf, this suffix is identical across leaves; each leaf only overwrites the
114		// `leaf_len`-byte message prefix.
115		let mut block_template = [0u8; 64];
116		block_template[leaf_len] = 0x80;
117		block_template[56..64].copy_from_slice(&((leaf_len as u64) * 8).to_be_bytes());
118
119		source
120			.zip(out.par_iter_mut())
121			.for_each_with(block_template, |block, (items, out)| {
122				// Overwrite the message prefix; the padding suffix stays untouched.
123				let mut cursor = &mut block[..leaf_len];
124				let mut n_items = 0;
125				for item in items {
126					item.serialize(&mut cursor)
127						.expect("pre-condition: items must serialize without error");
128					n_items += 1;
129				}
130				debug_assert_eq!(n_items, n_items_per_input);
131				debug_assert!(cursor.is_empty(), "pre-condition: each leaf serializes to leaf_len");
132
133				let mut state = SHA256_IV;
134				compress256(&mut state, std::slice::from_ref(&*block));
135
136				// SHA-256 emits its state words in big-endian byte order.
137				let mut digest = Output::<Sha256>::default();
138				for (chunk, word) in digest.chunks_exact_mut(4).zip(state) {
139					chunk.copy_from_slice(&word.to_be_bytes());
140				}
141				out.write(digest);
142			});
143	}
144}
145
146#[cfg(test)]
147mod tests {
148	use std::iter::repeat_with;
149
150	use binius_utils::rayon::iter::{IntoParallelRefIterator, ParallelIterator};
151	use rand::{RngExt, SeedableRng, rngs::StdRng};
152
153	use super::*;
154
155	/// Checks that the specialized digest matches `Sha256::digest` over the serialized leaf bytes,
156	/// covering both the single-block fast path and the multi-block fallback.
157	#[test]
158	fn test_parallel_sha256_matches_serial() {
159		let mut rng = StdRng::seed_from_u64(0);
160		// `u128` serializes to 16 little-endian bytes, so leaf lengths are 16, 32, 48 (single
161		// block) and 64 (> SINGLE_BLOCK_MAX_LEN, exercises the fallback).
162		for n_items_per_input in [1, 2, 3, 4] {
163			let n_leaves = 50;
164			let leaves: Vec<Vec<u128>> = (0..n_leaves)
165				.map(|_| {
166					(0..n_items_per_input)
167						.map(|_| rng.random::<u128>())
168						.collect()
169				})
170				.collect();
171
172			let digest = ParallelSha256Digest::new();
173			let mut results = repeat_with(MaybeUninit::<Output<Sha256>>::uninit)
174				.take(n_leaves)
175				.collect::<Vec<_>>();
176			digest.digest_with_const_len(
177				n_items_per_input,
178				leaves.par_iter().map(|leaf| leaf.iter().copied()),
179				&mut results,
180			);
181
182			for (result, leaf) in results.into_iter().zip(&leaves) {
183				let mut bytes = Vec::new();
184				for &item in leaf {
185					bytes.extend_from_slice(&item.to_le_bytes());
186				}
187				assert_eq!(unsafe { result.assume_init() }, <Sha256 as Digest>::digest(&bytes));
188			}
189		}
190	}
191}