binius_core/fiat_shamir/
hasher_challenger.rs1use std::{cmp::min, mem};
4
5use bytes::{buf::UninitSlice, Buf, BufMut};
6use digest::{
7 core_api::{Block, BlockSizeUser},
8 Digest, FixedOutputReset, Output,
9};
10
11use super::Challenger;
12
13#[derive(Debug, Default)]
15pub struct HasherSampler<H: Digest> {
16 index: usize,
17 buffer: Output<H>,
18 hasher: H,
19}
20
21#[derive(Debug, Default)]
23pub struct HasherObserver<H: Digest + BlockSizeUser> {
24 index: usize,
25 buffer: Block<H>,
26 hasher: H,
27}
28
29#[derive(Debug)]
33pub enum HasherChallenger<H: Digest + BlockSizeUser> {
34 Observer(HasherObserver<H>),
35 Sampler(HasherSampler<H>),
36}
37
38impl<H> Default for HasherChallenger<H>
39where
40 H: Digest + BlockSizeUser,
41{
42 fn default() -> Self {
43 Self::Observer(HasherObserver {
44 hasher: H::new(),
45 index: 0,
46 buffer: Block::<H>::default(),
47 })
48 }
49}
50
51impl<H: Digest + BlockSizeUser + FixedOutputReset + Default> Challenger for HasherChallenger<H> {
52 fn observer(&mut self) -> &mut impl BufMut {
54 match self {
55 Self::Observer(observer) => observer,
56 Self::Sampler(sampler) => {
57 *self = Self::Observer(mem::take(sampler).into_observer());
58 match self {
59 Self::Observer(observer) => observer,
60 _ => unreachable!(),
61 }
62 }
63 }
64 }
65
66 fn sampler(&mut self) -> &mut impl Buf {
68 match self {
69 Self::Sampler(sampler) => sampler,
70 Self::Observer(observer) => {
71 *self = Self::Sampler(mem::take(observer).into_sampler());
72 match self {
73 Self::Sampler(sampler) => sampler,
74 _ => unreachable!(),
75 }
76 }
77 }
78 }
79}
80
81impl<H> HasherSampler<H>
82where
83 H: Digest + Default + BlockSizeUser,
84{
85 fn into_observer(mut self) -> HasherObserver<H> {
86 Digest::update(&mut self.hasher, self.index.to_le_bytes());
87
88 HasherObserver {
89 hasher: self.hasher,
90 index: 0,
91 buffer: Block::<H>::default(),
92 }
93 }
94}
95
96impl<H> HasherSampler<H>
97where
98 H: Digest + FixedOutputReset,
99{
100 fn fill_buffer(&mut self) {
101 let digest = self.hasher.finalize_reset();
102
103 Digest::update(&mut self.hasher, &digest);
105
106 self.buffer = digest;
107 self.index = 0
108 }
109}
110
111impl<H> HasherObserver<H>
112where
113 H: Digest + BlockSizeUser + Default,
114{
115 fn into_sampler(mut self) -> HasherSampler<H> {
116 self.flush();
117 HasherSampler {
118 hasher: self.hasher,
119 index: <H as Digest>::output_size(),
120 buffer: Output::<H>::default(),
121 }
122 }
123}
124
125impl<H> HasherObserver<H>
126where
127 H: Digest + BlockSizeUser,
128{
129 fn flush(&mut self) {
130 self.hasher.update(&self.buffer[..self.index]);
131 self.index = 0
132 }
133}
134
135impl<H> Buf for HasherSampler<H>
136where
137 H: Digest + FixedOutputReset + Default,
138{
139 fn remaining(&self) -> usize {
140 usize::MAX
141 }
142
143 fn chunk(&self) -> &[u8] {
144 &self.buffer[self.index..]
145 }
146
147 fn advance(&mut self, mut cnt: usize) {
148 if self.index == <H as Digest>::output_size() {
150 self.fill_buffer();
151 }
152
153 while cnt > 0 {
154 let remaining = min(<H as Digest>::output_size() - self.index, cnt);
155 if remaining == 0 {
156 self.fill_buffer();
157 continue;
158 }
159 cnt -= remaining;
160 self.index += remaining;
161 }
162 }
163}
164
165unsafe impl<H> BufMut for HasherObserver<H>
166where
167 H: Digest + BlockSizeUser,
168{
169 fn remaining_mut(&self) -> usize {
170 usize::MAX
171 }
172
173 unsafe fn advance_mut(&mut self, mut cnt: usize) {
174 while cnt > 0 {
175 let remaining = min(<H as BlockSizeUser>::block_size() - self.index, cnt);
176 cnt -= remaining;
177 self.index += remaining;
178 if self.index == <H as BlockSizeUser>::block_size() {
179 self.flush();
180 }
181 }
182 }
183
184 fn chunk_mut(&mut self) -> &mut UninitSlice {
185 let buffer = &mut self.buffer[self.index..];
186 buffer.into()
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use groestl_crypto::Groestl256;
193 use rand::{thread_rng, RngCore};
194
195 use super::*;
196
197 #[test]
198 fn test_starting_sampler() {
199 let mut hasher = Groestl256::default();
200 let mut challenger = HasherChallenger::<Groestl256>::default();
201 let mut out = [0u8; 8];
202 challenger.sampler().copy_to_slice(&mut out);
203
204 let first_hash_out = hasher.finalize_reset();
205 hasher.update(first_hash_out);
206
207 assert_eq!(first_hash_out[0..8], out);
208
209 let mut out = [0u8; 24];
210 challenger.sampler().copy_to_slice(&mut out);
211 assert_eq!(first_hash_out[8..], out);
212
213 challenger.observer().put_slice(&[0x48, 0x55]);
214 hasher.update([32, 0, 0, 0, 0, 0, 0, 0]);
215 hasher.update([0x48, 0x55]);
216
217 let mut out_after_observe = [0u8; 2];
218 challenger.sampler().copy_to_slice(&mut out_after_observe);
219
220 let second_hash_out = hasher.finalize_reset();
221
222 assert_eq!(out_after_observe, second_hash_out[..2]);
223 }
224
225 #[test]
226 fn test_starting_observer() {
227 let mut hasher = Groestl256::default();
228 let mut challenger = HasherChallenger::<Groestl256>::default();
229 let mut observable = [0u8; 1019];
230 thread_rng().fill_bytes(&mut observable);
231 challenger.observer().put_slice(&observable[..39]);
232 challenger.observer().put_slice(&observable[39..300]);
233 challenger.observer().put_slice(&observable[300..987]);
234 challenger.observer().put_slice(&observable[987..]);
235 hasher.update(observable);
236
237 let mut out = [0u8; 7];
238 challenger.sampler().copy_to_slice(&mut out);
239
240 let first_hash_out = hasher.finalize_reset();
241 hasher.update(first_hash_out);
242
243 assert_eq!(first_hash_out[..7], out);
244
245 thread_rng().fill_bytes(&mut observable);
246 challenger.observer().put_slice(&observable[..128]);
247 hasher.update([7, 0, 0, 0, 0, 0, 0, 0]);
249
250 hasher.update(&observable[..128]);
251
252 let mut out = [0u8; 32];
253 challenger.sampler().copy_to_slice(&mut out);
254
255 let second_hasher_out = hasher.finalize_reset();
256 hasher.update(second_hasher_out);
257
258 assert_eq!(second_hasher_out[..], out);
259
260 challenger.observer().put_slice(&observable[128..]);
261 hasher.update([32, 0, 0, 0, 0, 0, 0, 0]);
262 hasher.update(&observable[128..]);
263
264 let mut out_again = [0u8; 7];
265 challenger.sampler().copy_to_slice(&mut out_again);
266
267 let final_hasher_out = hasher.finalize_reset();
268 assert_eq!(final_hasher_out[..7], out_again);
269 }
270}