1mod error;
16
17use std::{fs::File, io::Write, iter::repeat_with, slice};
18
19use binius_field::{PackedField, TowerField};
20use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes};
21use bytes::{Buf, BufMut, Bytes, BytesMut, buf::UninitSlice};
22pub use error::Error;
23use tracing::warn;
24
25use crate::fiat_shamir::{CanSample, CanSampleBits, Challenger};
26
27#[derive(Debug)]
33pub struct ProverTranscript<Challenger> {
34 combined: FiatShamirBuf<BytesMut, Challenger>,
35 debug_assertions: bool,
36}
37
38#[derive(Debug, Clone)]
44pub struct VerifierTranscript<Challenger> {
45 combined: FiatShamirBuf<Bytes, Challenger>,
46 debug_assertions: bool,
47}
48
49#[derive(Debug, Default, Clone)]
50struct FiatShamirBuf<Inner, Challenger> {
51 buffer: Inner,
52 challenger: Challenger,
53}
54
55impl<Inner: Buf, Challenger_: Challenger> Buf for FiatShamirBuf<Inner, Challenger_> {
56 fn remaining(&self) -> usize {
57 self.buffer.remaining()
58 }
59
60 fn chunk(&self) -> &[u8] {
61 self.buffer.chunk()
62 }
63
64 fn advance(&mut self, cnt: usize) {
65 assert!(cnt <= self.buffer.remaining());
66 let readable = self.buffer.chunk();
68 assert!(cnt <= readable.len());
70 self.challenger.observer().put_slice(&readable[..cnt]);
71 self.buffer.advance(cnt);
72 }
73}
74
75unsafe impl<Inner: BufMut, Challenger_: Challenger> BufMut for FiatShamirBuf<Inner, Challenger_> {
76 fn remaining_mut(&self) -> usize {
77 self.buffer.remaining_mut()
78 }
79
80 unsafe fn advance_mut(&mut self, cnt: usize) {
81 assert!(cnt <= self.buffer.remaining_mut());
82 let written = self.buffer.chunk_mut();
83 assert!(cnt <= written.len());
86
87 let written: &[u8] = unsafe { slice::from_raw_parts(written.as_mut_ptr(), cnt) };
90
91 self.challenger.observer().put_slice(written);
92 unsafe {
93 self.buffer.advance_mut(cnt);
94 }
95 }
96
97 fn chunk_mut(&mut self) -> &mut UninitSlice {
98 self.buffer.chunk_mut()
99 }
100}
101
102impl<Challenger_: Default + Challenger> ProverTranscript<Challenger_> {
103 pub fn new() -> Self {
108 Self {
109 combined: Default::default(),
110 debug_assertions: cfg!(debug_assertions),
111 }
112 }
113
114 pub fn into_verifier(self) -> VerifierTranscript<Challenger_> {
115 let transcript = self.finalize();
116
117 VerifierTranscript::new(transcript)
118 }
119}
120
121impl<Challenger_: Default + Challenger> Default for ProverTranscript<Challenger_> {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127impl<Challenger_: Challenger> ProverTranscript<Challenger_> {
128 pub fn finalize(self) -> Vec<u8> {
129 let transcript = self.combined.buffer.to_vec();
130
131 if let Ok(path) = std::env::var("BINIUS_DUMP_PROOF") {
133 let path = if cfg!(test) {
134 let current_thread = std::thread::current();
137 let test_name = current_thread.name().unwrap_or("unknown");
138 let path = if let Some(stripped) = path.strip_prefix("./") {
141 format!("../../{stripped}",)
142 } else {
143 path
144 };
145 std::fs::create_dir_all(&path)
146 .unwrap_or_else(|_| panic!("Failed to create directories for path: {path}",));
147 format!("{path}/{test_name}.bin")
148 } else {
149 path
150 };
151
152 let mut file = File::create(&path)
153 .unwrap_or_else(|_| panic!("Failed to create proof dump file: {path}"));
154 file.write_all(&transcript)
155 .expect("Failed to write proof to dump file");
156 }
157 transcript
158 }
159
160 pub const fn set_debug(&mut self, debug: bool) {
165 self.debug_assertions = debug;
166 }
167
168 pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b>
173 where
174 'a: 'b,
175 {
176 TranscriptWriter {
177 buffer: self.combined.challenger.observer(),
178 debug_assertions: self.debug_assertions,
179 }
180 }
181
182 pub fn decommitment(&mut self) -> TranscriptWriter<impl BufMut> {
191 TranscriptWriter {
192 buffer: &mut self.combined.buffer,
193 debug_assertions: self.debug_assertions,
194 }
195 }
196
197 pub fn message<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut>
201 where
202 'a: 'b,
203 {
204 TranscriptWriter {
205 buffer: &mut self.combined,
206 debug_assertions: self.debug_assertions,
207 }
208 }
209}
210
211impl<Challenger_: Default + Challenger> VerifierTranscript<Challenger_> {
212 pub fn new(vec: Vec<u8>) -> Self {
213 Self {
214 combined: FiatShamirBuf {
215 challenger: Challenger_::default(),
216 buffer: Bytes::from(vec),
217 },
218 debug_assertions: cfg!(debug_assertions),
219 }
220 }
221}
222
223impl<Challenger_: Challenger> VerifierTranscript<Challenger_> {
224 pub fn finalize(self) -> Result<(), Error> {
225 if self.combined.buffer.has_remaining() {
226 return Err(Error::TranscriptNotEmpty {
227 remaining: self.combined.buffer.remaining(),
228 });
229 }
230 Ok(())
231 }
232
233 pub const fn set_debug(&mut self, debug: bool) {
234 self.debug_assertions = debug;
235 }
236
237 pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b>
242 where
243 'a: 'b,
244 {
245 TranscriptWriter {
246 buffer: self.combined.challenger.observer(),
247 debug_assertions: self.debug_assertions,
248 }
249 }
250
251 pub fn decommitment(&mut self) -> TranscriptReader<impl Buf + '_> {
257 TranscriptReader {
258 buffer: &mut self.combined.buffer,
259 debug_assertions: self.debug_assertions,
260 }
261 }
262
263 pub fn message<'a, 'b>(&'a mut self) -> TranscriptReader<'b, impl Buf>
267 where
268 'a: 'b,
269 {
270 TranscriptReader {
271 buffer: &mut self.combined,
272 debug_assertions: self.debug_assertions,
273 }
274 }
275}
276
277impl<Challenger> Drop for VerifierTranscript<Challenger> {
279 fn drop(&mut self) {
280 if self.combined.buffer.has_remaining() {
281 warn!(
282 "Transcript reader is not fully read out: {:?} bytes left",
283 self.combined.buffer.remaining()
284 )
285 }
286 }
287}
288
289pub struct TranscriptReader<'a, B: Buf> {
290 buffer: &'a mut B,
291 debug_assertions: bool,
292}
293
294impl<B: Buf> TranscriptReader<'_, B> {
295 pub const fn buffer(&mut self) -> &mut B {
296 self.buffer
297 }
298
299 pub fn read<T: DeserializeBytes>(&mut self) -> Result<T, Error> {
300 let mode = SerializationMode::CanonicalTower;
301 T::deserialize(self.buffer(), mode).map_err(Into::into)
302 }
303
304 pub fn read_vec<T: DeserializeBytes>(&mut self, n: usize) -> Result<Vec<T>, Error> {
305 let mode = SerializationMode::CanonicalTower;
306 let mut buffer = self.buffer();
307 repeat_with(move || T::deserialize(&mut buffer, mode).map_err(Into::into))
308 .take(n)
309 .collect()
310 }
311
312 pub fn read_bytes(&mut self, buf: &mut [u8]) -> Result<(), Error> {
313 let buffer = self.buffer();
314 if buffer.remaining() < buf.len() {
315 return Err(Error::NotEnoughBytes);
316 }
317 buffer.copy_to_slice(buf);
318 Ok(())
319 }
320
321 pub fn read_scalar<F: TowerField>(&mut self) -> Result<F, Error> {
322 let mut out = F::default();
323 self.read_scalar_slice_into(slice::from_mut(&mut out))?;
324 Ok(out)
325 }
326
327 pub fn read_scalar_slice_into<F: TowerField>(&mut self, buf: &mut [F]) -> Result<(), Error> {
328 let mut buffer = self.buffer();
329 for elem in buf {
330 let mode = SerializationMode::CanonicalTower;
331 *elem = DeserializeBytes::deserialize(&mut buffer, mode)?;
332 }
333 Ok(())
334 }
335
336 pub fn read_scalar_slice<F: TowerField>(&mut self, len: usize) -> Result<Vec<F>, Error> {
337 let mut elems = vec![F::default(); len];
338 self.read_scalar_slice_into(&mut elems)?;
339 Ok(elems)
340 }
341
342 pub fn read_packed<P: PackedField<Scalar: TowerField>>(&mut self) -> Result<P, Error> {
343 P::try_from_fn(|_| self.read_scalar())
344 }
345
346 pub fn read_packed_slice<P: PackedField<Scalar: TowerField>>(
347 &mut self,
348 len: usize,
349 ) -> Result<Vec<P>, Error> {
350 let mut packed = Vec::with_capacity(len);
351 for _ in 0..len {
352 packed.push(self.read_packed()?);
353 }
354 Ok(packed)
355 }
356
357 pub fn read_debug(&mut self, msg: &str) {
358 if self.debug_assertions {
359 let msg_bytes = msg.as_bytes();
360 let mut buffer = vec![0; msg_bytes.len()];
361 assert!(self.read_bytes(&mut buffer).is_ok());
362 assert_eq!(msg_bytes, buffer);
363 }
364 }
365}
366
367pub struct TranscriptWriter<'a, B: BufMut> {
368 buffer: &'a mut B,
369 debug_assertions: bool,
370}
371
372impl<B: BufMut> TranscriptWriter<'_, B> {
373 pub const fn buffer(&mut self) -> &mut B {
374 self.buffer
375 }
376
377 pub fn write<T: SerializeBytes>(&mut self, value: &T) {
378 self.proof_size_event_wrapper(|buffer| {
379 value
380 .serialize(buffer, SerializationMode::CanonicalTower)
381 .expect("TODO: propagate error");
382 });
383 }
384
385 pub fn write_slice<T: SerializeBytes>(&mut self, values: &[T]) {
386 self.proof_size_event_wrapper(|buffer| {
387 for value in values {
388 value
389 .serialize(&mut *buffer, SerializationMode::CanonicalTower)
390 .expect("TODO: propagate error");
391 }
392 });
393 }
394
395 pub fn write_bytes(&mut self, data: &[u8]) {
396 self.proof_size_event_wrapper(|buffer| {
397 buffer.put_slice(data);
398 });
399 }
400
401 pub fn write_scalar<F: TowerField>(&mut self, f: F) {
402 self.write_scalar_slice(slice::from_ref(&f));
403 }
404
405 pub fn write_scalar_iter<F: TowerField>(&mut self, it: impl IntoIterator<Item = F>) {
406 self.proof_size_event_wrapper(move |buffer| {
407 for elem in it {
408 SerializeBytes::serialize(&elem, &mut *buffer, SerializationMode::CanonicalTower)
409 .expect("TODO: propagate error");
410 }
411 });
412 }
413
414 pub fn write_scalar_slice<F: TowerField>(&mut self, elems: &[F]) {
415 self.write_scalar_iter(elems.iter().copied());
416 }
417
418 pub fn write_packed<P: PackedField<Scalar: TowerField>>(&mut self, packed: P) {
419 self.write_scalar_iter(packed.into_iter());
420 }
421
422 pub fn write_packed_iter<P: PackedField<Scalar: TowerField>>(
423 &mut self,
424 it: impl IntoIterator<Item = P>,
425 ) {
426 self.write_scalar_iter(it.into_iter().flat_map(|packed| packed.into_iter()));
427 }
428
429 pub fn write_packed_slice<P: PackedField<Scalar: TowerField>>(&mut self, packed_slice: &[P]) {
430 self.write_scalar_iter(P::iter_slice(packed_slice));
431 }
432
433 pub fn write_debug(&mut self, msg: &str) {
434 if self.debug_assertions {
435 self.write_bytes(msg.as_bytes())
436 }
437 }
438
439 fn proof_size_event_wrapper<F: FnOnce(&mut B)>(&mut self, f: F) {
440 let buffer = self.buffer();
441 let start_bytes = buffer.remaining_mut();
442 f(buffer);
443 let end_bytes = buffer.remaining_mut();
444 tracing::event!(name: "incremental_proof_size", tracing::Level::INFO, counter=true, incremental=true, value=start_bytes - end_bytes);
445 }
446}
447
448impl<F, Challenger_> CanSample<F> for VerifierTranscript<Challenger_>
449where
450 F: TowerField,
451 Challenger_: Challenger,
452{
453 fn sample(&mut self) -> F {
454 let mode = SerializationMode::CanonicalTower;
455 DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode)
456 .expect("challenger has infinite buffer")
457 }
458}
459
460impl<F, Challenger_> CanSample<F> for ProverTranscript<Challenger_>
461where
462 F: TowerField,
463 Challenger_: Challenger,
464{
465 fn sample(&mut self) -> F {
466 let mode = SerializationMode::CanonicalTower;
467 DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode)
468 .expect("challenger has infinite buffer")
469 }
470}
471
472fn sample_bits_reader<Reader: Buf>(mut reader: Reader, bits: usize) -> u32 {
473 let bits = bits.min(u32::BITS as usize);
474
475 let bytes_to_sample: usize = std::mem::size_of::<u32>();
476
477 let mut bytes = [0u8; std::mem::size_of::<u32>()];
478
479 reader.copy_to_slice(&mut bytes[..bytes_to_sample]);
480
481 let unmasked = u32::from_le_bytes(bytes);
482 let mask = 1u32.checked_shl(bits as u32);
483 let mask = match mask {
484 Some(x) => x - 1,
485 None => u32::MAX,
486 };
487 mask & unmasked
488}
489
490impl<Challenger_> CanSampleBits<u32> for VerifierTranscript<Challenger_>
491where
492 Challenger_: Challenger,
493{
494 fn sample_bits(&mut self, bits: usize) -> u32 {
495 sample_bits_reader(self.combined.challenger.sampler(), bits)
496 }
497}
498
499impl<Challenger_> CanSampleBits<u32> for ProverTranscript<Challenger_>
500where
501 Challenger_: Challenger,
502{
503 fn sample_bits(&mut self, bits: usize) -> u32 {
504 sample_bits_reader(self.combined.challenger.sampler(), bits)
505 }
506}
507
508pub fn read_u64<B: Buf>(transcript: &mut TranscriptReader<B>) -> Result<u64, Error> {
510 let mut as_bytes = [0; size_of::<u64>()];
511 transcript.read_bytes(&mut as_bytes)?;
512 Ok(u64::from_le_bytes(as_bytes))
513}
514
515pub fn write_u64<B: BufMut>(transcript: &mut TranscriptWriter<B>, n: u64) {
516 transcript.write_bytes(&n.to_le_bytes());
517}
518
519#[cfg(test)]
520mod tests {
521 use binius_field::{
522 AESTowerField8b, AESTowerField16b, AESTowerField32b, AESTowerField128b, BinaryField8b,
523 BinaryField32b, BinaryField64b, BinaryField128b, BinaryField128bPolyval,
524 };
525 use binius_hash::groestl::Groestl256;
526 use rand::{RngCore, thread_rng};
527
528 use super::*;
529 use crate::fiat_shamir::HasherChallenger;
530
531 #[test]
532 fn test_transcript_interactions() {
533 let mut prover_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
534 let mut writable = prover_transcript.message();
535
536 writable.write_scalar(BinaryField8b::new(0x96));
537 writable.write_scalar(BinaryField32b::new(0xDEADBEEF));
538 writable.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB));
539 let sampled_fanpaar1: BinaryField128b = prover_transcript.sample();
540
541 let mut writable = prover_transcript.message();
542
543 writable.write_scalar(AESTowerField8b::new(0x52));
544 writable.write_scalar(AESTowerField32b::new(0x12345678));
545 writable.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777));
546
547 let sampled_aes1: AESTowerField16b = prover_transcript.sample();
548
549 prover_transcript
550 .message()
551 .write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
552 let sampled_polyval1: BinaryField128bPolyval = prover_transcript.sample();
553
554 let mut verifier_transcript = prover_transcript.into_verifier();
555 let mut readable = verifier_transcript.message();
556
557 let fp_8: BinaryField8b = readable.read_scalar().unwrap();
558 let fp_32: BinaryField32b = readable.read_scalar().unwrap();
559 let fp_128: BinaryField128b = readable.read_scalar().unwrap();
560
561 assert_eq!(fp_8.val(), 0x96);
562 assert_eq!(fp_32.val(), 0xDEADBEEF);
563 assert_eq!(fp_128.val(), 0x55669900112233550000CCDDFFEEAABB);
564
565 let sampled_fanpaar1_res: BinaryField128b = verifier_transcript.sample();
566
567 assert_eq!(sampled_fanpaar1_res, sampled_fanpaar1);
568
569 let mut readable = verifier_transcript.message();
570
571 let aes_8: AESTowerField8b = readable.read_scalar().unwrap();
572 let aes_32: AESTowerField32b = readable.read_scalar().unwrap();
573 let aes_128: AESTowerField128b = readable.read_scalar().unwrap();
574
575 assert_eq!(aes_8.val(), 0x52);
576 assert_eq!(aes_32.val(), 0x12345678);
577 assert_eq!(aes_128.val(), 0xDDDDBBBBCCCCAAAA2222999911117777);
578
579 let sampled_aes_res: AESTowerField16b = verifier_transcript.sample();
580
581 assert_eq!(sampled_aes_res, sampled_aes1);
582
583 let polyval_128: BinaryField128bPolyval =
584 verifier_transcript.message().read_scalar().unwrap();
585 assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
586
587 let sampled_polyval_res: BinaryField128bPolyval = verifier_transcript.sample();
588 assert_eq!(sampled_polyval_res, sampled_polyval1);
589
590 verifier_transcript.finalize().unwrap();
591 }
592
593 #[test]
594 fn test_advising() {
595 let mut prover_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
596 let mut advice_writer = prover_transcript.decommitment();
597
598 advice_writer.write_scalar(BinaryField8b::new(0x96));
599 advice_writer.write_scalar(BinaryField32b::new(0xDEADBEEF));
600 advice_writer.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB));
601
602 advice_writer.write_scalar(AESTowerField8b::new(0x52));
603 advice_writer.write_scalar(AESTowerField32b::new(0x12345678));
604 advice_writer.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777));
605
606 advice_writer.write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
607
608 let mut verifier_transcript = prover_transcript.into_verifier();
609 let mut advice_reader = verifier_transcript.decommitment();
610
611 let fp_8: BinaryField8b = advice_reader.read_scalar().unwrap();
612 let fp_32: BinaryField32b = advice_reader.read_scalar().unwrap();
613 let fp_128: BinaryField128b = advice_reader.read_scalar().unwrap();
614
615 assert_eq!(fp_8.val(), 0x96);
616 assert_eq!(fp_32.val(), 0xDEADBEEF);
617 assert_eq!(fp_128.val(), 0x55669900112233550000CCDDFFEEAABB);
618
619 let aes_8: AESTowerField8b = advice_reader.read_scalar().unwrap();
620 let aes_32: AESTowerField32b = advice_reader.read_scalar().unwrap();
621 let aes_128: AESTowerField128b = advice_reader.read_scalar().unwrap();
622
623 assert_eq!(aes_8.val(), 0x52);
624 assert_eq!(aes_32.val(), 0x12345678);
625 assert_eq!(aes_128.val(), 0xDDDDBBBBCCCCAAAA2222999911117777);
626
627 let polyval_128: BinaryField128bPolyval = advice_reader.read_scalar().unwrap();
628 assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
629
630 verifier_transcript.finalize().unwrap();
631 }
632
633 #[test]
634 fn test_challenger_and_observing() {
635 let mut taped_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
636 let mut untaped_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
637 let mut challenger = HasherChallenger::<Groestl256>::default();
638
639 const NUM_SAMPLING: usize = 32;
640 let mut random_bytes = [0u8; NUM_SAMPLING * 8];
641 thread_rng().fill_bytes(&mut random_bytes);
642 let mut sampled_arrays = [[0u8; 8]; NUM_SAMPLING];
643
644 for i in 0..NUM_SAMPLING {
645 taped_transcript
646 .message()
647 .write_scalar(BinaryField64b::new(u64::from_le_bytes(
648 random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(),
649 )));
650 untaped_transcript
651 .observe()
652 .write_scalar(BinaryField64b::new(u64::from_le_bytes(
653 random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(),
654 )));
655 challenger
656 .observer()
657 .put_slice(&random_bytes[i * 8..i * 8 + 8]);
658
659 let sampled_out_transcript1: BinaryField64b = taped_transcript.sample();
660 let sampled_out_transcript2: BinaryField64b = untaped_transcript.sample();
661 let mut challenger_out = [0u8; 8];
662 challenger.sampler().copy_to_slice(&mut challenger_out);
663 assert_eq!(challenger_out, sampled_out_transcript1.val().to_le_bytes());
664 assert_eq!(challenger_out, sampled_out_transcript2.val().to_le_bytes());
665 sampled_arrays[i] = challenger_out;
666 }
667
668 let mut taped_transcript = taped_transcript.into_verifier();
669
670 assert!(untaped_transcript.finalize().is_empty());
671
672 for array in sampled_arrays {
673 let _: BinaryField64b = taped_transcript.message().read_scalar().unwrap();
674 let sampled_out_transcript: BinaryField64b = taped_transcript.sample();
675
676 assert_eq!(array, sampled_out_transcript.val().to_le_bytes());
677 }
678
679 taped_transcript.finalize().unwrap();
680 }
681
682 #[test]
683 fn test_transcript_debug() {
684 let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
685
686 transcript.message().write_debug("test_transcript_debug");
687 transcript
688 .into_verifier()
689 .message()
690 .read_debug("test_transcript_debug");
691 }
692
693 #[test]
694 #[should_panic]
695 fn test_transcript_debug_fail() {
696 let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
697
698 transcript.message().write_debug("test_transcript_debug");
699 transcript
700 .into_verifier()
701 .message()
702 .read_debug("test_transcript_debug_should_fail");
703 }
704}