binius_core/fiat_shamir/
hasher_challenger.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cmp::min, mem};
4
5use bytes::{buf::UninitSlice, Buf, BufMut};
6use digest::{
7	core_api::{Block, BlockSizeUser},
8	Digest, FixedOutputReset, Output,
9};
10
11use super::Challenger;
12
13/// Challenger type which implements `[Buf]` that has similar functionality as `[CanSample]`
14#[derive(Debug, Default, Clone)]
15pub struct HasherSampler<H: Digest> {
16	index: usize,
17	buffer: Output<H>,
18	hasher: H,
19}
20
21/// Challenger type which implements `[BufMut]` that has similar functionality as `[CanObserve]`
22#[derive(Debug, Default, Clone)]
23pub struct HasherObserver<H: Digest + BlockSizeUser> {
24	index: usize,
25	buffer: Block<H>,
26	hasher: H,
27}
28
29/// Challenger interface over hashes that implement `[Digest]` trait,
30///
31/// This challenger works over bytes instead of Field elements
32#[derive(Debug, Clone)]
33pub enum HasherChallenger<H: Digest + BlockSizeUser> {
34	Observer(HasherObserver<H>),
35	Sampler(HasherSampler<H>),
36}
37
38impl<H> HasherChallenger<H>
39where
40	H: Digest + BlockSizeUser,
41{
42	fn new(initial_digest: Output<H>) -> Self {
43		let mut hasher = H::new();
44		Digest::update(&mut hasher, &initial_digest);
45
46		Self::Sampler(HasherSampler {
47			hasher,
48			index: 0,
49			buffer: initial_digest,
50		})
51	}
52}
53
54impl<H> Default for HasherChallenger<H>
55where
56	H: Digest + BlockSizeUser + FixedOutputReset,
57{
58	fn default() -> Self {
59		Self::new(H::digest([]))
60	}
61}
62
63impl<H: Digest + BlockSizeUser + FixedOutputReset + Default> Challenger for HasherChallenger<H> {
64	/// This returns the inner challenger which implements `[BufMut]`
65	fn observer(&mut self) -> &mut impl BufMut {
66		match self {
67			Self::Observer(observer) => observer,
68			Self::Sampler(sampler) => {
69				*self = Self::Observer(mem::take(sampler).into_observer());
70				match self {
71					Self::Observer(observer) => observer,
72					_ => unreachable!(),
73				}
74			}
75		}
76	}
77
78	/// This returns the inner challenger which implements `[Buf]`
79	fn sampler(&mut self) -> &mut impl Buf {
80		match self {
81			Self::Sampler(sampler) => sampler,
82			Self::Observer(observer) => {
83				*self = Self::Sampler(mem::take(observer).into_sampler());
84				match self {
85					Self::Sampler(sampler) => sampler,
86					_ => unreachable!(),
87				}
88			}
89		}
90	}
91}
92
93impl<H> HasherSampler<H>
94where
95	H: Digest + Default + BlockSizeUser,
96{
97	fn into_observer(mut self) -> HasherObserver<H> {
98		Digest::update(&mut self.hasher, self.index.to_le_bytes());
99
100		HasherObserver {
101			hasher: self.hasher,
102			index: 0,
103			buffer: Block::<H>::default(),
104		}
105	}
106}
107
108impl<H> HasherSampler<H>
109where
110	H: Digest + FixedOutputReset,
111{
112	fn fill_buffer(&mut self) {
113		let digest = self.hasher.finalize_reset();
114
115		// feed forward to the empty state
116		Digest::update(&mut self.hasher, &digest);
117
118		self.buffer = digest;
119		self.index = 0
120	}
121}
122
123impl<H> HasherObserver<H>
124where
125	H: Digest + BlockSizeUser + Default,
126{
127	fn into_sampler(mut self) -> HasherSampler<H> {
128		self.flush();
129		HasherSampler {
130			hasher: self.hasher,
131			index: <H as Digest>::output_size(),
132			buffer: Output::<H>::default(),
133		}
134	}
135}
136
137impl<H> HasherObserver<H>
138where
139	H: Digest + BlockSizeUser,
140{
141	fn flush(&mut self) {
142		self.hasher.update(&self.buffer[..self.index]);
143		self.index = 0
144	}
145}
146
147impl<H> Buf for HasherSampler<H>
148where
149	H: Digest + FixedOutputReset + Default,
150{
151	fn remaining(&self) -> usize {
152		usize::MAX
153	}
154
155	fn chunk(&self) -> &[u8] {
156		&self.buffer[self.index..]
157	}
158
159	fn advance(&mut self, mut cnt: usize) {
160		// Must handle the case when `cnt` is 0
161		if self.index == <H as Digest>::output_size() {
162			self.fill_buffer();
163		}
164
165		while cnt > 0 {
166			let remaining = min(<H as Digest>::output_size() - self.index, cnt);
167			if remaining == 0 {
168				self.fill_buffer();
169				continue;
170			}
171			cnt -= remaining;
172			self.index += remaining;
173		}
174	}
175}
176
177unsafe impl<H> BufMut for HasherObserver<H>
178where
179	H: Digest + BlockSizeUser,
180{
181	fn remaining_mut(&self) -> usize {
182		usize::MAX
183	}
184
185	unsafe fn advance_mut(&mut self, mut cnt: usize) {
186		while cnt > 0 {
187			let remaining = min(<H as BlockSizeUser>::block_size() - self.index, cnt);
188			cnt -= remaining;
189			self.index += remaining;
190			if self.index == <H as BlockSizeUser>::block_size() {
191				self.flush();
192			}
193		}
194	}
195
196	fn chunk_mut(&mut self) -> &mut UninitSlice {
197		let buffer = &mut self.buffer[self.index..];
198		buffer.into()
199	}
200}
201
202#[cfg(test)]
203mod tests {
204	use binius_hash::groestl::Groestl256;
205	use rand::{thread_rng, RngCore};
206
207	use super::*;
208
209	#[test]
210	fn test_starting_sampler() {
211		let empty_string_digest = Groestl256::digest([]);
212
213		let mut hasher = {
214			let mut hasher = Groestl256::default();
215			Digest::update(&mut hasher, empty_string_digest);
216			hasher
217		};
218
219		let mut challenger = HasherChallenger::<Groestl256>::default();
220
221		// first sampling
222		let mut out = [0u8; 8];
223		challenger.sampler().copy_to_slice(&mut out);
224
225		let first_hash_out = empty_string_digest;
226
227		assert_eq!(first_hash_out[0..8], out);
228
229		// second sampling
230		let mut out = [0u8; 24];
231		challenger.sampler().copy_to_slice(&mut out);
232		assert_eq!(first_hash_out[8..], out);
233
234		// first observing
235		challenger.observer().put_slice(&[0x48, 0x55]);
236		hasher.update([32, 0, 0, 0, 0, 0, 0, 0]);
237		hasher.update([0x48, 0x55]);
238
239		// third sampling
240		let mut out_after_observe = [0u8; 2];
241		challenger.sampler().copy_to_slice(&mut out_after_observe);
242
243		let second_hash_out = hasher.finalize_reset();
244
245		assert_eq!(out_after_observe, second_hash_out[..2]);
246	}
247
248	#[test]
249	fn test_starting_observer() {
250		let empty_string_digest = Groestl256::digest([]);
251
252		let mut hasher = {
253			let mut hasher = Groestl256::default();
254			Digest::update(&mut hasher, empty_string_digest);
255			hasher
256		};
257
258		let mut challenger = HasherChallenger::<Groestl256>::default();
259
260		// first observing
261		let mut observable = [0u8; 1019];
262		thread_rng().fill_bytes(&mut observable);
263		challenger.observer().put_slice(&observable[..39]);
264		challenger.observer().put_slice(&observable[39..300]);
265		challenger.observer().put_slice(&observable[300..987]);
266		challenger.observer().put_slice(&observable[987..]);
267		hasher.update([0, 0, 0, 0, 0, 0, 0, 0]);
268		hasher.update(observable);
269
270		// first sampling
271		let mut out = [0u8; 7];
272		challenger.sampler().copy_to_slice(&mut out);
273
274		let first_hash_out = hasher.finalize_reset();
275		hasher.update(first_hash_out);
276
277		assert_eq!(first_hash_out[..7], out);
278
279		// second observing
280		thread_rng().fill_bytes(&mut observable);
281		challenger.observer().put_slice(&observable[..128]);
282		hasher.update([7, 0, 0, 0, 0, 0, 0, 0]);
283
284		hasher.update(&observable[..128]);
285
286		// second sampling
287		let mut out = [0u8; 32];
288		challenger.sampler().copy_to_slice(&mut out);
289
290		let second_hasher_out = hasher.finalize_reset();
291		hasher.update(second_hasher_out);
292
293		assert_eq!(second_hasher_out[..], out);
294
295		// third observing
296		challenger.observer().put_slice(&observable[128..]);
297		hasher.update([32, 0, 0, 0, 0, 0, 0, 0]);
298		hasher.update(&observable[128..]);
299
300		// third sampling
301		let mut out_again = [0u8; 7];
302		challenger.sampler().copy_to_slice(&mut out_again);
303
304		let final_hasher_out = hasher.finalize_reset();
305		assert_eq!(final_hasher_out[..7], out_again);
306	}
307}