binius_core/fiat_shamir/
hasher_challenger.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp::min, mem};
4
5use bytes::{buf::UninitSlice, Buf, BufMut};
6use digest::{
7	core_api::{Block, BlockSizeUser},
8	Digest, FixedOutputReset, Output,
9};
10
11use super::Challenger;
12
13/// Challenger type which implements `[Buf]` that has similar functionality as `[CanSample]`
14#[derive(Debug, Default)]
15pub struct HasherSampler<H: Digest> {
16	index: usize,
17	buffer: Output<H>,
18	hasher: H,
19}
20
21/// Challenger type which implements `[BufMut]` that has similar functionality as `[CanObserve]`
22#[derive(Debug, Default)]
23pub struct HasherObserver<H: Digest + BlockSizeUser> {
24	index: usize,
25	buffer: Block<H>,
26	hasher: H,
27}
28
29/// Challenger interface over hashes that implement `[Digest]` trait,
30///
31/// This challenger works over bytes instead of Field elements
32#[derive(Debug)]
33pub enum HasherChallenger<H: Digest + BlockSizeUser> {
34	Observer(HasherObserver<H>),
35	Sampler(HasherSampler<H>),
36}
37
38impl<H> Default for HasherChallenger<H>
39where
40	H: Digest + BlockSizeUser,
41{
42	fn default() -> Self {
43		Self::Observer(HasherObserver {
44			hasher: H::new(),
45			index: 0,
46			buffer: Block::<H>::default(),
47		})
48	}
49}
50
51impl<H: Digest + BlockSizeUser + FixedOutputReset + Default> Challenger for HasherChallenger<H> {
52	/// This returns the inner challenger which implements `[BufMut]`
53	fn observer(&mut self) -> &mut impl BufMut {
54		match self {
55			Self::Observer(observer) => observer,
56			Self::Sampler(sampler) => {
57				*self = Self::Observer(mem::take(sampler).into_observer());
58				match self {
59					Self::Observer(observer) => observer,
60					_ => unreachable!(),
61				}
62			}
63		}
64	}
65
66	/// This returns the inner challenger which implements `[Buf]`
67	fn sampler(&mut self) -> &mut impl Buf {
68		match self {
69			Self::Sampler(sampler) => sampler,
70			Self::Observer(observer) => {
71				*self = Self::Sampler(mem::take(observer).into_sampler());
72				match self {
73					Self::Sampler(sampler) => sampler,
74					_ => unreachable!(),
75				}
76			}
77		}
78	}
79}
80
81impl<H> HasherSampler<H>
82where
83	H: Digest + Default + BlockSizeUser,
84{
85	fn into_observer(mut self) -> HasherObserver<H> {
86		Digest::update(&mut self.hasher, self.index.to_le_bytes());
87
88		HasherObserver {
89			hasher: self.hasher,
90			index: 0,
91			buffer: Block::<H>::default(),
92		}
93	}
94}
95
96impl<H> HasherSampler<H>
97where
98	H: Digest + FixedOutputReset,
99{
100	fn fill_buffer(&mut self) {
101		let digest = self.hasher.finalize_reset();
102
103		// feed forward to the empty state
104		Digest::update(&mut self.hasher, &digest);
105
106		self.buffer = digest;
107		self.index = 0
108	}
109}
110
111impl<H> HasherObserver<H>
112where
113	H: Digest + BlockSizeUser + Default,
114{
115	fn into_sampler(mut self) -> HasherSampler<H> {
116		self.flush();
117		HasherSampler {
118			hasher: self.hasher,
119			index: <H as Digest>::output_size(),
120			buffer: Output::<H>::default(),
121		}
122	}
123}
124
125impl<H> HasherObserver<H>
126where
127	H: Digest + BlockSizeUser,
128{
129	fn flush(&mut self) {
130		self.hasher.update(&self.buffer[..self.index]);
131		self.index = 0
132	}
133}
134
135impl<H> Buf for HasherSampler<H>
136where
137	H: Digest + FixedOutputReset + Default,
138{
139	fn remaining(&self) -> usize {
140		usize::MAX
141	}
142
143	fn chunk(&self) -> &[u8] {
144		&self.buffer[self.index..]
145	}
146
147	fn advance(&mut self, mut cnt: usize) {
148		// Must handle the case when `cnt` is 0
149		if self.index == <H as Digest>::output_size() {
150			self.fill_buffer();
151		}
152
153		while cnt > 0 {
154			let remaining = min(<H as Digest>::output_size() - self.index, cnt);
155			if remaining == 0 {
156				self.fill_buffer();
157				continue;
158			}
159			cnt -= remaining;
160			self.index += remaining;
161		}
162	}
163}
164
165unsafe impl<H> BufMut for HasherObserver<H>
166where
167	H: Digest + BlockSizeUser,
168{
169	fn remaining_mut(&self) -> usize {
170		usize::MAX
171	}
172
173	unsafe fn advance_mut(&mut self, mut cnt: usize) {
174		while cnt > 0 {
175			let remaining = min(<H as BlockSizeUser>::block_size() - self.index, cnt);
176			cnt -= remaining;
177			self.index += remaining;
178			if self.index == <H as BlockSizeUser>::block_size() {
179				self.flush();
180			}
181		}
182	}
183
184	fn chunk_mut(&mut self) -> &mut UninitSlice {
185		let buffer = &mut self.buffer[self.index..];
186		buffer.into()
187	}
188}
189
190#[cfg(test)]
191mod tests {
192	use groestl_crypto::Groestl256;
193	use rand::{thread_rng, RngCore};
194
195	use super::*;
196
197	#[test]
198	fn test_starting_sampler() {
199		let mut hasher = Groestl256::default();
200		let mut challenger = HasherChallenger::<Groestl256>::default();
201		let mut out = [0u8; 8];
202		challenger.sampler().copy_to_slice(&mut out);
203
204		let first_hash_out = hasher.finalize_reset();
205		hasher.update(first_hash_out);
206
207		assert_eq!(first_hash_out[0..8], out);
208
209		let mut out = [0u8; 24];
210		challenger.sampler().copy_to_slice(&mut out);
211		assert_eq!(first_hash_out[8..], out);
212
213		challenger.observer().put_slice(&[0x48, 0x55]);
214		hasher.update([32, 0, 0, 0, 0, 0, 0, 0]);
215		hasher.update([0x48, 0x55]);
216
217		let mut out_after_observe = [0u8; 2];
218		challenger.sampler().copy_to_slice(&mut out_after_observe);
219
220		let second_hash_out = hasher.finalize_reset();
221
222		assert_eq!(out_after_observe, second_hash_out[..2]);
223	}
224
225	#[test]
226	fn test_starting_observer() {
227		let mut hasher = Groestl256::default();
228		let mut challenger = HasherChallenger::<Groestl256>::default();
229		let mut observable = [0u8; 1019];
230		thread_rng().fill_bytes(&mut observable);
231		challenger.observer().put_slice(&observable[..39]);
232		challenger.observer().put_slice(&observable[39..300]);
233		challenger.observer().put_slice(&observable[300..987]);
234		challenger.observer().put_slice(&observable[987..]);
235		hasher.update(observable);
236
237		let mut out = [0u8; 7];
238		challenger.sampler().copy_to_slice(&mut out);
239
240		let first_hash_out = hasher.finalize_reset();
241		hasher.update(first_hash_out);
242
243		assert_eq!(first_hash_out[..7], out);
244
245		thread_rng().fill_bytes(&mut observable);
246		challenger.observer().put_slice(&observable[..128]);
247		// updated with index of sampler first
248		hasher.update([7, 0, 0, 0, 0, 0, 0, 0]);
249
250		hasher.update(&observable[..128]);
251
252		let mut out = [0u8; 32];
253		challenger.sampler().copy_to_slice(&mut out);
254
255		let second_hasher_out = hasher.finalize_reset();
256		hasher.update(second_hasher_out);
257
258		assert_eq!(second_hasher_out[..], out);
259
260		challenger.observer().put_slice(&observable[128..]);
261		hasher.update([32, 0, 0, 0, 0, 0, 0, 0]);
262		hasher.update(&observable[128..]);
263
264		let mut out_again = [0u8; 7];
265		challenger.sampler().copy_to_slice(&mut out_again);
266
267		let final_hasher_out = hasher.finalize_reset();
268		assert_eq!(final_hasher_out[..7], out_again);
269	}
270}