1mod error;
16
17use std::{iter::repeat_with, slice};
18
19use binius_field::{PackedField, TowerField};
20use binius_utils::{DeserializeBytes, SerializationMode, SerializeBytes};
21use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut};
22pub use error::Error;
23use tracing::warn;
24
25use crate::fiat_shamir::{CanSample, CanSampleBits, Challenger};
26
27#[derive(Debug)]
32pub struct ProverTranscript<Challenger> {
33 combined: FiatShamirBuf<BytesMut, Challenger>,
34 debug_assertions: bool,
35}
36
37#[derive(Debug, Clone)]
42pub struct VerifierTranscript<Challenger> {
43 combined: FiatShamirBuf<Bytes, Challenger>,
44 debug_assertions: bool,
45}
46
47#[derive(Debug, Default, Clone)]
48struct FiatShamirBuf<Inner, Challenger> {
49 buffer: Inner,
50 challenger: Challenger,
51}
52
53impl<Inner: Buf, Challenger_: Challenger> Buf for FiatShamirBuf<Inner, Challenger_> {
54 fn remaining(&self) -> usize {
55 self.buffer.remaining()
56 }
57
58 fn chunk(&self) -> &[u8] {
59 self.buffer.chunk()
60 }
61
62 fn advance(&mut self, cnt: usize) {
63 assert!(cnt <= self.buffer.remaining());
64 let readable = self.buffer.chunk();
66 assert!(cnt <= readable.len());
68 self.challenger.observer().put_slice(&readable[..cnt]);
69 self.buffer.advance(cnt);
70 }
71}
72
73unsafe impl<Inner: BufMut, Challenger_: Challenger> BufMut for FiatShamirBuf<Inner, Challenger_> {
74 fn remaining_mut(&self) -> usize {
75 self.buffer.remaining_mut()
76 }
77
78 unsafe fn advance_mut(&mut self, cnt: usize) {
79 assert!(cnt <= self.buffer.remaining_mut());
80 let written = self.buffer.chunk_mut();
81 assert!(cnt <= written.len());
83
84 let written: &[u8] = slice::from_raw_parts(written.as_mut_ptr(), cnt);
87
88 self.challenger.observer().put_slice(written);
89 self.buffer.advance_mut(cnt);
90 }
91
92 fn chunk_mut(&mut self) -> &mut UninitSlice {
93 self.buffer.chunk_mut()
94 }
95}
96
97impl<Challenger_: Default + Challenger> ProverTranscript<Challenger_> {
98 pub fn new() -> Self {
103 Self {
104 combined: Default::default(),
105 debug_assertions: cfg!(debug_assertions),
106 }
107 }
108
109 pub fn into_verifier(self) -> VerifierTranscript<Challenger_> {
110 VerifierTranscript::new(self.finalize())
111 }
112}
113
114impl<Challenger_: Default + Challenger> Default for ProverTranscript<Challenger_> {
115 fn default() -> Self {
116 Self::new()
117 }
118}
119
120impl<Challenger_: Challenger> ProverTranscript<Challenger_> {
121 pub fn finalize(self) -> Vec<u8> {
122 self.combined.buffer.to_vec()
123 }
124
125 pub const fn set_debug(&mut self, debug: bool) {
130 self.debug_assertions = debug;
131 }
132
133 pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b>
138 where
139 'a: 'b,
140 {
141 TranscriptWriter {
142 buffer: self.combined.challenger.observer(),
143 debug_assertions: self.debug_assertions,
144 }
145 }
146
147 pub fn decommitment(&mut self) -> TranscriptWriter<impl BufMut> {
155 TranscriptWriter {
156 buffer: &mut self.combined.buffer,
157 debug_assertions: self.debug_assertions,
158 }
159 }
160
161 pub fn message<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut>
165 where
166 'a: 'b,
167 {
168 TranscriptWriter {
169 buffer: &mut self.combined,
170 debug_assertions: self.debug_assertions,
171 }
172 }
173}
174
175impl<Challenger_: Default + Challenger> VerifierTranscript<Challenger_> {
176 pub fn new(vec: Vec<u8>) -> Self {
177 Self {
178 combined: FiatShamirBuf {
179 challenger: Challenger_::default(),
180 buffer: Bytes::from(vec),
181 },
182 debug_assertions: cfg!(debug_assertions),
183 }
184 }
185}
186
187impl<Challenger_: Challenger> VerifierTranscript<Challenger_> {
188 pub fn finalize(self) -> Result<(), Error> {
189 if self.combined.buffer.has_remaining() {
190 return Err(Error::TranscriptNotEmpty {
191 remaining: self.combined.buffer.remaining(),
192 });
193 }
194 Ok(())
195 }
196
197 pub const fn set_debug(&mut self, debug: bool) {
198 self.debug_assertions = debug;
199 }
200
201 pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b>
206 where
207 'a: 'b,
208 {
209 TranscriptWriter {
210 buffer: self.combined.challenger.observer(),
211 debug_assertions: self.debug_assertions,
212 }
213 }
214
215 pub fn decommitment(&mut self) -> TranscriptReader<impl Buf + '_> {
219 TranscriptReader {
220 buffer: &mut self.combined.buffer,
221 debug_assertions: self.debug_assertions,
222 }
223 }
224
225 pub fn message<'a, 'b>(&'a mut self) -> TranscriptReader<'b, impl Buf>
229 where
230 'a: 'b,
231 {
232 TranscriptReader {
233 buffer: &mut self.combined,
234 debug_assertions: self.debug_assertions,
235 }
236 }
237}
238
239impl<Challenger> Drop for VerifierTranscript<Challenger> {
241 fn drop(&mut self) {
242 if self.combined.buffer.has_remaining() {
243 warn!(
244 "Transcript reader is not fully read out: {:?} bytes left",
245 self.combined.buffer.remaining()
246 )
247 }
248 }
249}
250
251pub struct TranscriptReader<'a, B: Buf> {
252 buffer: &'a mut B,
253 debug_assertions: bool,
254}
255
256impl<B: Buf> TranscriptReader<'_, B> {
257 pub const fn buffer(&mut self) -> &mut B {
258 self.buffer
259 }
260
261 pub fn read<T: DeserializeBytes>(&mut self) -> Result<T, Error> {
262 let mode = SerializationMode::CanonicalTower;
263 T::deserialize(self.buffer(), mode).map_err(Into::into)
264 }
265
266 pub fn read_vec<T: DeserializeBytes>(&mut self, n: usize) -> Result<Vec<T>, Error> {
267 let mode = SerializationMode::CanonicalTower;
268 let mut buffer = self.buffer();
269 repeat_with(move || T::deserialize(&mut buffer, mode).map_err(Into::into))
270 .take(n)
271 .collect()
272 }
273
274 pub fn read_bytes(&mut self, buf: &mut [u8]) -> Result<(), Error> {
275 let buffer = self.buffer();
276 if buffer.remaining() < buf.len() {
277 return Err(Error::NotEnoughBytes);
278 }
279 buffer.copy_to_slice(buf);
280 Ok(())
281 }
282
283 pub fn read_scalar<F: TowerField>(&mut self) -> Result<F, Error> {
284 let mut out = F::default();
285 self.read_scalar_slice_into(slice::from_mut(&mut out))?;
286 Ok(out)
287 }
288
289 pub fn read_scalar_slice_into<F: TowerField>(&mut self, buf: &mut [F]) -> Result<(), Error> {
290 let mut buffer = self.buffer();
291 for elem in buf {
292 let mode = SerializationMode::CanonicalTower;
293 *elem = DeserializeBytes::deserialize(&mut buffer, mode)?;
294 }
295 Ok(())
296 }
297
298 pub fn read_scalar_slice<F: TowerField>(&mut self, len: usize) -> Result<Vec<F>, Error> {
299 let mut elems = vec![F::default(); len];
300 self.read_scalar_slice_into(&mut elems)?;
301 Ok(elems)
302 }
303
304 pub fn read_packed<P: PackedField<Scalar: TowerField>>(&mut self) -> Result<P, Error> {
305 P::try_from_fn(|_| self.read_scalar())
306 }
307
308 pub fn read_packed_slice<P: PackedField<Scalar: TowerField>>(
309 &mut self,
310 len: usize,
311 ) -> Result<Vec<P>, Error> {
312 let mut packed = Vec::with_capacity(len);
313 for _ in 0..len {
314 packed.push(self.read_packed()?);
315 }
316 Ok(packed)
317 }
318
319 pub fn read_debug(&mut self, msg: &str) {
320 if self.debug_assertions {
321 let msg_bytes = msg.as_bytes();
322 let mut buffer = vec![0; msg_bytes.len()];
323 assert!(self.read_bytes(&mut buffer).is_ok());
324 assert_eq!(msg_bytes, buffer);
325 }
326 }
327}
328
329pub struct TranscriptWriter<'a, B: BufMut> {
330 buffer: &'a mut B,
331 debug_assertions: bool,
332}
333
334impl<B: BufMut> TranscriptWriter<'_, B> {
335 pub const fn buffer(&mut self) -> &mut B {
336 self.buffer
337 }
338
339 pub fn write<T: SerializeBytes>(&mut self, value: &T) {
340 self.proof_size_event_wrapper(|buffer| {
341 value
342 .serialize(buffer, SerializationMode::CanonicalTower)
343 .expect("TODO: propagate error");
344 });
345 }
346
347 pub fn write_slice<T: SerializeBytes>(&mut self, values: &[T]) {
348 self.proof_size_event_wrapper(|buffer| {
349 for value in values {
350 value
351 .serialize(&mut *buffer, SerializationMode::CanonicalTower)
352 .expect("TODO: propagate error");
353 }
354 });
355 }
356
357 pub fn write_bytes(&mut self, data: &[u8]) {
358 self.proof_size_event_wrapper(|buffer| {
359 buffer.put_slice(data);
360 });
361 }
362
363 pub fn write_scalar<F: TowerField>(&mut self, f: F) {
364 self.write_scalar_slice(slice::from_ref(&f));
365 }
366
367 pub fn write_scalar_iter<F: TowerField>(&mut self, it: impl IntoIterator<Item = F>) {
368 self.proof_size_event_wrapper(move |buffer| {
369 for elem in it {
370 SerializeBytes::serialize(&elem, &mut *buffer, SerializationMode::CanonicalTower)
371 .expect("TODO: propagate error");
372 }
373 });
374 }
375
376 pub fn write_scalar_slice<F: TowerField>(&mut self, elems: &[F]) {
377 self.write_scalar_iter(elems.iter().copied());
378 }
379
380 pub fn write_packed<P: PackedField<Scalar: TowerField>>(&mut self, packed: P) {
381 self.write_scalar_iter(packed.into_iter());
382 }
383
384 pub fn write_packed_iter<P: PackedField<Scalar: TowerField>>(
385 &mut self,
386 it: impl IntoIterator<Item = P>,
387 ) {
388 self.write_scalar_iter(it.into_iter().flat_map(|packed| packed.into_iter()));
389 }
390
391 pub fn write_packed_slice<P: PackedField<Scalar: TowerField>>(&mut self, packed_slice: &[P]) {
392 self.write_scalar_iter(P::iter_slice(packed_slice));
393 }
394
395 pub fn write_debug(&mut self, msg: &str) {
396 if self.debug_assertions {
397 self.write_bytes(msg.as_bytes())
398 }
399 }
400
401 fn proof_size_event_wrapper<F: FnOnce(&mut B)>(&mut self, f: F) {
402 let buffer = self.buffer();
403 let start_bytes = buffer.remaining_mut();
404 f(buffer);
405 let end_bytes = buffer.remaining_mut();
406 tracing::event!(name: "incremental_proof_size", tracing::Level::INFO, counter=true, incremental=true, value=start_bytes - end_bytes);
407 }
408}
409
410impl<F, Challenger_> CanSample<F> for VerifierTranscript<Challenger_>
411where
412 F: TowerField,
413 Challenger_: Challenger,
414{
415 fn sample(&mut self) -> F {
416 let mode = SerializationMode::CanonicalTower;
417 DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode)
418 .expect("challenger has infinite buffer")
419 }
420}
421
422impl<F, Challenger_> CanSample<F> for ProverTranscript<Challenger_>
423where
424 F: TowerField,
425 Challenger_: Challenger,
426{
427 fn sample(&mut self) -> F {
428 let mode = SerializationMode::CanonicalTower;
429 DeserializeBytes::deserialize(self.combined.challenger.sampler(), mode)
430 .expect("challenger has infinite buffer")
431 }
432}
433
434fn sample_bits_reader<Reader: Buf>(mut reader: Reader, bits: usize) -> u32 {
435 let bits = bits.min(u32::BITS as usize);
436
437 let bytes_to_sample: usize = std::mem::size_of::<u32>();
438
439 let mut bytes = [0u8; std::mem::size_of::<u32>()];
440
441 reader.copy_to_slice(&mut bytes[..bytes_to_sample]);
442
443 let unmasked = u32::from_le_bytes(bytes);
444 let mask = 1u32.checked_shl(bits as u32);
445 let mask = match mask {
446 Some(x) => x - 1,
447 None => u32::MAX,
448 };
449 mask & unmasked
450}
451
452impl<Challenger_> CanSampleBits<u32> for VerifierTranscript<Challenger_>
453where
454 Challenger_: Challenger,
455{
456 fn sample_bits(&mut self, bits: usize) -> u32 {
457 sample_bits_reader(self.combined.challenger.sampler(), bits)
458 }
459}
460
461impl<Challenger_> CanSampleBits<u32> for ProverTranscript<Challenger_>
462where
463 Challenger_: Challenger,
464{
465 fn sample_bits(&mut self, bits: usize) -> u32 {
466 sample_bits_reader(self.combined.challenger.sampler(), bits)
467 }
468}
469
470pub fn read_u64<B: Buf>(transcript: &mut TranscriptReader<B>) -> Result<u64, Error> {
472 let mut as_bytes = [0; size_of::<u64>()];
473 transcript.read_bytes(&mut as_bytes)?;
474 Ok(u64::from_le_bytes(as_bytes))
475}
476
477pub fn write_u64<B: BufMut>(transcript: &mut TranscriptWriter<B>, n: u64) {
478 transcript.write_bytes(&n.to_le_bytes());
479}
480
481#[cfg(test)]
482mod tests {
483 use binius_field::{
484 AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField8b, BinaryField128b,
485 BinaryField128bPolyval, BinaryField32b, BinaryField64b, BinaryField8b,
486 };
487 use binius_hash::groestl::Groestl256;
488 use rand::{thread_rng, RngCore};
489
490 use super::*;
491 use crate::fiat_shamir::HasherChallenger;
492
493 #[test]
494 fn test_transcripting() {
495 let mut prover_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
496 let mut writable = prover_transcript.message();
497
498 writable.write_scalar(BinaryField8b::new(0x96));
499 writable.write_scalar(BinaryField32b::new(0xDEADBEEF));
500 writable.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB));
501 let sampled_fanpaar1: BinaryField128b = prover_transcript.sample();
502
503 let mut writable = prover_transcript.message();
504
505 writable.write_scalar(AESTowerField8b::new(0x52));
506 writable.write_scalar(AESTowerField32b::new(0x12345678));
507 writable.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777));
508
509 let sampled_aes1: AESTowerField16b = prover_transcript.sample();
510
511 prover_transcript
512 .message()
513 .write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
514 let sampled_polyval1: BinaryField128bPolyval = prover_transcript.sample();
515
516 let mut verifier_transcript = prover_transcript.into_verifier();
517 let mut readable = verifier_transcript.message();
518
519 let fp_8: BinaryField8b = readable.read_scalar().unwrap();
520 let fp_32: BinaryField32b = readable.read_scalar().unwrap();
521 let fp_128: BinaryField128b = readable.read_scalar().unwrap();
522
523 assert_eq!(fp_8.val(), 0x96);
524 assert_eq!(fp_32.val(), 0xDEADBEEF);
525 assert_eq!(fp_128.val(), 0x55669900112233550000CCDDFFEEAABB);
526
527 let sampled_fanpaar1_res: BinaryField128b = verifier_transcript.sample();
528
529 assert_eq!(sampled_fanpaar1_res, sampled_fanpaar1);
530
531 let mut readable = verifier_transcript.message();
532
533 let aes_8: AESTowerField8b = readable.read_scalar().unwrap();
534 let aes_32: AESTowerField32b = readable.read_scalar().unwrap();
535 let aes_128: AESTowerField128b = readable.read_scalar().unwrap();
536
537 assert_eq!(aes_8.val(), 0x52);
538 assert_eq!(aes_32.val(), 0x12345678);
539 assert_eq!(aes_128.val(), 0xDDDDBBBBCCCCAAAA2222999911117777);
540
541 let sampled_aes_res: AESTowerField16b = verifier_transcript.sample();
542
543 assert_eq!(sampled_aes_res, sampled_aes1);
544
545 let polyval_128: BinaryField128bPolyval =
546 verifier_transcript.message().read_scalar().unwrap();
547 assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
548
549 let sampled_polyval_res: BinaryField128bPolyval = verifier_transcript.sample();
550 assert_eq!(sampled_polyval_res, sampled_polyval1);
551
552 verifier_transcript.finalize().unwrap();
553 }
554
555 #[test]
556 fn test_advicing() {
557 let mut prover_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
558 let mut advice_writer = prover_transcript.decommitment();
559
560 advice_writer.write_scalar(BinaryField8b::new(0x96));
561 advice_writer.write_scalar(BinaryField32b::new(0xDEADBEEF));
562 advice_writer.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB));
563
564 advice_writer.write_scalar(AESTowerField8b::new(0x52));
565 advice_writer.write_scalar(AESTowerField32b::new(0x12345678));
566 advice_writer.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777));
567
568 advice_writer.write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
569
570 let mut verifier_transcript = prover_transcript.into_verifier();
571 let mut advice_reader = verifier_transcript.decommitment();
572
573 let fp_8: BinaryField8b = advice_reader.read_scalar().unwrap();
574 let fp_32: BinaryField32b = advice_reader.read_scalar().unwrap();
575 let fp_128: BinaryField128b = advice_reader.read_scalar().unwrap();
576
577 assert_eq!(fp_8.val(), 0x96);
578 assert_eq!(fp_32.val(), 0xDEADBEEF);
579 assert_eq!(fp_128.val(), 0x55669900112233550000CCDDFFEEAABB);
580
581 let aes_8: AESTowerField8b = advice_reader.read_scalar().unwrap();
582 let aes_32: AESTowerField32b = advice_reader.read_scalar().unwrap();
583 let aes_128: AESTowerField128b = advice_reader.read_scalar().unwrap();
584
585 assert_eq!(aes_8.val(), 0x52);
586 assert_eq!(aes_32.val(), 0x12345678);
587 assert_eq!(aes_128.val(), 0xDDDDBBBBCCCCAAAA2222999911117777);
588
589 let polyval_128: BinaryField128bPolyval = advice_reader.read_scalar().unwrap();
590 assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
591
592 verifier_transcript.finalize().unwrap();
593 }
594
595 #[test]
596 fn test_challenger_and_observing() {
597 let mut taped_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
598 let mut untaped_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
599 let mut challenger = HasherChallenger::<Groestl256>::default();
600
601 const NUM_SAMPLING: usize = 32;
602 let mut random_bytes = [0u8; NUM_SAMPLING * 8];
603 thread_rng().fill_bytes(&mut random_bytes);
604 let mut sampled_arrays = [[0u8; 8]; NUM_SAMPLING];
605
606 for i in 0..NUM_SAMPLING {
607 taped_transcript
608 .message()
609 .write_scalar(BinaryField64b::new(u64::from_le_bytes(
610 random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(),
611 )));
612 untaped_transcript
613 .observe()
614 .write_scalar(BinaryField64b::new(u64::from_le_bytes(
615 random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(),
616 )));
617 challenger
618 .observer()
619 .put_slice(&random_bytes[i * 8..i * 8 + 8]);
620
621 let sampled_out_transcript1: BinaryField64b = taped_transcript.sample();
622 let sampled_out_transcript2: BinaryField64b = untaped_transcript.sample();
623 let mut challenger_out = [0u8; 8];
624 challenger.sampler().copy_to_slice(&mut challenger_out);
625 assert_eq!(challenger_out, sampled_out_transcript1.val().to_le_bytes());
626 assert_eq!(challenger_out, sampled_out_transcript2.val().to_le_bytes());
627 sampled_arrays[i] = challenger_out;
628 }
629
630 let mut taped_transcript = taped_transcript.into_verifier();
631
632 assert!(untaped_transcript.finalize().is_empty());
633
634 for array in sampled_arrays {
635 let _: BinaryField64b = taped_transcript.message().read_scalar().unwrap();
636 let sampled_out_transcript: BinaryField64b = taped_transcript.sample();
637
638 assert_eq!(array, sampled_out_transcript.val().to_le_bytes());
639 }
640
641 taped_transcript.finalize().unwrap();
642 }
643
644 #[test]
645 fn test_transcript_debug() {
646 let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
647
648 transcript.message().write_debug("test_transcript_debug");
649 transcript
650 .into_verifier()
651 .message()
652 .read_debug("test_transcript_debug");
653 }
654
655 #[test]
656 #[should_panic]
657 fn test_transcript_debug_fail() {
658 let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
659
660 transcript.message().write_debug("test_transcript_debug");
661 transcript
662 .into_verifier()
663 .message()
664 .read_debug("test_transcript_debug_should_fail");
665 }
666}