binius_core/protocols/evalcheck/
evalcheck.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{
4	ops::{Deref, Range},
5	slice,
6	sync::Arc,
7};
8
9use binius_field::{Field, TowerField};
10use bytes::{Buf, BufMut};
11
12use super::error::Error;
13use crate::{
14	oracle::OracleId,
15	transcript::{TranscriptReader, TranscriptWriter},
16};
17
18#[derive(Debug, Clone)]
19pub struct EvalcheckMultilinearClaim<F: Field> {
20	/// Virtual Polynomial Oracle for which the evaluation is claimed
21	pub id: OracleId,
22	/// Evaluation Point
23	pub eval_point: EvalPoint<F>,
24	/// Claimed Evaluation
25	pub eval: F,
26}
27
28#[repr(u8)]
29#[derive(Debug)]
30enum EvalcheckNumerics {
31	Transparent = 1,
32	Committed,
33	Shifted,
34	Packed,
35	Repeating,
36	LinearCombination,
37	ZeroPadded,
38	CompositeMLE,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum EvalcheckProof<F: Field> {
43	Transparent,
44	Committed,
45	Shifted,
46	Packed,
47	Repeating(Box<EvalcheckProof<F>>),
48	LinearCombination {
49		subproofs: Vec<(F, EvalcheckProof<F>)>,
50	},
51	ZeroPadded(F, Box<EvalcheckProof<F>>),
52	CompositeMLE,
53}
54
55impl<F: Field> EvalcheckProof<F> {
56	pub fn isomorphic<FI: Field + From<F>>(self) -> EvalcheckProof<FI> {
57		match self {
58			Self::Transparent => EvalcheckProof::Transparent,
59			Self::Committed => EvalcheckProof::Committed,
60			Self::Shifted => EvalcheckProof::Shifted,
61			Self::Packed => EvalcheckProof::Packed,
62			Self::Repeating(proof) => EvalcheckProof::Repeating(Box::new(proof.isomorphic())),
63			Self::LinearCombination { subproofs } => EvalcheckProof::LinearCombination {
64				subproofs: subproofs
65					.into_iter()
66					.map(|(eval, proof)| (eval.into(), proof.isomorphic()))
67					.collect(),
68			},
69			Self::ZeroPadded(eval, proof) => {
70				EvalcheckProof::ZeroPadded(eval.into(), Box::new(proof.isomorphic()))
71			}
72			Self::CompositeMLE => EvalcheckProof::CompositeMLE,
73		}
74	}
75}
76
77impl EvalcheckNumerics {
78	const fn from(x: u8) -> Result<Self, Error> {
79		match x {
80			1 => Ok(Self::Transparent),
81			2 => Ok(Self::Committed),
82			3 => Ok(Self::Shifted),
83			4 => Ok(Self::Packed),
84			5 => Ok(Self::Repeating),
85			6 => Ok(Self::LinearCombination),
86			7 => Ok(Self::ZeroPadded),
87			8 => Ok(Self::CompositeMLE),
88			_ => Err(Error::EvalcheckSerializationError),
89		}
90	}
91}
92
93/// Serializes the `EvalcheckProof` into the transcript
94pub fn serialize_evalcheck_proof<B: BufMut, F: TowerField>(
95	transcript: &mut TranscriptWriter<B>,
96	evalcheck: &EvalcheckProof<F>,
97) {
98	match evalcheck {
99		EvalcheckProof::Transparent => {
100			transcript.write_bytes(&[EvalcheckNumerics::Transparent as u8]);
101		}
102		EvalcheckProof::Committed => {
103			transcript.write_bytes(&[EvalcheckNumerics::Committed as u8]);
104		}
105		EvalcheckProof::Shifted => {
106			transcript.write_bytes(&[EvalcheckNumerics::Shifted as u8]);
107		}
108		EvalcheckProof::Packed => {
109			transcript.write_bytes(&[EvalcheckNumerics::Packed as u8]);
110		}
111		EvalcheckProof::Repeating(inner) => {
112			transcript.write_bytes(&[EvalcheckNumerics::Repeating as u8]);
113			serialize_evalcheck_proof(transcript, inner);
114		}
115		EvalcheckProof::LinearCombination { subproofs } => {
116			transcript.write_bytes(&[EvalcheckNumerics::LinearCombination as u8]);
117			let len_u64 = subproofs.len() as u64;
118			transcript.write_bytes(&len_u64.to_le_bytes());
119			for (scalar, subproof) in subproofs {
120				transcript.write_scalar(*scalar);
121				serialize_evalcheck_proof(transcript, subproof)
122			}
123		}
124		EvalcheckProof::ZeroPadded(val, subproof) => {
125			transcript.write_bytes(&[EvalcheckNumerics::ZeroPadded as u8]);
126			transcript.write_scalar(*val);
127			serialize_evalcheck_proof(transcript, subproof);
128		}
129		EvalcheckProof::CompositeMLE => {
130			transcript.write_bytes(&[EvalcheckNumerics::CompositeMLE as u8]);
131		}
132	}
133}
134
135/// Deserializes the `EvalcheckProof` object from the given transcript.
136pub fn deserialize_evalcheck_proof<B: Buf, F: TowerField>(
137	transcript: &mut TranscriptReader<B>,
138) -> Result<EvalcheckProof<F>, Error> {
139	let mut ty = 0;
140	transcript.read_bytes(slice::from_mut(&mut ty))?;
141	let as_enum = EvalcheckNumerics::from(ty)?;
142
143	match as_enum {
144		EvalcheckNumerics::Transparent => Ok(EvalcheckProof::Transparent),
145		EvalcheckNumerics::Committed => Ok(EvalcheckProof::Committed),
146		EvalcheckNumerics::Shifted => Ok(EvalcheckProof::Shifted),
147		EvalcheckNumerics::Packed => Ok(EvalcheckProof::Packed),
148		EvalcheckNumerics::Repeating => {
149			let inner = deserialize_evalcheck_proof(transcript)?;
150			Ok(EvalcheckProof::Repeating(Box::new(inner)))
151		}
152		EvalcheckNumerics::LinearCombination => {
153			let mut len = [0u8; 8];
154			transcript.read_bytes(&mut len)?;
155			let len = u64::from_le_bytes(len) as usize;
156			let mut subproofs: Vec<(F, EvalcheckProof<F>)> = Vec::new();
157			for _ in 0..len {
158				let scalar = transcript.read_scalar()?;
159				let subproof = deserialize_evalcheck_proof(transcript)?;
160				subproofs.push((scalar, subproof));
161			}
162			Ok(EvalcheckProof::LinearCombination { subproofs })
163		}
164		EvalcheckNumerics::ZeroPadded => {
165			let scalar = transcript.read_scalar()?;
166			let subproof = deserialize_evalcheck_proof(transcript)?;
167			Ok(EvalcheckProof::ZeroPadded(scalar, Box::new(subproof)))
168		}
169		EvalcheckNumerics::CompositeMLE => Ok(EvalcheckProof::CompositeMLE),
170	}
171}
172
173pub struct EvalPointOracleIdMap<T: Clone, F: Field> {
174	data: Vec<Vec<(EvalPoint<F>, T)>>,
175}
176
177impl<T: Clone, F: Field> EvalPointOracleIdMap<T, F> {
178	pub fn new() -> Self {
179		Self {
180			data: Default::default(),
181		}
182	}
183
184	pub fn get(&self, id: OracleId, eval_point: &[F]) -> Option<&T> {
185		self.data
186			.get(id)?
187			.iter()
188			.find(|(ep, _)| **ep == *eval_point)
189			.map(|(_, val)| val)
190	}
191
192	pub fn insert(&mut self, id: OracleId, eval_point: EvalPoint<F>, val: T) {
193		if id >= self.data.len() {
194			self.data.resize(id + 1, Vec::new());
195		}
196
197		self.data[id].push((eval_point, val))
198	}
199
200	pub fn flatten(mut self) -> Vec<T> {
201		self.data.reverse();
202
203		std::mem::take(&mut self.data)
204			.into_iter()
205			.flatten()
206			.map(|(_, val)| val)
207			.collect::<Vec<_>>()
208	}
209}
210
211impl<T: Clone, F: Field> Default for EvalPointOracleIdMap<T, F> {
212	fn default() -> Self {
213		Self {
214			data: Default::default(),
215		}
216	}
217}
218
219#[derive(Debug, Clone)]
220pub struct EvalPoint<F> {
221	data: Arc<[F]>,
222	range: Range<usize>,
223}
224
225impl<F: Field> PartialEq for EvalPoint<F> {
226	fn eq(&self, other: &Self) -> bool {
227		self.data[self.range.clone()] == other.data[other.range.clone()]
228	}
229}
230
231impl<F: Clone> EvalPoint<F> {
232	pub fn slice(&self, range: Range<usize>) -> Self {
233		assert!(self.range.len() >= range.len());
234
235		let new_range = self.range.start + range.start..self.range.start + range.end;
236
237		Self {
238			data: self.data.clone(),
239			range: new_range,
240		}
241	}
242
243	pub fn to_vec(&self) -> Vec<F> {
244		self.data.to_vec()
245	}
246}
247
248impl<F> From<Vec<F>> for EvalPoint<F> {
249	fn from(data: Vec<F>) -> Self {
250		let range = 0..data.len();
251		Self {
252			data: data.into(),
253			range,
254		}
255	}
256}
257
258impl<F: Clone> From<&[F]> for EvalPoint<F> {
259	fn from(data: &[F]) -> Self {
260		let range = 0..data.len();
261		Self {
262			data: data.into(),
263			range,
264		}
265	}
266}
267
268impl<F> Deref for EvalPoint<F> {
269	type Target = [F];
270
271	fn deref(&self) -> &Self::Target {
272		&self.data[self.range.clone()]
273	}
274}