1use std::mem::MaybeUninit;
7
8use binius_utils::{
9 FixedSizeSerializeBytes, SerializeBytes,
10 rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
11};
12use bytemuck::{bytes_of_mut, must_cast};
13use digest::Digest;
14use sha2::{Sha256, block_api::compress256, digest::Output};
15
16use super::{
17 binary_merkle_tree::HashSuite,
18 compress::{CompressionFunction, PseudoCompressionFunction},
19 parallel_compression::ParallelCompressionAdaptor,
20 parallel_digest::{ParallelDigest, ParallelDigestAdapter},
21};
22
23const SHA256_IV: [u32; 8] = [
25 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19,
26];
27
28const SINGLE_BLOCK_MAX_LEN: usize = 64 - 1 - 8;
31
32#[derive(Debug, Clone)]
34pub struct Sha256Compression {
35 initial_state: [u32; 8],
36}
37
38impl Default for Sha256Compression {
39 fn default() -> Self {
40 let initial_state_bytes = Sha256::digest(b"BINIUS SHA-256 COMPRESS");
41 let mut initial_state = [0u32; 8];
42 bytes_of_mut(&mut initial_state).copy_from_slice(&initial_state_bytes);
43 Self { initial_state }
44 }
45}
46
47impl PseudoCompressionFunction<Output<Sha256>, 2> for Sha256Compression {
48 fn compress(&self, input: [Output<Sha256>; 2]) -> Output<Sha256> {
49 let mut ret = self.initial_state;
50 let mut block = [0u8; 64];
51 block[..32].copy_from_slice(input[0].as_slice());
52 block[32..].copy_from_slice(input[1].as_slice());
53 compress256(&mut ret, &[block]);
54 must_cast::<[u32; 8], [u8; 32]>(ret).into()
55 }
56}
57
58impl CompressionFunction<Output<Sha256>, 2> for Sha256Compression {}
59
60#[derive(Debug, Clone, Default)]
62pub struct Sha256HashSuite;
63
64impl HashSuite for Sha256HashSuite {
65 type LeafHash = Sha256;
66 type Compression = Sha256Compression;
67 type ParLeafHash = ParallelSha256Digest;
68 type ParCompression = ParallelCompressionAdaptor<Sha256Compression>;
69}
70
71#[derive(Debug, Clone, Default)]
82pub struct ParallelSha256Digest;
83
84impl ParallelDigest for ParallelSha256Digest {
85 type Digest = Sha256;
86
87 fn new() -> Self {
88 Self
89 }
90
91 fn digest<I: IntoIterator<Item: SerializeBytes>>(
92 &self,
93 source: impl IndexedParallelIterator<Item = I>,
94 out: &mut [MaybeUninit<Output<Sha256>>],
95 ) {
96 ParallelDigestAdapter::<Sha256>::new().digest(source, out);
97 }
98
99 fn digest_with_const_len<I: IntoIterator<Item: FixedSizeSerializeBytes>>(
100 &self,
101 n_items_per_input: usize,
102 source: impl IndexedParallelIterator<Item = I>,
103 out: &mut [MaybeUninit<Output<Sha256>>],
104 ) {
105 let leaf_len = n_items_per_input * <I::Item as FixedSizeSerializeBytes>::BYTE_SIZE;
106 if leaf_len > SINGLE_BLOCK_MAX_LEN {
107 self.digest(source, out);
108 return;
109 }
110
111 let mut block_template = [0u8; 64];
116 block_template[leaf_len] = 0x80;
117 block_template[56..64].copy_from_slice(&((leaf_len as u64) * 8).to_be_bytes());
118
119 source
120 .zip(out.par_iter_mut())
121 .for_each_with(block_template, |block, (items, out)| {
122 let mut cursor = &mut block[..leaf_len];
124 let mut n_items = 0;
125 for item in items {
126 item.serialize(&mut cursor)
127 .expect("pre-condition: items must serialize without error");
128 n_items += 1;
129 }
130 debug_assert_eq!(n_items, n_items_per_input);
131 debug_assert!(cursor.is_empty(), "pre-condition: each leaf serializes to leaf_len");
132
133 let mut state = SHA256_IV;
134 compress256(&mut state, std::slice::from_ref(&*block));
135
136 let mut digest = Output::<Sha256>::default();
138 for (chunk, word) in digest.chunks_exact_mut(4).zip(state) {
139 chunk.copy_from_slice(&word.to_be_bytes());
140 }
141 out.write(digest);
142 });
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use std::iter::repeat_with;
149
150 use binius_utils::rayon::iter::{IntoParallelRefIterator, ParallelIterator};
151 use rand::{RngExt, SeedableRng, rngs::StdRng};
152
153 use super::*;
154
155 #[test]
158 fn test_parallel_sha256_matches_serial() {
159 let mut rng = StdRng::seed_from_u64(0);
160 for n_items_per_input in [1, 2, 3, 4] {
163 let n_leaves = 50;
164 let leaves: Vec<Vec<u128>> = (0..n_leaves)
165 .map(|_| {
166 (0..n_items_per_input)
167 .map(|_| rng.random::<u128>())
168 .collect()
169 })
170 .collect();
171
172 let digest = ParallelSha256Digest::new();
173 let mut results = repeat_with(MaybeUninit::<Output<Sha256>>::uninit)
174 .take(n_leaves)
175 .collect::<Vec<_>>();
176 digest.digest_with_const_len(
177 n_items_per_input,
178 leaves.par_iter().map(|leaf| leaf.iter().copied()),
179 &mut results,
180 );
181
182 for (result, leaf) in results.into_iter().zip(&leaves) {
183 let mut bytes = Vec::new();
184 for &item in leaf {
185 bytes.extend_from_slice(&item.to_le_bytes());
186 }
187 assert_eq!(unsafe { result.assume_init() }, <Sha256 as Digest>::digest(&bytes));
188 }
189 }
190 }
191}