binius_core/protocols/evalcheck/
evalcheck.rs1use std::{
4 hash::Hash,
5 ops::{Deref, Range},
6 slice,
7 sync::Arc,
8};
9
10use binius_field::{Field, TowerField};
11use bytes::{Buf, BufMut};
12
13use super::error::Error;
14use crate::{
15 oracle::OracleId,
16 transcript::{TranscriptReader, TranscriptWriter},
17};
18
19#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct EvalcheckMultilinearClaim<F: Field> {
22 pub id: OracleId,
24 pub eval_point: EvalPoint<F>,
26 pub eval: F,
28}
29
30#[repr(u8)]
31#[derive(Debug)]
32enum EvalcheckNumerics {
33 Transparent = 1,
34 Committed,
35 Shifted,
36 Packed,
37 Repeating,
38 LinearCombination,
39 ZeroPadded,
40 CompositeMLE,
41 Projected,
42 DuplicateClaim,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum EvalcheckProof<F: Field> {
48 Transparent,
49 Committed,
50 Shifted,
51 Packed,
52 Repeating(Box<EvalcheckProof<F>>),
53 LinearCombination {
54 subproofs: Vec<(Option<F>, EvalcheckProof<F>)>,
55 },
56 ZeroPadded(F, Box<EvalcheckProof<F>>),
57 CompositeMLE,
58 Projected(Box<EvalcheckProof<F>>),
59 DuplicateClaim(usize),
60}
61
62impl EvalcheckNumerics {
63 const fn from(x: u8) -> Result<Self, Error> {
64 match x {
65 1 => Ok(Self::Transparent),
66 2 => Ok(Self::Committed),
67 3 => Ok(Self::Shifted),
68 4 => Ok(Self::Packed),
69 5 => Ok(Self::Repeating),
70 6 => Ok(Self::LinearCombination),
71 7 => Ok(Self::ZeroPadded),
72 8 => Ok(Self::CompositeMLE),
73 9 => Ok(Self::Projected),
74 10 => Ok(Self::DuplicateClaim),
75 _ => Err(Error::EvalcheckSerializationError),
76 }
77 }
78}
79
80pub fn serialize_evalcheck_proof<B: BufMut, F: TowerField>(
82 transcript: &mut TranscriptWriter<B>,
83 evalcheck: &EvalcheckProof<F>,
84) {
85 match evalcheck {
86 EvalcheckProof::Transparent => {
87 transcript.write_bytes(&[EvalcheckNumerics::Transparent as u8]);
88 }
89 EvalcheckProof::Committed => {
90 transcript.write_bytes(&[EvalcheckNumerics::Committed as u8]);
91 }
92 EvalcheckProof::Shifted => {
93 transcript.write_bytes(&[EvalcheckNumerics::Shifted as u8]);
94 }
95 EvalcheckProof::Packed => {
96 transcript.write_bytes(&[EvalcheckNumerics::Packed as u8]);
97 }
98 EvalcheckProof::Repeating(inner) => {
99 transcript.write_bytes(&[EvalcheckNumerics::Repeating as u8]);
100 serialize_evalcheck_proof(transcript, inner);
101 }
102 EvalcheckProof::LinearCombination { subproofs } => {
103 transcript.write_bytes(&[EvalcheckNumerics::LinearCombination as u8]);
104 let len_u64 = subproofs.len() as u64;
105 transcript.write_bytes(&len_u64.to_le_bytes());
106 for (scalar, subproof) in subproofs {
107 serialize_evalcheck_proof(transcript, subproof);
108 if let Some(scalar) = scalar {
109 transcript.write_scalar(*scalar);
110 }
111 }
112 }
113 EvalcheckProof::ZeroPadded(val, subproof) => {
114 transcript.write_bytes(&[EvalcheckNumerics::ZeroPadded as u8]);
115 transcript.write_scalar(*val);
116 serialize_evalcheck_proof(transcript, subproof);
117 }
118 EvalcheckProof::CompositeMLE => {
119 transcript.write_bytes(&[EvalcheckNumerics::CompositeMLE as u8]);
120 }
121 EvalcheckProof::DuplicateClaim(index) => {
122 transcript.write_bytes(&[EvalcheckNumerics::DuplicateClaim as u8]);
123 transcript.write(index);
124 }
125 EvalcheckProof::Projected(subproof) => {
126 transcript.write_bytes(&[EvalcheckNumerics::Projected as u8]);
127 serialize_evalcheck_proof(transcript, subproof);
128 }
129 }
130}
131
132pub fn deserialize_evalcheck_proof<B: Buf, F: TowerField>(
134 transcript: &mut TranscriptReader<B>,
135) -> Result<EvalcheckProof<F>, Error> {
136 let mut ty = 0;
137 transcript.read_bytes(slice::from_mut(&mut ty))?;
138 let as_enum = EvalcheckNumerics::from(ty)?;
139
140 match as_enum {
141 EvalcheckNumerics::Transparent => Ok(EvalcheckProof::Transparent),
142 EvalcheckNumerics::Committed => Ok(EvalcheckProof::Committed),
143 EvalcheckNumerics::Shifted => Ok(EvalcheckProof::Shifted),
144 EvalcheckNumerics::Packed => Ok(EvalcheckProof::Packed),
145 EvalcheckNumerics::Repeating => {
146 let inner = deserialize_evalcheck_proof(transcript)?;
147 Ok(EvalcheckProof::Repeating(Box::new(inner)))
148 }
149 EvalcheckNumerics::LinearCombination => {
150 let mut len = [0u8; 8];
151 transcript.read_bytes(&mut len)?;
152 let len = u64::from_le_bytes(len) as usize;
153 let mut subproofs = Vec::new();
154 for _ in 0..len {
155 let subproof = deserialize_evalcheck_proof(transcript)?;
156
157 let scalar = if let EvalcheckProof::DuplicateClaim(_) = subproof {
158 None
159 } else {
160 Some(transcript.read_scalar()?)
161 };
162
163 subproofs.push((scalar, subproof));
164 }
165 Ok(EvalcheckProof::LinearCombination { subproofs })
166 }
167 EvalcheckNumerics::ZeroPadded => {
168 let scalar = transcript.read_scalar()?;
169 let subproof = deserialize_evalcheck_proof(transcript)?;
170 Ok(EvalcheckProof::ZeroPadded(scalar, Box::new(subproof)))
171 }
172 EvalcheckNumerics::CompositeMLE => Ok(EvalcheckProof::CompositeMLE),
173 EvalcheckNumerics::DuplicateClaim => {
174 let index = transcript.read()?;
175 Ok(EvalcheckProof::DuplicateClaim(index))
176 }
177 EvalcheckNumerics::Projected => {
178 let subproof = deserialize_evalcheck_proof(transcript)?;
179 Ok(EvalcheckProof::Projected(Box::new(subproof)))
180 }
181 }
182}
183
184pub struct EvalPointOracleIdMap<T: Clone, F: Field> {
190 data: Vec<Vec<(EvalPoint<F>, T)>>,
191}
192
193impl<T: Clone, F: Field> EvalPointOracleIdMap<T, F> {
194 pub fn new() -> Self {
195 Self {
196 data: Default::default(),
197 }
198 }
199
200 pub fn get(&self, id: OracleId, eval_point: &[F]) -> Option<&T> {
204 self.data
205 .get(id)?
206 .iter()
207 .find(|(ep, _)| **ep == *eval_point)
208 .map(|(_, val)| val)
209 }
210
211 pub fn insert(&mut self, id: OracleId, eval_point: EvalPoint<F>, val: T) {
215 if id >= self.data.len() {
216 self.data.resize(id + 1, Vec::new());
217 }
218
219 self.data[id].push((eval_point, val))
220 }
221
222 pub fn flatten(mut self) -> Vec<T> {
224 self.data.reverse();
225
226 std::mem::take(&mut self.data)
227 .into_iter()
228 .flatten()
229 .map(|(_, val)| val)
230 .collect::<Vec<_>>()
231 }
232
233 pub fn clear(&mut self) {
234 self.data.clear()
235 }
236}
237
238impl<T: Clone, F: Field> Default for EvalPointOracleIdMap<T, F> {
239 fn default() -> Self {
240 Self {
241 data: Default::default(),
242 }
243 }
244}
245
246#[derive(Debug, Clone, Eq)]
248pub struct EvalPoint<F: Field> {
249 data: Arc<[F]>,
250 range: Range<usize>,
251}
252
253impl<F: Field> PartialEq for EvalPoint<F> {
254 fn eq(&self, other: &Self) -> bool {
255 self.data[self.range.clone()] == other.data[other.range.clone()]
256 }
257}
258
259impl<F: Field> Hash for EvalPoint<F> {
260 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
261 self.data[self.range.clone()].hash(state)
262 }
263}
264
265impl<F: Field> EvalPoint<F> {
266 pub fn slice(&self, range: Range<usize>) -> Self {
267 assert!(self.range.len() >= range.len());
268
269 let new_range = self.range.start + range.start..self.range.start + range.end;
270
271 Self {
272 data: self.data.clone(),
273 range: new_range,
274 }
275 }
276
277 pub fn to_vec(&self) -> Vec<F> {
278 self.data.to_vec()
279 }
280}
281
282impl<F: Field> From<Vec<F>> for EvalPoint<F> {
283 fn from(data: Vec<F>) -> Self {
284 let range = 0..data.len();
285 Self {
286 data: data.into(),
287 range,
288 }
289 }
290}
291
292impl<F: Field> From<&[F]> for EvalPoint<F> {
293 fn from(data: &[F]) -> Self {
294 let range = 0..data.len();
295 Self {
296 data: data.into(),
297 range,
298 }
299 }
300}
301
302impl<F: Field> Deref for EvalPoint<F> {
303 type Target = [F];
304
305 fn deref(&self) -> &Self::Target {
306 &self.data[self.range.clone()]
307 }
308}