binius_core/protocols/evalcheck/
evalcheck.rs1use std::{
4 ops::{Deref, Range},
5 slice,
6 sync::Arc,
7};
8
9use binius_field::{Field, TowerField};
10use bytes::{Buf, BufMut};
11
12use super::error::Error;
13use crate::{
14 oracle::OracleId,
15 transcript::{TranscriptReader, TranscriptWriter},
16};
17
18#[derive(Debug, Clone)]
19pub struct EvalcheckMultilinearClaim<F: Field> {
20 pub id: OracleId,
22 pub eval_point: EvalPoint<F>,
24 pub eval: F,
26}
27
28#[repr(u8)]
29#[derive(Debug)]
30enum EvalcheckNumerics {
31 Transparent = 1,
32 Committed,
33 Shifted,
34 Packed,
35 Repeating,
36 LinearCombination,
37 ZeroPadded,
38 CompositeMLE,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum EvalcheckProof<F: Field> {
43 Transparent,
44 Committed,
45 Shifted,
46 Packed,
47 Repeating(Box<EvalcheckProof<F>>),
48 LinearCombination {
49 subproofs: Vec<(F, EvalcheckProof<F>)>,
50 },
51 ZeroPadded(F, Box<EvalcheckProof<F>>),
52 CompositeMLE,
53}
54
55impl<F: Field> EvalcheckProof<F> {
56 pub fn isomorphic<FI: Field + From<F>>(self) -> EvalcheckProof<FI> {
57 match self {
58 Self::Transparent => EvalcheckProof::Transparent,
59 Self::Committed => EvalcheckProof::Committed,
60 Self::Shifted => EvalcheckProof::Shifted,
61 Self::Packed => EvalcheckProof::Packed,
62 Self::Repeating(proof) => EvalcheckProof::Repeating(Box::new(proof.isomorphic())),
63 Self::LinearCombination { subproofs } => EvalcheckProof::LinearCombination {
64 subproofs: subproofs
65 .into_iter()
66 .map(|(eval, proof)| (eval.into(), proof.isomorphic()))
67 .collect(),
68 },
69 Self::ZeroPadded(eval, proof) => {
70 EvalcheckProof::ZeroPadded(eval.into(), Box::new(proof.isomorphic()))
71 }
72 Self::CompositeMLE => EvalcheckProof::CompositeMLE,
73 }
74 }
75}
76
77impl EvalcheckNumerics {
78 const fn from(x: u8) -> Result<Self, Error> {
79 match x {
80 1 => Ok(Self::Transparent),
81 2 => Ok(Self::Committed),
82 3 => Ok(Self::Shifted),
83 4 => Ok(Self::Packed),
84 5 => Ok(Self::Repeating),
85 6 => Ok(Self::LinearCombination),
86 7 => Ok(Self::ZeroPadded),
87 8 => Ok(Self::CompositeMLE),
88 _ => Err(Error::EvalcheckSerializationError),
89 }
90 }
91}
92
93pub fn serialize_evalcheck_proof<B: BufMut, F: TowerField>(
95 transcript: &mut TranscriptWriter<B>,
96 evalcheck: &EvalcheckProof<F>,
97) {
98 match evalcheck {
99 EvalcheckProof::Transparent => {
100 transcript.write_bytes(&[EvalcheckNumerics::Transparent as u8]);
101 }
102 EvalcheckProof::Committed => {
103 transcript.write_bytes(&[EvalcheckNumerics::Committed as u8]);
104 }
105 EvalcheckProof::Shifted => {
106 transcript.write_bytes(&[EvalcheckNumerics::Shifted as u8]);
107 }
108 EvalcheckProof::Packed => {
109 transcript.write_bytes(&[EvalcheckNumerics::Packed as u8]);
110 }
111 EvalcheckProof::Repeating(inner) => {
112 transcript.write_bytes(&[EvalcheckNumerics::Repeating as u8]);
113 serialize_evalcheck_proof(transcript, inner);
114 }
115 EvalcheckProof::LinearCombination { subproofs } => {
116 transcript.write_bytes(&[EvalcheckNumerics::LinearCombination as u8]);
117 let len_u64 = subproofs.len() as u64;
118 transcript.write_bytes(&len_u64.to_le_bytes());
119 for (scalar, subproof) in subproofs {
120 transcript.write_scalar(*scalar);
121 serialize_evalcheck_proof(transcript, subproof)
122 }
123 }
124 EvalcheckProof::ZeroPadded(val, subproof) => {
125 transcript.write_bytes(&[EvalcheckNumerics::ZeroPadded as u8]);
126 transcript.write_scalar(*val);
127 serialize_evalcheck_proof(transcript, subproof);
128 }
129 EvalcheckProof::CompositeMLE => {
130 transcript.write_bytes(&[EvalcheckNumerics::CompositeMLE as u8]);
131 }
132 }
133}
134
135pub fn deserialize_evalcheck_proof<B: Buf, F: TowerField>(
137 transcript: &mut TranscriptReader<B>,
138) -> Result<EvalcheckProof<F>, Error> {
139 let mut ty = 0;
140 transcript.read_bytes(slice::from_mut(&mut ty))?;
141 let as_enum = EvalcheckNumerics::from(ty)?;
142
143 match as_enum {
144 EvalcheckNumerics::Transparent => Ok(EvalcheckProof::Transparent),
145 EvalcheckNumerics::Committed => Ok(EvalcheckProof::Committed),
146 EvalcheckNumerics::Shifted => Ok(EvalcheckProof::Shifted),
147 EvalcheckNumerics::Packed => Ok(EvalcheckProof::Packed),
148 EvalcheckNumerics::Repeating => {
149 let inner = deserialize_evalcheck_proof(transcript)?;
150 Ok(EvalcheckProof::Repeating(Box::new(inner)))
151 }
152 EvalcheckNumerics::LinearCombination => {
153 let mut len = [0u8; 8];
154 transcript.read_bytes(&mut len)?;
155 let len = u64::from_le_bytes(len) as usize;
156 let mut subproofs: Vec<(F, EvalcheckProof<F>)> = Vec::new();
157 for _ in 0..len {
158 let scalar = transcript.read_scalar()?;
159 let subproof = deserialize_evalcheck_proof(transcript)?;
160 subproofs.push((scalar, subproof));
161 }
162 Ok(EvalcheckProof::LinearCombination { subproofs })
163 }
164 EvalcheckNumerics::ZeroPadded => {
165 let scalar = transcript.read_scalar()?;
166 let subproof = deserialize_evalcheck_proof(transcript)?;
167 Ok(EvalcheckProof::ZeroPadded(scalar, Box::new(subproof)))
168 }
169 EvalcheckNumerics::CompositeMLE => Ok(EvalcheckProof::CompositeMLE),
170 }
171}
172
173pub struct EvalPointOracleIdMap<T: Clone, F: Field> {
174 data: Vec<Vec<(EvalPoint<F>, T)>>,
175}
176
177impl<T: Clone, F: Field> EvalPointOracleIdMap<T, F> {
178 pub fn new() -> Self {
179 Self {
180 data: Default::default(),
181 }
182 }
183
184 pub fn get(&self, id: OracleId, eval_point: &[F]) -> Option<&T> {
185 self.data
186 .get(id)?
187 .iter()
188 .find(|(ep, _)| **ep == *eval_point)
189 .map(|(_, val)| val)
190 }
191
192 pub fn insert(&mut self, id: OracleId, eval_point: EvalPoint<F>, val: T) {
193 if id >= self.data.len() {
194 self.data.resize(id + 1, Vec::new());
195 }
196
197 self.data[id].push((eval_point, val))
198 }
199
200 pub fn flatten(mut self) -> Vec<T> {
201 self.data.reverse();
202
203 std::mem::take(&mut self.data)
204 .into_iter()
205 .flatten()
206 .map(|(_, val)| val)
207 .collect::<Vec<_>>()
208 }
209}
210
211impl<T: Clone, F: Field> Default for EvalPointOracleIdMap<T, F> {
212 fn default() -> Self {
213 Self {
214 data: Default::default(),
215 }
216 }
217}
218
219#[derive(Debug, Clone)]
220pub struct EvalPoint<F> {
221 data: Arc<[F]>,
222 range: Range<usize>,
223}
224
225impl<F: Field> PartialEq for EvalPoint<F> {
226 fn eq(&self, other: &Self) -> bool {
227 self.data[self.range.clone()] == other.data[other.range.clone()]
228 }
229}
230
231impl<F: Clone> EvalPoint<F> {
232 pub fn slice(&self, range: Range<usize>) -> Self {
233 assert!(self.range.len() >= range.len());
234
235 let new_range = self.range.start + range.start..self.range.start + range.end;
236
237 Self {
238 data: self.data.clone(),
239 range: new_range,
240 }
241 }
242
243 pub fn to_vec(&self) -> Vec<F> {
244 self.data.to_vec()
245 }
246}
247
248impl<F> From<Vec<F>> for EvalPoint<F> {
249 fn from(data: Vec<F>) -> Self {
250 let range = 0..data.len();
251 Self {
252 data: data.into(),
253 range,
254 }
255 }
256}
257
258impl<F: Clone> From<&[F]> for EvalPoint<F> {
259 fn from(data: &[F]) -> Self {
260 let range = 0..data.len();
261 Self {
262 data: data.into(),
263 range,
264 }
265 }
266}
267
268impl<F> Deref for EvalPoint<F> {
269 type Target = [F];
270
271 fn deref(&self) -> &Self::Target {
272 &self.data[self.range.clone()]
273 }
274}