1mod error;
16
17use std::{iter::repeat_with, slice};
18
19use binius_field::{PackedField, TowerField};
20use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes};
21use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut};
22pub use error::Error;
23use tracing::warn;
24
25use crate::fiat_shamir::{CanSample, CanSampleBits, Challenger};
26
27#[derive(Debug)]
32pub struct ProverTranscript<Challenger> {
33 combined: FiatShamirBuf<BytesMut, Challenger>,
34 debug_assertions: bool,
35}
36
37#[derive(Debug)]
42pub struct VerifierTranscript<Challenger> {
43 combined: FiatShamirBuf<Bytes, Challenger>,
44 debug_assertions: bool,
45}
46
47#[derive(Debug, Default)]
48struct FiatShamirBuf<Inner, Challenger> {
49 buffer: Inner,
50 challenger: Challenger,
51}
52
53impl<Inner: Buf, Challenger_: Challenger> Buf for FiatShamirBuf<Inner, Challenger_> {
54 fn remaining(&self) -> usize {
55 self.buffer.remaining()
56 }
57
58 fn chunk(&self) -> &[u8] {
59 self.buffer.chunk()
60 }
61
62 fn advance(&mut self, cnt: usize) {
63 assert!(cnt <= self.buffer.remaining());
64 let readable = self.buffer.chunk();
66 assert!(cnt <= readable.len());
68 self.challenger.observer().put_slice(&readable[..cnt]);
69 self.buffer.advance(cnt);
70 }
71}
72
73unsafe impl<Inner: BufMut, Challenger_: Challenger> BufMut for FiatShamirBuf<Inner, Challenger_> {
74 fn remaining_mut(&self) -> usize {
75 self.buffer.remaining_mut()
76 }
77
78 unsafe fn advance_mut(&mut self, cnt: usize) {
79 assert!(cnt <= self.buffer.remaining_mut());
80 let written = self.buffer.chunk_mut();
81 assert!(cnt <= written.len());
83
84 let written: &[u8] = slice::from_raw_parts(written.as_mut_ptr(), cnt);
87
88 self.challenger.observer().put_slice(written);
89 self.buffer.advance_mut(cnt);
90 }
91
92 fn chunk_mut(&mut self) -> &mut UninitSlice {
93 self.buffer.chunk_mut()
94 }
95}
96
97impl<Challenger_: Default + Challenger> ProverTranscript<Challenger_> {
98 pub fn new() -> Self {
103 Self {
104 combined: Default::default(),
105 debug_assertions: cfg!(debug_assertions),
106 }
107 }
108
109 pub fn into_verifier(self) -> VerifierTranscript<Challenger_> {
110 VerifierTranscript::new(self.finalize())
111 }
112}
113
114impl<Challenger_: Default + Challenger> Default for ProverTranscript<Challenger_> {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120impl<Challenger_: Challenger> ProverTranscript<Challenger_> {
121 pub fn finalize(self) -> Vec<u8> {
122 self.combined.buffer.to_vec()
123 }
124
125 pub const fn set_debug(&mut self, debug: bool) {
130 self.debug_assertions = debug;
131 }
132
133 pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b>
138 where
139 'a: 'b,
140 {
141 TranscriptWriter {
142 buffer: self.combined.challenger.observer(),
143 debug_assertions: self.debug_assertions,
144 }
145 }
146
147 pub fn decommitment(&mut self) -> TranscriptWriter<impl BufMut> {
155 TranscriptWriter {
156 buffer: &mut self.combined.buffer,
157 debug_assertions: self.debug_assertions,
158 }
159 }
160
161 pub fn message<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut>
165 where
166 'a: 'b,
167 {
168 TranscriptWriter {
169 buffer: &mut self.combined,
170 debug_assertions: self.debug_assertions,
171 }
172 }
173}
174
175impl<Challenger_: Default + Challenger> VerifierTranscript<Challenger_> {
176 pub fn new(vec: Vec<u8>) -> Self {
177 Self {
178 combined: FiatShamirBuf {
179 challenger: Challenger_::default(),
180 buffer: Bytes::from(vec),
181 },
182 debug_assertions: cfg!(debug_assertions),
183 }
184 }
185}
186
187impl<Challenger_: Challenger> VerifierTranscript<Challenger_> {
188 pub fn finalize(self) -> Result<(), Error> {
189 if self.combined.buffer.has_remaining() {
190 return Err(Error::TranscriptNotEmpty {
191 remaining: self.combined.buffer.remaining(),
192 });
193 }
194 Ok(())
195 }
196
197 pub const fn set_debug(&mut self, debug: bool) {
198 self.debug_assertions = debug;
199 }
200
201 pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b>
206 where
207 'a: 'b,
208 {
209 TranscriptWriter {
210 buffer: self.combined.challenger.observer(),
211 debug_assertions: self.debug_assertions,
212 }
213 }
214
215 pub fn decommitment(&mut self) -> TranscriptReader<impl Buf + '_> {
219 TranscriptReader {
220 buffer: &mut self.combined.buffer,
221 debug_assertions: self.debug_assertions,
222 }
223 }
224
225 pub fn message<'a, 'b>(&'a mut self) -> TranscriptReader<'b, impl Buf>
229 where
230 'a: 'b,
231 {
232 TranscriptReader {
233 buffer: &mut self.combined,
234 debug_assertions: self.debug_assertions,
235 }
236 }
237}
238
239impl<Challenger> Drop for VerifierTranscript<Challenger> {
241 fn drop(&mut self) {
242 if self.combined.buffer.has_remaining() {
243 warn!(
244 "Transcript reader is not fully read out: {:?} bytes left",
245 self.combined.buffer.remaining()
246 )
247 }
248 }
249}
250
251pub struct TranscriptReader<'a, B: Buf> {
252 buffer: &'a mut B,
253 debug_assertions: bool,
254}
255
256impl<B: Buf> TranscriptReader<'_, B> {
257 pub const fn buffer(&mut self) -> &mut B {
258 self.buffer
259 }
260
261 pub fn read<T: DeserializeBytes>(&mut self) -> Result<T, Error> {
262 let mode = SerializationMode::CanonicalTower;
263 T::deserialize(self.buffer(), mode).map_err(Into::into)
264 }
265
266 pub fn read_vec<T: DeserializeBytes>(&mut self, n: usize) -> Result<Vec<T>, Error> {
267 let mode = SerializationMode::CanonicalTower;
268 let mut buffer = self.buffer();
269 repeat_with(move || T::deserialize(&mut buffer, mode).map_err(Into::into))
270 .take(n)
271 .collect()
272 }
273
274 pub fn read_bytes(&mut self, buf: &mut [u8]) -> Result<(), Error> {
275 let buffer = self.buffer();
276 if buffer.remaining() < buf.len() {
277 return Err(Error::NotEnoughBytes);
278 }
279 buffer.copy_to_slice(buf);
280 Ok(())
281 }
282
283 pub fn read_scalar<F: TowerField>(&mut self) -> Result<F, Error> {
284 let mut out = F::default();
285 self.read_scalar_slice_into(slice::from_mut(&mut out))?;
286 Ok(out)
287 }
288
289 pub fn read_scalar_slice_into<F: TowerField>(&mut self, buf: &mut [F]) -> Result<(), Error> {
290 let mut buffer = self.buffer();
291 for elem in buf {
292 let mode = SerializationMode::CanonicalTower;
293 *elem = DeserializeBytes::deserialize(&mut buffer, mode)?;
294 }
295 Ok(())
296 }
297
298 pub fn read_scalar_slice<F: TowerField>(&mut self, len: usize) -> Result<Vec<F>, Error> {
299 let mut elems = vec![F::default(); len];
300 self.read_scalar_slice_into(&mut elems)?;
301 Ok(elems)
302 }
303
304 pub fn read_packed<P: PackedField<Scalar: TowerField>>(&mut self) -> Result<P, Error> {
305 P::try_from_fn(|_| self.read_scalar())
306 }
307
308 pub fn read_packed_slice<P: PackedField<Scalar: TowerField>>(
309 &mut self,
310 len: usize,
311 ) -> Result<Vec<P>, Error> {
312 let mut packed = Vec::with_capacity(len);
313 for _ in 0..len {
314 packed.push(self.read_packed()?);
315 }
316 Ok(packed)
317 }
318
319 pub fn read_debug(&mut self, msg: &str) {
320 if self.debug_assertions {
321 let msg_bytes = msg.as_bytes();
322 let mut buffer = vec![0; msg_bytes.len()];
323 assert!(self.read_bytes(&mut buffer).is_ok());
324 assert_eq!(msg_bytes, buffer);
325 }
326 }
327}
328
329pub struct TranscriptWriter<'a, B: BufMut> {
330 buffer: &'a mut B,
331 debug_assertions: bool,
332}
333
334impl<B: BufMut> TranscriptWriter<'_, B> {
335 pub const fn buffer(&mut self) -> &mut B {
336 self.buffer
337 }
338
339 pub fn write<T: SerializeBytes>(&mut self, value: &T) {
340 self.proof_size_event_wrapper(|buffer| {
341 value
342 .serialize(buffer, SerializationMode::CanonicalTower)
343 .expect("TODO: propagate error");
344 });
345 }
346
347 pub fn write_slice<T: SerializeBytes>(&mut self, values: &[T]) {
348 self.proof_size_event_wrapper(|buffer| {
349 for value in values {
350 value
351 .serialize(&mut *buffer, SerializationMode::CanonicalTower)
352 .expect("TODO: propagate error");
353 }
354 });
355 }
356
357 pub fn write_bytes(&mut self, data: &[u8]) {
358 self.proof_size_event_wrapper(|buffer| {
359 buffer.put_slice(data);
360 });
361 }
362
363 pub fn write_scalar<F: TowerField>(&mut self, f: F) {
364 self.write_scalar_slice(slice::from_ref(&f));
365 }
366
367 pub fn write_scalar_slice<F: TowerField>(&mut self, elems: &[F]) {
368 self.proof_size_event_wrapper(|buffer| {
369 for elem in elems {
370 SerializeBytes::serialize(elem, &mut *buffer, SerializationMode::CanonicalTower)
371 .expect("TODO: propagate error");
372 }
373 });
374 }
375
376 pub fn write_packed<P: PackedField<Scalar: TowerField>>(&mut self, packed: P) {
377 for scalar in packed.iter() {
378 self.write_scalar(scalar);
379 }
380 }
381
382 pub fn write_packed_slice<P: PackedField<Scalar: TowerField>>(&mut self, packed_slice: &[P]) {
383 for &packed in packed_slice {
384 self.write_packed(packed)
385 }
386 }
387
388 pub fn write_debug(&mut self, msg: &str) {
389 if self.debug_assertions {
390 self.write_bytes(msg.as_bytes())
391 }
392 }
393
394 fn proof_size_event_wrapper<F: Fn(&mut B)>(&mut self, f: F) {
395 let buffer = self.buffer();
396 let start_bytes = buffer.remaining_mut();
397 f(buffer);
398 let end_bytes = buffer.remaining_mut();
399 tracing::event!(name: "proof_size", tracing::Level::INFO, counter=true, incremental=true, value=start_bytes - end_bytes);
400 }
401}
402
403impl<F, Challenger_> CanSample<F> for VerifierTranscript<Challenger_>
404where
405 F: TowerField,
406 Challenger_: Challenger,
407{
408 fn sample(&mut self) -> F {
409 let mode = SerializationMode::CanonicalTower;
410 DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode)
411 .expect("challenger has infinite buffer")
412 }
413}
414
415impl<F, Challenger_> CanSample<F> for ProverTranscript<Challenger_>
416where
417 F: TowerField,
418 Challenger_: Challenger,
419{
420 fn sample(&mut self) -> F {
421 let mode = SerializationMode::CanonicalTower;
422 DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode)
423 .expect("challenger has infinite buffer")
424 }
425}
426
427fn sample_bits_reader<Reader: Buf>(mut reader: Reader, bits: usize) -> usize {
428 let bits = bits.min(usize::BITS as usize);
429
430 let bytes_to_sample = bits.div_ceil(8);
431
432 let mut bytes = [0u8; std::mem::size_of::<usize>()];
433
434 reader.copy_to_slice(&mut bytes[..bytes_to_sample]);
435
436 let unmasked = usize::from_le_bytes(bytes);
437 let mask = 1usize.checked_shl(bits as u32);
438 let mask = match mask {
439 Some(x) => x - 1,
440 None => usize::MAX,
441 };
442 mask & unmasked
443}
444
445impl<Challenger_> CanSampleBits<usize> for VerifierTranscript<Challenger_>
446where
447 Challenger_: Challenger,
448{
449 fn sample_bits(&mut self, bits: usize) -> usize {
450 sample_bits_reader(self.combined.challenger.sampler(), bits)
451 }
452}
453
454impl<Challenger_> CanSampleBits<usize> for ProverTranscript<Challenger_>
455where
456 Challenger_: Challenger,
457{
458 fn sample_bits(&mut self, bits: usize) -> usize {
459 sample_bits_reader(self.combined.challenger.sampler(), bits)
460 }
461}
462
463pub fn read_u64<B: Buf>(transcript: &mut TranscriptReader<B>) -> Result<u64, Error> {
465 let mut as_bytes = [0; size_of::<u64>()];
466 transcript.read_bytes(&mut as_bytes)?;
467 Ok(u64::from_le_bytes(as_bytes))
468}
469
470pub fn write_u64<B: BufMut>(transcript: &mut TranscriptWriter<B>, n: u64) {
471 transcript.write_bytes(&n.to_le_bytes());
472}
473
474#[cfg(test)]
475mod tests {
476 use binius_field::{
477 AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField8b, BinaryField128b,
478 BinaryField128bPolyval, BinaryField32b, BinaryField64b, BinaryField8b,
479 };
480 use groestl_crypto::Groestl256;
481 use rand::{thread_rng, RngCore};
482
483 use super::*;
484 use crate::fiat_shamir::HasherChallenger;
485
486 #[test]
487 fn test_transcripting() {
488 let mut prover_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
489 let mut writable = prover_transcript.message();
490
491 writable.write_scalar(BinaryField8b::new(0x96));
492 writable.write_scalar(BinaryField32b::new(0xDEADBEEF));
493 writable.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB));
494 let sampled_fanpaar1: BinaryField128b = prover_transcript.sample();
495
496 let mut writable = prover_transcript.message();
497
498 writable.write_scalar(AESTowerField8b::new(0x52));
499 writable.write_scalar(AESTowerField32b::new(0x12345678));
500 writable.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777));
501
502 let sampled_aes1: AESTowerField16b = prover_transcript.sample();
503
504 prover_transcript
505 .message()
506 .write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
507 let sampled_polyval1: BinaryField128bPolyval = prover_transcript.sample();
508
509 let mut verifier_transcript = prover_transcript.into_verifier();
510 let mut readable = verifier_transcript.message();
511
512 let fp_8: BinaryField8b = readable.read_scalar().unwrap();
513 let fp_32: BinaryField32b = readable.read_scalar().unwrap();
514 let fp_128: BinaryField128b = readable.read_scalar().unwrap();
515
516 assert_eq!(fp_8.val(), 0x96);
517 assert_eq!(fp_32.val(), 0xDEADBEEF);
518 assert_eq!(fp_128.val(), 0x55669900112233550000CCDDFFEEAABB);
519
520 let sampled_fanpaar1_res: BinaryField128b = verifier_transcript.sample();
521
522 assert_eq!(sampled_fanpaar1_res, sampled_fanpaar1);
523
524 let mut readable = verifier_transcript.message();
525
526 let aes_8: AESTowerField8b = readable.read_scalar().unwrap();
527 let aes_32: AESTowerField32b = readable.read_scalar().unwrap();
528 let aes_128: AESTowerField128b = readable.read_scalar().unwrap();
529
530 assert_eq!(aes_8.val(), 0x52);
531 assert_eq!(aes_32.val(), 0x12345678);
532 assert_eq!(aes_128.val(), 0xDDDDBBBBCCCCAAAA2222999911117777);
533
534 let sampled_aes_res: AESTowerField16b = verifier_transcript.sample();
535
536 assert_eq!(sampled_aes_res, sampled_aes1);
537
538 let polyval_128: BinaryField128bPolyval =
539 verifier_transcript.message().read_scalar().unwrap();
540 assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
541
542 let sampled_polyval_res: BinaryField128bPolyval = verifier_transcript.sample();
543 assert_eq!(sampled_polyval_res, sampled_polyval1);
544
545 verifier_transcript.finalize().unwrap();
546 }
547
548 #[test]
549 fn test_advicing() {
550 let mut prover_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
551 let mut advice_writer = prover_transcript.decommitment();
552
553 advice_writer.write_scalar(BinaryField8b::new(0x96));
554 advice_writer.write_scalar(BinaryField32b::new(0xDEADBEEF));
555 advice_writer.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB));
556
557 advice_writer.write_scalar(AESTowerField8b::new(0x52));
558 advice_writer.write_scalar(AESTowerField32b::new(0x12345678));
559 advice_writer.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777));
560
561 advice_writer.write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
562
563 let mut verifier_transcript = prover_transcript.into_verifier();
564 let mut advice_reader = verifier_transcript.decommitment();
565
566 let fp_8: BinaryField8b = advice_reader.read_scalar().unwrap();
567 let fp_32: BinaryField32b = advice_reader.read_scalar().unwrap();
568 let fp_128: BinaryField128b = advice_reader.read_scalar().unwrap();
569
570 assert_eq!(fp_8.val(), 0x96);
571 assert_eq!(fp_32.val(), 0xDEADBEEF);
572 assert_eq!(fp_128.val(), 0x55669900112233550000CCDDFFEEAABB);
573
574 let aes_8: AESTowerField8b = advice_reader.read_scalar().unwrap();
575 let aes_32: AESTowerField32b = advice_reader.read_scalar().unwrap();
576 let aes_128: AESTowerField128b = advice_reader.read_scalar().unwrap();
577
578 assert_eq!(aes_8.val(), 0x52);
579 assert_eq!(aes_32.val(), 0x12345678);
580 assert_eq!(aes_128.val(), 0xDDDDBBBBCCCCAAAA2222999911117777);
581
582 let polyval_128: BinaryField128bPolyval = advice_reader.read_scalar().unwrap();
583 assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
584
585 verifier_transcript.finalize().unwrap();
586 }
587
588 #[test]
589 fn test_challenger_and_observing() {
590 let mut taped_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
591 let mut untaped_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
592 let mut challenger = HasherChallenger::<Groestl256>::default();
593
594 const NUM_SAMPLING: usize = 32;
595 let mut random_bytes = [0u8; NUM_SAMPLING * 8];
596 thread_rng().fill_bytes(&mut random_bytes);
597 let mut sampled_arrays = [[0u8; 8]; NUM_SAMPLING];
598
599 for i in 0..NUM_SAMPLING {
600 taped_transcript
601 .message()
602 .write_scalar(BinaryField64b::new(u64::from_le_bytes(
603 random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(),
604 )));
605 untaped_transcript
606 .observe()
607 .write_scalar(BinaryField64b::new(u64::from_le_bytes(
608 random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(),
609 )));
610 challenger
611 .observer()
612 .put_slice(&random_bytes[i * 8..i * 8 + 8]);
613
614 let sampled_out_transcript1: BinaryField64b = taped_transcript.sample();
615 let sampled_out_transcript2: BinaryField64b = untaped_transcript.sample();
616 let mut challenger_out = [0u8; 8];
617 challenger.sampler().copy_to_slice(&mut challenger_out);
618 assert_eq!(challenger_out, sampled_out_transcript1.val().to_le_bytes());
619 assert_eq!(challenger_out, sampled_out_transcript2.val().to_le_bytes());
620 sampled_arrays[i] = challenger_out;
621 }
622
623 let mut taped_transcript = taped_transcript.into_verifier();
624
625 assert!(untaped_transcript.finalize().is_empty());
626
627 for array in sampled_arrays {
628 let _: BinaryField64b = taped_transcript.message().read_scalar().unwrap();
629 let sampled_out_transcript: BinaryField64b = taped_transcript.sample();
630
631 assert_eq!(array, sampled_out_transcript.val().to_le_bytes());
632 }
633
634 taped_transcript.finalize().unwrap();
635 }
636
637 #[test]
638 fn test_transcript_debug() {
639 let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
640
641 transcript.message().write_debug("test_transcript_debug");
642 transcript
643 .into_verifier()
644 .message()
645 .read_debug("test_transcript_debug");
646 }
647
648 #[test]
649 #[should_panic]
650 fn test_transcript_debug_fail() {
651 let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
652
653 transcript.message().write_debug("test_transcript_debug");
654 transcript
655 .into_verifier()
656 .message()
657 .read_debug("test_transcript_debug_should_fail");
658 }
659}