1use std::{cmp::min, mem};
4
5use bytes::{buf::UninitSlice, Buf, BufMut};
6use digest::{
7 core_api::{Block, BlockSizeUser},
8 Digest, FixedOutputReset, Output,
9};
10
11use super::Challenger;
12
13#[derive(Debug, Default, Clone)]
15pub struct HasherSampler<H: Digest> {
16 index: usize,
17 buffer: Output<H>,
18 hasher: H,
19}
20
21#[derive(Debug, Default, Clone)]
23pub struct HasherObserver<H: Digest + BlockSizeUser> {
24 index: usize,
25 buffer: Block<H>,
26 hasher: H,
27}
28
29#[derive(Debug, Clone)]
33pub enum HasherChallenger<H: Digest + BlockSizeUser> {
34 Observer(HasherObserver<H>),
35 Sampler(HasherSampler<H>),
36}
37
38impl<H> HasherChallenger<H>
39where
40 H: Digest + BlockSizeUser,
41{
42 fn new(initial_digest: Output<H>) -> Self {
43 let mut hasher = H::new();
44 Digest::update(&mut hasher, &initial_digest);
45
46 Self::Sampler(HasherSampler {
47 hasher,
48 index: 0,
49 buffer: initial_digest,
50 })
51 }
52}
53
54impl<H> Default for HasherChallenger<H>
55where
56 H: Digest + BlockSizeUser + FixedOutputReset,
57{
58 fn default() -> Self {
59 Self::new(H::digest([]))
60 }
61}
62
63impl<H: Digest + BlockSizeUser + FixedOutputReset + Default> Challenger for HasherChallenger<H> {
64 fn observer(&mut self) -> &mut impl BufMut {
66 match self {
67 Self::Observer(observer) => observer,
68 Self::Sampler(sampler) => {
69 *self = Self::Observer(mem::take(sampler).into_observer());
70 match self {
71 Self::Observer(observer) => observer,
72 _ => unreachable!(),
73 }
74 }
75 }
76 }
77
78 fn sampler(&mut self) -> &mut impl Buf {
80 match self {
81 Self::Sampler(sampler) => sampler,
82 Self::Observer(observer) => {
83 *self = Self::Sampler(mem::take(observer).into_sampler());
84 match self {
85 Self::Sampler(sampler) => sampler,
86 _ => unreachable!(),
87 }
88 }
89 }
90 }
91}
92
93impl<H> HasherSampler<H>
94where
95 H: Digest + Default + BlockSizeUser,
96{
97 fn into_observer(mut self) -> HasherObserver<H> {
98 Digest::update(&mut self.hasher, self.index.to_le_bytes());
99
100 HasherObserver {
101 hasher: self.hasher,
102 index: 0,
103 buffer: Block::<H>::default(),
104 }
105 }
106}
107
108impl<H> HasherSampler<H>
109where
110 H: Digest + FixedOutputReset,
111{
112 fn fill_buffer(&mut self) {
113 let digest = self.hasher.finalize_reset();
114
115 Digest::update(&mut self.hasher, &digest);
117
118 self.buffer = digest;
119 self.index = 0
120 }
121}
122
123impl<H> HasherObserver<H>
124where
125 H: Digest + BlockSizeUser + Default,
126{
127 fn into_sampler(mut self) -> HasherSampler<H> {
128 self.flush();
129 HasherSampler {
130 hasher: self.hasher,
131 index: <H as Digest>::output_size(),
132 buffer: Output::<H>::default(),
133 }
134 }
135}
136
137impl<H> HasherObserver<H>
138where
139 H: Digest + BlockSizeUser,
140{
141 fn flush(&mut self) {
142 self.hasher.update(&self.buffer[..self.index]);
143 self.index = 0
144 }
145}
146
147impl<H> Buf for HasherSampler<H>
148where
149 H: Digest + FixedOutputReset + Default,
150{
151 fn remaining(&self) -> usize {
152 usize::MAX
153 }
154
155 fn chunk(&self) -> &[u8] {
156 &self.buffer[self.index..]
157 }
158
159 fn advance(&mut self, mut cnt: usize) {
160 if self.index == <H as Digest>::output_size() {
162 self.fill_buffer();
163 }
164
165 while cnt > 0 {
166 let remaining = min(<H as Digest>::output_size() - self.index, cnt);
167 if remaining == 0 {
168 self.fill_buffer();
169 continue;
170 }
171 cnt -= remaining;
172 self.index += remaining;
173 }
174 }
175}
176
177unsafe impl<H> BufMut for HasherObserver<H>
178where
179 H: Digest + BlockSizeUser,
180{
181 fn remaining_mut(&self) -> usize {
182 usize::MAX
183 }
184
185 unsafe fn advance_mut(&mut self, mut cnt: usize) {
186 while cnt > 0 {
187 let remaining = min(<H as BlockSizeUser>::block_size() - self.index, cnt);
188 cnt -= remaining;
189 self.index += remaining;
190 if self.index == <H as BlockSizeUser>::block_size() {
191 self.flush();
192 }
193 }
194 }
195
196 fn chunk_mut(&mut self) -> &mut UninitSlice {
197 let buffer = &mut self.buffer[self.index..];
198 buffer.into()
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use binius_hash::groestl::Groestl256;
205 use rand::{thread_rng, RngCore};
206
207 use super::*;
208
209 #[test]
210 fn test_starting_sampler() {
211 let empty_string_digest = Groestl256::digest([]);
212
213 let mut hasher = {
214 let mut hasher = Groestl256::default();
215 Digest::update(&mut hasher, empty_string_digest);
216 hasher
217 };
218
219 let mut challenger = HasherChallenger::<Groestl256>::default();
220
221 let mut out = [0u8; 8];
223 challenger.sampler().copy_to_slice(&mut out);
224
225 let first_hash_out = empty_string_digest;
226
227 assert_eq!(first_hash_out[0..8], out);
228
229 let mut out = [0u8; 24];
231 challenger.sampler().copy_to_slice(&mut out);
232 assert_eq!(first_hash_out[8..], out);
233
234 challenger.observer().put_slice(&[0x48, 0x55]);
236 hasher.update([32, 0, 0, 0, 0, 0, 0, 0]);
237 hasher.update([0x48, 0x55]);
238
239 let mut out_after_observe = [0u8; 2];
241 challenger.sampler().copy_to_slice(&mut out_after_observe);
242
243 let second_hash_out = hasher.finalize_reset();
244
245 assert_eq!(out_after_observe, second_hash_out[..2]);
246 }
247
248 #[test]
249 fn test_starting_observer() {
250 let empty_string_digest = Groestl256::digest([]);
251
252 let mut hasher = {
253 let mut hasher = Groestl256::default();
254 Digest::update(&mut hasher, empty_string_digest);
255 hasher
256 };
257
258 let mut challenger = HasherChallenger::<Groestl256>::default();
259
260 let mut observable = [0u8; 1019];
262 thread_rng().fill_bytes(&mut observable);
263 challenger.observer().put_slice(&observable[..39]);
264 challenger.observer().put_slice(&observable[39..300]);
265 challenger.observer().put_slice(&observable[300..987]);
266 challenger.observer().put_slice(&observable[987..]);
267 hasher.update([0, 0, 0, 0, 0, 0, 0, 0]);
268 hasher.update(observable);
269
270 let mut out = [0u8; 7];
272 challenger.sampler().copy_to_slice(&mut out);
273
274 let first_hash_out = hasher.finalize_reset();
275 hasher.update(first_hash_out);
276
277 assert_eq!(first_hash_out[..7], out);
278
279 thread_rng().fill_bytes(&mut observable);
281 challenger.observer().put_slice(&observable[..128]);
282 hasher.update([7, 0, 0, 0, 0, 0, 0, 0]);
283
284 hasher.update(&observable[..128]);
285
286 let mut out = [0u8; 32];
288 challenger.sampler().copy_to_slice(&mut out);
289
290 let second_hasher_out = hasher.finalize_reset();
291 hasher.update(second_hasher_out);
292
293 assert_eq!(second_hasher_out[..], out);
294
295 challenger.observer().put_slice(&observable[128..]);
297 hasher.update([32, 0, 0, 0, 0, 0, 0, 0]);
298 hasher.update(&observable[128..]);
299
300 let mut out_again = [0u8; 7];
302 challenger.sampler().copy_to_slice(&mut out_again);
303
304 let final_hasher_out = hasher.finalize_reset();
305 assert_eq!(final_hasher_out[..7], out_again);
306 }
307}