binius_core/protocols/evalcheck/
evalcheck.rs1use std::{
4 hash::Hash,
5 ops::{Deref, Range},
6 sync::Arc,
7};
8
9use binius_field::Field;
10use bytes::{Buf, BufMut};
11
12use super::error::Error;
13use crate::{
14 oracle::OracleId,
15 transcript::{TranscriptReader, TranscriptWriter},
16};
17
18#[derive(Debug, Clone, PartialEq, Eq)]
20pub struct EvalcheckMultilinearClaim<F: Field> {
21 pub id: OracleId,
23 pub eval_point: EvalPoint<F>,
25 pub eval: F,
27}
28
29#[repr(u32)]
30#[derive(Debug)]
31enum EvalcheckNumerics {
32 NewClaim = 1,
33 DuplicateClaim,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum EvalcheckHint {
46 NewClaim,
47 DuplicateClaim(u32),
48}
49
50impl EvalcheckNumerics {
51 const fn from(x: u32) -> Result<Self, Error> {
52 match x {
53 1 => Ok(Self::NewClaim),
54 2 => Ok(Self::DuplicateClaim),
55 _ => Err(Error::EvalcheckSerializationError),
56 }
57 }
58}
59
60pub fn serialize_evalcheck_proof<B: BufMut>(
62 transcript: &mut TranscriptWriter<B>,
63 evalcheck: &EvalcheckHint,
64) {
65 match evalcheck {
66 EvalcheckHint::NewClaim => {
67 transcript.write(&(EvalcheckNumerics::NewClaim as u32));
68 }
69 EvalcheckHint::DuplicateClaim(index) => {
70 transcript.write(&(EvalcheckNumerics::DuplicateClaim as u32));
71 transcript.write(index);
72 }
73 }
74}
75
76pub fn deserialize_evalcheck_proof<B: Buf>(
78 transcript: &mut TranscriptReader<B>,
79) -> Result<EvalcheckHint, Error> {
80 let mut bytes = [0; size_of::<u32>()];
81 transcript.read_bytes(&mut bytes)?;
82 let as_enum = EvalcheckNumerics::from(u32::from_le_bytes(bytes))?;
83
84 match as_enum {
85 EvalcheckNumerics::NewClaim => Ok(EvalcheckHint::NewClaim),
86 EvalcheckNumerics::DuplicateClaim => {
87 let index = transcript.read()?;
88 Ok(EvalcheckHint::DuplicateClaim(index))
89 }
90 }
91}
92
93#[derive(Clone, Debug)]
99pub struct EvalPointOracleIdMap<T: Clone, F: Field> {
100 data: Vec<Vec<(EvalPoint<F>, T)>>,
101}
102
103impl<T: Clone, F: Field> EvalPointOracleIdMap<T, F> {
104 pub fn new() -> Self {
105 Self {
106 data: Default::default(),
107 }
108 }
109
110 pub fn get(&self, id: OracleId, eval_point: &[F]) -> Option<&T> {
114 self.data
115 .get(id.index())?
116 .iter()
117 .find(|(ep, _)| **ep == *eval_point)
118 .map(|(_, val)| val)
119 }
120
121 pub fn insert(&mut self, id: OracleId, eval_point: EvalPoint<F>, val: T) {
123 if id.index() >= self.data.len() {
124 self.data.resize(id.index() + 1, Vec::new());
125 }
126
127 if self.get(id, &eval_point).is_none() {
128 self.data[id.index()].push((eval_point, val))
129 }
130 }
131
132 pub fn contains(&self, id: OracleId, eval_point: &EvalPoint<F>) -> bool {
134 self.get(id, eval_point).is_some()
135 }
136
137 pub fn flatten(mut self) -> Vec<T> {
139 self.data.reverse();
140
141 std::mem::take(&mut self.data)
142 .into_iter()
143 .flatten()
144 .map(|(_, val)| val)
145 .collect::<Vec<_>>()
146 }
147
148 pub fn clear(&mut self) {
149 self.data.clear()
150 }
151}
152
153impl<T: Clone, F: Field> Default for EvalPointOracleIdMap<T, F> {
154 fn default() -> Self {
155 Self {
156 data: Default::default(),
157 }
158 }
159}
160
161#[derive(Debug, Clone, Eq)]
163pub struct EvalPoint<F: Field> {
164 data: Arc<[F]>,
165 range: Range<usize>,
166}
167
168impl<F: Field> PartialEq for EvalPoint<F> {
169 fn eq(&self, other: &Self) -> bool {
170 self.data[self.range.clone()] == other.data[other.range.clone()]
171 }
172}
173
174impl<F: Field> Hash for EvalPoint<F> {
175 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
176 self.data[self.range.clone()].hash(state)
177 }
178}
179
180impl<F: Field> EvalPoint<F> {
181 pub fn slice(&self, range: Range<usize>) -> Self {
182 assert!(self.range.len() >= range.len());
183
184 let new_range = self.range.start + range.start..self.range.start + range.end;
185
186 Self {
187 data: self.data.clone(),
188 range: new_range,
189 }
190 }
191
192 pub fn to_vec(&self) -> Vec<F> {
193 self.data[self.range.clone()].to_vec()
194 }
195
196 pub fn try_get_prefix(&self, suffix: &Self) -> Option<Self> {
197 let suffix_len = suffix.len();
198 let self_len = self.len();
199
200 if suffix_len > self_len {
201 return None;
202 }
203
204 if self.slice(self_len - suffix_len..self_len) == *suffix {
205 return Some(self.slice(0..self_len - suffix_len));
206 }
207 None
208 }
209}
210
211impl<F: Field> From<Vec<F>> for EvalPoint<F> {
212 fn from(data: Vec<F>) -> Self {
213 let range = 0..data.len();
214 Self {
215 data: data.into(),
216 range,
217 }
218 }
219}
220
221impl<F: Field> From<&[F]> for EvalPoint<F> {
222 fn from(data: &[F]) -> Self {
223 let range = 0..data.len();
224 Self {
225 data: data.into(),
226 range,
227 }
228 }
229}
230
231impl<F: Field> Deref for EvalPoint<F> {
232 type Target = [F];
233
234 fn deref(&self) -> &Self::Target {
235 &self.data[self.range.clone()]
236 }
237}