mod error;
use std::{iter::repeat_with, slice};
use binius_field::{deserialize_canonical, serialize_canonical, PackedField, TowerField};
use binius_utils::serialization::{DeserializeBytes, SerializeBytes};
use bytes::{buf::UninitSlice, Buf, BufMut, Bytes, BytesMut};
pub use error::Error;
use tracing::warn;
use crate::fiat_shamir::{CanSample, CanSampleBits, Challenger};
#[derive(Debug)]
pub struct ProverTranscript<Challenger> {
combined: FiatShamirBuf<BytesMut, Challenger>,
debug_assertions: bool,
}
#[derive(Debug)]
pub struct VerifierTranscript<Challenger> {
combined: FiatShamirBuf<Bytes, Challenger>,
debug_assertions: bool,
}
#[derive(Debug, Default)]
struct FiatShamirBuf<Inner, Challenger> {
buffer: Inner,
challenger: Challenger,
}
impl<Inner: Buf, Challenger_: Challenger> Buf for FiatShamirBuf<Inner, Challenger_> {
fn remaining(&self) -> usize {
self.buffer.remaining()
}
fn chunk(&self) -> &[u8] {
self.buffer.chunk()
}
fn advance(&mut self, cnt: usize) {
assert!(cnt <= self.buffer.remaining());
let readable = self.buffer.chunk();
assert!(cnt <= readable.len());
self.challenger.observer().put_slice(&readable[..cnt]);
self.buffer.advance(cnt);
}
}
unsafe impl<Inner: BufMut, Challenger_: Challenger> BufMut for FiatShamirBuf<Inner, Challenger_> {
fn remaining_mut(&self) -> usize {
self.buffer.remaining_mut()
}
unsafe fn advance_mut(&mut self, cnt: usize) {
assert!(cnt <= self.buffer.remaining_mut());
let written = self.buffer.chunk_mut();
assert!(cnt <= written.len());
let written: &[u8] = slice::from_raw_parts(written.as_mut_ptr(), cnt);
self.challenger.observer().put_slice(written);
self.buffer.advance_mut(cnt);
}
fn chunk_mut(&mut self) -> &mut UninitSlice {
self.buffer.chunk_mut()
}
}
impl<Challenger_: Default + Challenger> ProverTranscript<Challenger_> {
pub fn new() -> Self {
Self {
combined: Default::default(),
debug_assertions: cfg!(debug_assertions),
}
}
pub fn into_verifier(self) -> VerifierTranscript<Challenger_> {
VerifierTranscript::new(self.finalize())
}
}
impl<Challenger_: Default + Challenger> Default for ProverTranscript<Challenger_> {
fn default() -> Self {
Self::new()
}
}
impl<Challenger_: Challenger> ProverTranscript<Challenger_> {
pub fn finalize(self) -> Vec<u8> {
self.combined.buffer.to_vec()
}
pub fn set_debug(&mut self, debug: bool) {
self.debug_assertions = debug;
}
pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b>
where
'a: 'b,
{
TranscriptWriter {
buffer: self.combined.challenger.observer(),
debug_assertions: self.debug_assertions,
}
}
pub fn decommitment(&mut self) -> TranscriptWriter<impl BufMut> {
TranscriptWriter {
buffer: &mut self.combined.buffer,
debug_assertions: self.debug_assertions,
}
}
pub fn message<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut>
where
'a: 'b,
{
TranscriptWriter {
buffer: &mut self.combined,
debug_assertions: self.debug_assertions,
}
}
}
impl<Challenger_: Default + Challenger> VerifierTranscript<Challenger_> {
pub fn new(vec: Vec<u8>) -> Self {
Self {
combined: FiatShamirBuf {
challenger: Challenger_::default(),
buffer: Bytes::from(vec),
},
debug_assertions: cfg!(debug_assertions),
}
}
}
impl<Challenger_: Challenger> VerifierTranscript<Challenger_> {
pub fn finalize(self) -> Result<(), Error> {
if self.combined.buffer.has_remaining() {
return Err(Error::TranscriptNotEmpty {
remaining: self.combined.buffer.remaining(),
});
}
Ok(())
}
pub fn set_debug(&mut self, debug: bool) {
self.debug_assertions = debug;
}
pub fn observe<'a, 'b>(&'a mut self) -> TranscriptWriter<'b, impl BufMut + 'b>
where
'a: 'b,
{
TranscriptWriter {
buffer: self.combined.challenger.observer(),
debug_assertions: self.debug_assertions,
}
}
pub fn decommitment(&mut self) -> TranscriptReader<impl Buf + '_> {
TranscriptReader {
buffer: &mut self.combined.buffer,
debug_assertions: self.debug_assertions,
}
}
pub fn message<'a, 'b>(&'a mut self) -> TranscriptReader<'b, impl Buf>
where
'a: 'b,
{
TranscriptReader {
buffer: &mut self.combined,
debug_assertions: self.debug_assertions,
}
}
}
impl<Challenger> Drop for VerifierTranscript<Challenger> {
fn drop(&mut self) {
if self.combined.buffer.has_remaining() {
warn!(
"Transcript reader is not fully read out: {:?} bytes left",
self.combined.buffer.remaining()
)
}
}
}
pub struct TranscriptReader<'a, B: Buf> {
buffer: &'a mut B,
debug_assertions: bool,
}
impl<B: Buf> TranscriptReader<'_, B> {
pub fn buffer(&mut self) -> &mut B {
self.buffer
}
pub fn read<T: DeserializeBytes>(&mut self) -> Result<T, Error> {
T::deserialize(self.buffer()).map_err(Into::into)
}
pub fn read_vec<T: DeserializeBytes>(&mut self, n: usize) -> Result<Vec<T>, Error> {
let mut buffer = self.buffer();
repeat_with(move || T::deserialize(&mut buffer).map_err(Into::into))
.take(n)
.collect()
}
pub fn read_bytes(&mut self, buf: &mut [u8]) -> Result<(), Error> {
let buffer = self.buffer();
if buffer.remaining() < buf.len() {
return Err(Error::NotEnoughBytes);
}
buffer.copy_to_slice(buf);
Ok(())
}
pub fn read_scalar<F: TowerField>(&mut self) -> Result<F, Error> {
let mut out = F::default();
self.read_scalar_slice_into(slice::from_mut(&mut out))?;
Ok(out)
}
pub fn read_scalar_slice_into<F: TowerField>(&mut self, buf: &mut [F]) -> Result<(), Error> {
let mut buffer = self.buffer();
for elem in buf {
*elem = deserialize_canonical(&mut buffer)?;
}
Ok(())
}
pub fn read_scalar_slice<F: TowerField>(&mut self, len: usize) -> Result<Vec<F>, Error> {
let mut elems = vec![F::default(); len];
self.read_scalar_slice_into(&mut elems)?;
Ok(elems)
}
pub fn read_packed<P: PackedField<Scalar: TowerField>>(&mut self) -> Result<P, Error> {
P::try_from_fn(|_| self.read_scalar())
}
pub fn read_packed_slice<P: PackedField<Scalar: TowerField>>(
&mut self,
len: usize,
) -> Result<Vec<P>, Error> {
let mut packed = Vec::with_capacity(len);
for _ in 0..len {
packed.push(self.read_packed()?);
}
Ok(packed)
}
pub fn read_debug(&mut self, msg: &str) {
if self.debug_assertions {
let msg_bytes = msg.as_bytes();
let mut buffer = vec![0; msg_bytes.len()];
assert!(self.read_bytes(&mut buffer).is_ok());
assert_eq!(msg_bytes, buffer);
}
}
}
pub struct TranscriptWriter<'a, B: BufMut> {
buffer: &'a mut B,
debug_assertions: bool,
}
impl<B: BufMut> TranscriptWriter<'_, B> {
pub fn buffer(&mut self) -> &mut B {
self.buffer
}
pub fn write<T: SerializeBytes>(&mut self, value: &T) {
value
.serialize(self.buffer())
.expect("TODO: propagate error")
}
pub fn write_slice<T: SerializeBytes>(&mut self, values: &[T]) {
let mut buffer = self.buffer();
for value in values {
value.serialize(&mut buffer).expect("TODO: propagate error")
}
}
pub fn write_bytes(&mut self, data: &[u8]) {
self.buffer().put_slice(data);
}
pub fn write_scalar<F: TowerField>(&mut self, f: F) {
self.write_scalar_slice(slice::from_ref(&f));
}
pub fn write_scalar_slice<F: TowerField>(&mut self, elems: &[F]) {
let mut buffer = self.buffer();
for elem in elems {
serialize_canonical(*elem, &mut buffer).expect("TODO: propagate error");
}
}
pub fn write_packed<P: PackedField<Scalar: TowerField>>(&mut self, packed: P) {
for scalar in packed.iter() {
self.write_scalar(scalar);
}
}
pub fn write_packed_slice<P: PackedField<Scalar: TowerField>>(&mut self, packed_slice: &[P]) {
for &packed in packed_slice {
self.write_packed(packed)
}
}
pub fn write_debug(&mut self, msg: &str) {
if self.debug_assertions {
self.write_bytes(msg.as_bytes())
}
}
}
impl<F, Challenger_> CanSample<F> for VerifierTranscript<Challenger_>
where
F: TowerField,
Challenger_: Challenger,
{
fn sample(&mut self) -> F {
deserialize_canonical(self.combined.challenger.sampler())
.expect("challenger has infinite buffer")
}
}
impl<F, Challenger_> CanSample<F> for ProverTranscript<Challenger_>
where
F: TowerField,
Challenger_: Challenger,
{
fn sample(&mut self) -> F {
deserialize_canonical(self.combined.challenger.sampler())
.expect("challenger has infinite buffer")
}
}
fn sample_bits_reader<Reader: Buf>(mut reader: Reader, bits: usize) -> usize {
let bits = bits.min(usize::BITS as usize);
let bytes_to_sample = bits.div_ceil(8);
let mut bytes = [0u8; std::mem::size_of::<usize>()];
reader.copy_to_slice(&mut bytes[..bytes_to_sample]);
let unmasked = usize::from_le_bytes(bytes);
let mask = 1usize.checked_shl(bits as u32);
let mask = match mask {
Some(x) => x - 1,
None => usize::MAX,
};
mask & unmasked
}
impl<Challenger_> CanSampleBits<usize> for VerifierTranscript<Challenger_>
where
Challenger_: Challenger,
{
fn sample_bits(&mut self, bits: usize) -> usize {
sample_bits_reader(self.combined.challenger.sampler(), bits)
}
}
impl<Challenger_> CanSampleBits<usize> for ProverTranscript<Challenger_>
where
Challenger_: Challenger,
{
fn sample_bits(&mut self, bits: usize) -> usize {
sample_bits_reader(self.combined.challenger.sampler(), bits)
}
}
pub fn read_u64<B: Buf>(transcript: &mut TranscriptReader<B>) -> Result<u64, Error> {
let mut as_bytes = [0; size_of::<u64>()];
transcript.read_bytes(&mut as_bytes)?;
Ok(u64::from_le_bytes(as_bytes))
}
pub fn write_u64<B: BufMut>(transcript: &mut TranscriptWriter<B>, n: u64) {
transcript.write_bytes(&n.to_le_bytes());
}
#[cfg(test)]
mod tests {
use binius_field::{
AESTowerField128b, AESTowerField16b, AESTowerField32b, AESTowerField8b, BinaryField128b,
BinaryField128bPolyval, BinaryField32b, BinaryField64b, BinaryField8b,
};
use groestl_crypto::Groestl256;
use rand::{thread_rng, RngCore};
use super::*;
use crate::fiat_shamir::HasherChallenger;
#[test]
fn test_transcripting() {
let mut prover_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
let mut writable = prover_transcript.message();
writable.write_scalar(BinaryField8b::new(0x96));
writable.write_scalar(BinaryField32b::new(0xDEADBEEF));
writable.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB));
let sampled_fanpaar1: BinaryField128b = prover_transcript.sample();
let mut writable = prover_transcript.message();
writable.write_scalar(AESTowerField8b::new(0x52));
writable.write_scalar(AESTowerField32b::new(0x12345678));
writable.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777));
let sampled_aes1: AESTowerField16b = prover_transcript.sample();
prover_transcript
.message()
.write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
let sampled_polyval1: BinaryField128bPolyval = prover_transcript.sample();
let mut verifier_transcript = prover_transcript.into_verifier();
let mut readable = verifier_transcript.message();
let fp_8: BinaryField8b = readable.read_scalar().unwrap();
let fp_32: BinaryField32b = readable.read_scalar().unwrap();
let fp_128: BinaryField128b = readable.read_scalar().unwrap();
assert_eq!(fp_8.val(), 0x96);
assert_eq!(fp_32.val(), 0xDEADBEEF);
assert_eq!(fp_128.val(), 0x55669900112233550000CCDDFFEEAABB);
let sampled_fanpaar1_res: BinaryField128b = verifier_transcript.sample();
assert_eq!(sampled_fanpaar1_res, sampled_fanpaar1);
let mut readable = verifier_transcript.message();
let aes_8: AESTowerField8b = readable.read_scalar().unwrap();
let aes_32: AESTowerField32b = readable.read_scalar().unwrap();
let aes_128: AESTowerField128b = readable.read_scalar().unwrap();
assert_eq!(aes_8.val(), 0x52);
assert_eq!(aes_32.val(), 0x12345678);
assert_eq!(aes_128.val(), 0xDDDDBBBBCCCCAAAA2222999911117777);
let sampled_aes_res: AESTowerField16b = verifier_transcript.sample();
assert_eq!(sampled_aes_res, sampled_aes1);
let polyval_128: BinaryField128bPolyval =
verifier_transcript.message().read_scalar().unwrap();
assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
let sampled_polyval_res: BinaryField128bPolyval = verifier_transcript.sample();
assert_eq!(sampled_polyval_res, sampled_polyval1);
verifier_transcript.finalize().unwrap();
}
#[test]
fn test_advicing() {
let mut prover_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
let mut advice_writer = prover_transcript.decommitment();
advice_writer.write_scalar(BinaryField8b::new(0x96));
advice_writer.write_scalar(BinaryField32b::new(0xDEADBEEF));
advice_writer.write_scalar(BinaryField128b::new(0x55669900112233550000CCDDFFEEAABB));
advice_writer.write_scalar(AESTowerField8b::new(0x52));
advice_writer.write_scalar(AESTowerField32b::new(0x12345678));
advice_writer.write_scalar(AESTowerField128b::new(0xDDDDBBBBCCCCAAAA2222999911117777));
advice_writer.write_scalar(BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
let mut verifier_transcript = prover_transcript.into_verifier();
let mut advice_reader = verifier_transcript.decommitment();
let fp_8: BinaryField8b = advice_reader.read_scalar().unwrap();
let fp_32: BinaryField32b = advice_reader.read_scalar().unwrap();
let fp_128: BinaryField128b = advice_reader.read_scalar().unwrap();
assert_eq!(fp_8.val(), 0x96);
assert_eq!(fp_32.val(), 0xDEADBEEF);
assert_eq!(fp_128.val(), 0x55669900112233550000CCDDFFEEAABB);
let aes_8: AESTowerField8b = advice_reader.read_scalar().unwrap();
let aes_32: AESTowerField32b = advice_reader.read_scalar().unwrap();
let aes_128: AESTowerField128b = advice_reader.read_scalar().unwrap();
assert_eq!(aes_8.val(), 0x52);
assert_eq!(aes_32.val(), 0x12345678);
assert_eq!(aes_128.val(), 0xDDDDBBBBCCCCAAAA2222999911117777);
let polyval_128: BinaryField128bPolyval = advice_reader.read_scalar().unwrap();
assert_eq!(polyval_128, BinaryField128bPolyval::new(0xFFFF12345678DDDDEEEE87654321AAAA));
verifier_transcript.finalize().unwrap();
}
#[test]
fn test_challenger_and_observing() {
let mut taped_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
let mut untaped_transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
let mut challenger = HasherChallenger::<Groestl256>::default();
const NUM_SAMPLING: usize = 32;
let mut random_bytes = [0u8; NUM_SAMPLING * 8];
thread_rng().fill_bytes(&mut random_bytes);
let mut sampled_arrays = [[0u8; 8]; NUM_SAMPLING];
for i in 0..NUM_SAMPLING {
taped_transcript
.message()
.write_scalar(BinaryField64b::new(u64::from_le_bytes(
random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(),
)));
untaped_transcript
.observe()
.write_scalar(BinaryField64b::new(u64::from_le_bytes(
random_bytes[i * 8..i * 8 + 8].to_vec().try_into().unwrap(),
)));
challenger
.observer()
.put_slice(&random_bytes[i * 8..i * 8 + 8]);
let sampled_out_transcript1: BinaryField64b = taped_transcript.sample();
let sampled_out_transcript2: BinaryField64b = untaped_transcript.sample();
let mut challenger_out = [0u8; 8];
challenger.sampler().copy_to_slice(&mut challenger_out);
assert_eq!(challenger_out, sampled_out_transcript1.val().to_le_bytes());
assert_eq!(challenger_out, sampled_out_transcript2.val().to_le_bytes());
sampled_arrays[i] = challenger_out;
}
let mut taped_transcript = taped_transcript.into_verifier();
assert!(untaped_transcript.finalize().is_empty());
for array in sampled_arrays.into_iter() {
let _: BinaryField64b = taped_transcript.message().read_scalar().unwrap();
let sampled_out_transcript: BinaryField64b = taped_transcript.sample();
assert_eq!(array, sampled_out_transcript.val().to_le_bytes());
}
taped_transcript.finalize().unwrap();
}
#[test]
fn test_transcript_debug() {
let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
transcript.message().write_debug("test_transcript_debug");
transcript
.into_verifier()
.message()
.read_debug("test_transcript_debug");
}
#[test]
#[should_panic]
fn test_transcript_debug_fail() {
let mut transcript = ProverTranscript::<HasherChallenger<Groestl256>>::new();
transcript.message().write_debug("test_transcript_debug");
transcript
.into_verifier()
.message()
.read_debug("test_transcript_debug_should_fail");
}
}