binius_core/protocols/evalcheck/
evalcheck.rs

1// Copyright 2023-2025 Irreducible Inc.
2
3use std::{
4	hash::Hash,
5	ops::{Deref, Range},
6	slice,
7	sync::Arc,
8};
9
10use binius_field::{Field, TowerField};
11use bytes::{Buf, BufMut};
12
13use super::error::Error;
14use crate::{
15	oracle::OracleId,
16	transcript::{TranscriptReader, TranscriptWriter},
17};
18
19/// This struct represents a claim to be verified through the evalcheck protocol.
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct EvalcheckMultilinearClaim<F: Field> {
22	/// Virtual Polynomial Oracle for which the evaluation is claimed
23	pub id: OracleId,
24	/// Evaluation Point
25	pub eval_point: EvalPoint<F>,
26	/// Claimed Evaluation
27	pub eval: F,
28}
29
30#[repr(u8)]
31#[derive(Debug)]
32enum EvalcheckNumerics {
33	Transparent = 1,
34	Committed,
35	Shifted,
36	Packed,
37	Repeating,
38	LinearCombination,
39	ZeroPadded,
40	CompositeMLE,
41	Projected,
42	DuplicateClaim,
43}
44
45/// The proof output of a given claim which may recursively contain proofs of subclaims arising from a given claim.
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum EvalcheckProof<F: Field> {
48	Transparent,
49	Committed,
50	Shifted,
51	Packed,
52	Repeating(Box<EvalcheckProof<F>>),
53	LinearCombination {
54		subproofs: Vec<(Option<F>, EvalcheckProof<F>)>,
55	},
56	ZeroPadded(F, Box<EvalcheckProof<F>>),
57	CompositeMLE,
58	Projected(Box<EvalcheckProof<F>>),
59	DuplicateClaim(usize),
60}
61
62impl EvalcheckNumerics {
63	const fn from(x: u8) -> Result<Self, Error> {
64		match x {
65			1 => Ok(Self::Transparent),
66			2 => Ok(Self::Committed),
67			3 => Ok(Self::Shifted),
68			4 => Ok(Self::Packed),
69			5 => Ok(Self::Repeating),
70			6 => Ok(Self::LinearCombination),
71			7 => Ok(Self::ZeroPadded),
72			8 => Ok(Self::CompositeMLE),
73			9 => Ok(Self::Projected),
74			10 => Ok(Self::DuplicateClaim),
75			_ => Err(Error::EvalcheckSerializationError),
76		}
77	}
78}
79
80/// Serializes the `EvalcheckProof` into the transcript
81pub fn serialize_evalcheck_proof<B: BufMut, F: TowerField>(
82	transcript: &mut TranscriptWriter<B>,
83	evalcheck: &EvalcheckProof<F>,
84) {
85	match evalcheck {
86		EvalcheckProof::Transparent => {
87			transcript.write_bytes(&[EvalcheckNumerics::Transparent as u8]);
88		}
89		EvalcheckProof::Committed => {
90			transcript.write_bytes(&[EvalcheckNumerics::Committed as u8]);
91		}
92		EvalcheckProof::Shifted => {
93			transcript.write_bytes(&[EvalcheckNumerics::Shifted as u8]);
94		}
95		EvalcheckProof::Packed => {
96			transcript.write_bytes(&[EvalcheckNumerics::Packed as u8]);
97		}
98		EvalcheckProof::Repeating(inner) => {
99			transcript.write_bytes(&[EvalcheckNumerics::Repeating as u8]);
100			serialize_evalcheck_proof(transcript, inner);
101		}
102		EvalcheckProof::LinearCombination { subproofs } => {
103			transcript.write_bytes(&[EvalcheckNumerics::LinearCombination as u8]);
104			let len_u64 = subproofs.len() as u64;
105			transcript.write_bytes(&len_u64.to_le_bytes());
106			for (scalar, subproof) in subproofs {
107				serialize_evalcheck_proof(transcript, subproof);
108				if let Some(scalar) = scalar {
109					transcript.write_scalar(*scalar);
110				}
111			}
112		}
113		EvalcheckProof::ZeroPadded(val, subproof) => {
114			transcript.write_bytes(&[EvalcheckNumerics::ZeroPadded as u8]);
115			transcript.write_scalar(*val);
116			serialize_evalcheck_proof(transcript, subproof);
117		}
118		EvalcheckProof::CompositeMLE => {
119			transcript.write_bytes(&[EvalcheckNumerics::CompositeMLE as u8]);
120		}
121		EvalcheckProof::DuplicateClaim(index) => {
122			transcript.write_bytes(&[EvalcheckNumerics::DuplicateClaim as u8]);
123			transcript.write(index);
124		}
125		EvalcheckProof::Projected(subproof) => {
126			transcript.write_bytes(&[EvalcheckNumerics::Projected as u8]);
127			serialize_evalcheck_proof(transcript, subproof);
128		}
129	}
130}
131
132/// Deserializes the `EvalcheckProof` object from the given transcript.
133pub fn deserialize_evalcheck_proof<B: Buf, F: TowerField>(
134	transcript: &mut TranscriptReader<B>,
135) -> Result<EvalcheckProof<F>, Error> {
136	let mut ty = 0;
137	transcript.read_bytes(slice::from_mut(&mut ty))?;
138	let as_enum = EvalcheckNumerics::from(ty)?;
139
140	match as_enum {
141		EvalcheckNumerics::Transparent => Ok(EvalcheckProof::Transparent),
142		EvalcheckNumerics::Committed => Ok(EvalcheckProof::Committed),
143		EvalcheckNumerics::Shifted => Ok(EvalcheckProof::Shifted),
144		EvalcheckNumerics::Packed => Ok(EvalcheckProof::Packed),
145		EvalcheckNumerics::Repeating => {
146			let inner = deserialize_evalcheck_proof(transcript)?;
147			Ok(EvalcheckProof::Repeating(Box::new(inner)))
148		}
149		EvalcheckNumerics::LinearCombination => {
150			let mut len = [0u8; 8];
151			transcript.read_bytes(&mut len)?;
152			let len = u64::from_le_bytes(len) as usize;
153			let mut subproofs = Vec::new();
154			for _ in 0..len {
155				let subproof = deserialize_evalcheck_proof(transcript)?;
156
157				let scalar = if let EvalcheckProof::DuplicateClaim(_) = subproof {
158					None
159				} else {
160					Some(transcript.read_scalar()?)
161				};
162
163				subproofs.push((scalar, subproof));
164			}
165			Ok(EvalcheckProof::LinearCombination { subproofs })
166		}
167		EvalcheckNumerics::ZeroPadded => {
168			let scalar = transcript.read_scalar()?;
169			let subproof = deserialize_evalcheck_proof(transcript)?;
170			Ok(EvalcheckProof::ZeroPadded(scalar, Box::new(subproof)))
171		}
172		EvalcheckNumerics::CompositeMLE => Ok(EvalcheckProof::CompositeMLE),
173		EvalcheckNumerics::DuplicateClaim => {
174			let index = transcript.read()?;
175			Ok(EvalcheckProof::DuplicateClaim(index))
176		}
177		EvalcheckNumerics::Projected => {
178			let subproof = deserialize_evalcheck_proof(transcript)?;
179			Ok(EvalcheckProof::Projected(Box::new(subproof)))
180		}
181	}
182}
183
184/// Data structure for efficiently querying and inserting evaluations of claims.
185///
186/// Equivalent to a `HashMap<(OracleId, EvalPoint<F>), T>` but uses vectors of vectors to store the data.
187/// This data structure is more memory efficient for small number of evaluation points and OracleIds which
188/// are grouped together.
189pub struct EvalPointOracleIdMap<T: Clone, F: Field> {
190	data: Vec<Vec<(EvalPoint<F>, T)>>,
191}
192
193impl<T: Clone, F: Field> EvalPointOracleIdMap<T, F> {
194	pub fn new() -> Self {
195		Self {
196			data: Default::default(),
197		}
198	}
199
200	/// Query the first value found for an evaluation point for a given oracle id.
201	///
202	/// Returns `None` if no value is found.
203	pub fn get(&self, id: OracleId, eval_point: &[F]) -> Option<&T> {
204		self.data
205			.get(id)?
206			.iter()
207			.find(|(ep, _)| **ep == *eval_point)
208			.map(|(_, val)| val)
209	}
210
211	/// Insert a new evaluation point for a given oracle id.
212	///
213	/// We do not replace existing values.
214	pub fn insert(&mut self, id: OracleId, eval_point: EvalPoint<F>, val: T) {
215		if id >= self.data.len() {
216			self.data.resize(id + 1, Vec::new());
217		}
218
219		self.data[id].push((eval_point, val))
220	}
221
222	/// Flatten the data structure into a vector of values.
223	pub fn flatten(mut self) -> Vec<T> {
224		self.data.reverse();
225
226		std::mem::take(&mut self.data)
227			.into_iter()
228			.flatten()
229			.map(|(_, val)| val)
230			.collect::<Vec<_>>()
231	}
232
233	pub fn clear(&mut self) {
234		self.data.clear()
235	}
236}
237
238impl<T: Clone, F: Field> Default for EvalPointOracleIdMap<T, F> {
239	fn default() -> Self {
240		Self {
241			data: Default::default(),
242		}
243	}
244}
245
246/// A wrapper struct for evaluation points.
247#[derive(Debug, Clone, Eq)]
248pub struct EvalPoint<F: Field> {
249	data: Arc<[F]>,
250	range: Range<usize>,
251}
252
253impl<F: Field> PartialEq for EvalPoint<F> {
254	fn eq(&self, other: &Self) -> bool {
255		self.data[self.range.clone()] == other.data[other.range.clone()]
256	}
257}
258
259impl<F: Field> Hash for EvalPoint<F> {
260	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
261		self.data[self.range.clone()].hash(state)
262	}
263}
264
265impl<F: Field> EvalPoint<F> {
266	pub fn slice(&self, range: Range<usize>) -> Self {
267		assert!(self.range.len() >= range.len());
268
269		let new_range = self.range.start + range.start..self.range.start + range.end;
270
271		Self {
272			data: self.data.clone(),
273			range: new_range,
274		}
275	}
276
277	pub fn to_vec(&self) -> Vec<F> {
278		self.data.to_vec()
279	}
280}
281
282impl<F: Field> From<Vec<F>> for EvalPoint<F> {
283	fn from(data: Vec<F>) -> Self {
284		let range = 0..data.len();
285		Self {
286			data: data.into(),
287			range,
288		}
289	}
290}
291
292impl<F: Field> From<&[F]> for EvalPoint<F> {
293	fn from(data: &[F]) -> Self {
294		let range = 0..data.len();
295		Self {
296			data: data.into(),
297			range,
298		}
299	}
300}
301
302impl<F: Field> Deref for EvalPoint<F> {
303	type Target = [F];
304
305	fn deref(&self) -> &Self::Target {
306		&self.data[self.range.clone()]
307	}
308}