1use std::{cell::RefCell, marker::PhantomData, rc::Rc};
4
5use anyhow::{anyhow, Error};
6use binius_core::{
7 oracle::{MultilinearOracleSet, OracleId},
8 witness::{MultilinearExtensionIndex, MultilinearWitness},
9};
10use binius_field::{
11 as_packed_field::{PackScalar, PackedType},
12 underlier::WithUnderlier,
13 ExtensionField, PackedField, TowerField,
14};
15use binius_math::MultilinearExtension;
16use binius_utils::bail;
17use bytemuck::{must_cast_slice, must_cast_slice_mut, Pod};
18
19use super::types::{F, U};
20
21pub struct Builder<'arena> {
22 bump: &'arena bumpalo::Bump,
23
24 oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
25
26 #[allow(clippy::type_complexity)]
27 entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
28}
29
30struct WitnessBuilderEntry<'arena> {
31 witness: Result<MultilinearWitness<'arena, PackedType<U, F>>, binius_math::Error>,
32 tower_level: usize,
33 nonzero_scalars_prefix: usize,
34 data: &'arena [U],
35}
36
37impl<'arena> Builder<'arena> {
38 pub fn new(
39 allocator: &'arena bumpalo::Bump,
40 oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
41 ) -> Self {
42 Self {
43 bump: allocator,
44 oracles,
45 entries: Rc::new(RefCell::new(Vec::new())),
46 }
47 }
48
49 pub fn new_column<FS: TowerField>(&self, id: OracleId) -> EntryBuilder<'arena, FS>
50 where
51 U: PackScalar<FS>,
52 F: ExtensionField<FS>,
53 {
54 let nonzero_scalars_prefix = 1 << self.oracles.borrow().n_vars(id);
55 self.new_column_with_nonzero_scalars_prefix(id, nonzero_scalars_prefix)
56 }
57
58 pub fn new_column_with_nonzero_scalars_prefix<FS: TowerField>(
59 &self,
60 id: OracleId,
61 nonzero_scalars_prefix: usize,
62 ) -> EntryBuilder<'arena, FS>
63 where
64 U: PackScalar<FS>,
65 F: ExtensionField<FS>,
66 {
67 let oracles = self.oracles.borrow();
68 let log_rows = oracles.n_vars(id);
69 let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
71 let data = bumpalo::vec![in self.bump; U::default(); len].into_bump_slice_mut();
72 EntryBuilder {
73 _marker: PhantomData,
74 entries: self.entries.clone(),
75 id,
76 log_rows,
77 nonzero_scalars_prefix,
78 data: Some(data),
79 }
80 }
81
82 pub fn new_column_with_default<FS: TowerField>(
83 &self,
84 id: OracleId,
85 default: FS,
86 ) -> EntryBuilder<'arena, FS>
87 where
88 U: PackScalar<FS>,
89 F: ExtensionField<FS>,
90 {
91 let oracles = self.oracles.borrow();
92 let log_rows = oracles.n_vars(id);
93 let nonzero_scalars_prefix = 1 << log_rows;
94 let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
95 let default = WithUnderlier::to_underlier(PackedType::<U, FS>::broadcast(default));
96 let data = bumpalo::vec![in self.bump; default; len].into_bump_slice_mut();
97 EntryBuilder {
98 _marker: PhantomData,
99 entries: self.entries.clone(),
100 id,
101 log_rows,
102 nonzero_scalars_prefix,
103 data: Some(data),
104 }
105 }
106
107 pub fn get<FS>(&self, id: OracleId) -> Result<WitnessEntry<'arena, FS>, Error>
108 where
109 FS: TowerField,
110 U: PackScalar<FS>,
111 F: ExtensionField<FS>,
112 {
113 let entries = self.entries.borrow();
114 let oracles = self.oracles.borrow();
115 if !oracles.is_valid_oracle_id(id) {
116 bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
117 }
118 let entry = entries
119 .get(id.index())
120 .and_then(|entry| entry.as_ref())
121 .ok_or_else(|| anyhow!("Witness for {} is missing", oracles.label(id)))?;
122
123 if entry.tower_level != FS::TOWER_LEVEL {
124 bail!(anyhow!(
125 "Provided tower level ({}) for {} does not match stored tower level {}.",
126 FS::TOWER_LEVEL,
127 oracles.label(id),
128 entry.tower_level
129 ));
130 }
131
132 Ok(WitnessEntry {
133 data: entry.data,
134 log_rows: oracles.n_vars(id),
135 nonzero_scalars_prefix: entry.nonzero_scalars_prefix,
136 _marker: PhantomData,
137 })
138 }
139
140 pub fn set<FS: TowerField>(
141 &self,
142 oracle_id: OracleId,
143 entry: WitnessEntry<'arena, FS>,
144 ) -> Result<(), Error>
145 where
146 U: PackScalar<FS>,
147 F: ExtensionField<FS>,
148 {
149 let oracles = self.oracles.borrow();
150 if !oracles.is_valid_oracle_id(oracle_id) {
151 bail!(anyhow!("OracleId {oracle_id} does not exist in MultilinearOracleSet"));
152 }
153 let mut entries = self.entries.borrow_mut();
154 let oracle_index = oracle_id.index();
155 if oracle_index >= entries.len() {
156 entries.resize_with(oracle_index + 1, || None);
157 }
158 entries[oracle_index] = Some(WitnessBuilderEntry {
159 data: entry.data,
160 nonzero_scalars_prefix: entry.nonzero_scalars_prefix,
161 tower_level: FS::TOWER_LEVEL,
162 witness: MultilinearExtension::new(entry.log_rows, entry.packed())
163 .map(|x| x.specialize_arc_dyn()),
164 });
165 Ok(())
166 }
167
168 pub fn build(self) -> Result<MultilinearExtensionIndex<'arena, PackedType<U, F>>, Error> {
169 let mut result = MultilinearExtensionIndex::new();
170 let entries = Rc::into_inner(self.entries)
171 .ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))?
172 .into_inner()
173 .into_iter()
174 .enumerate()
175 .filter_map(|(index, entry)| entry.map(|entry| Ok((OracleId::from_index(index), entry.witness?, entry.nonzero_scalars_prefix))))
176 .collect::<Result<Vec<_>, Error>>()?;
177 result.update_multilin_poly_with_nonzero_scalars_prefixes(entries)?;
178 Ok(result)
179 }
180}
181
182#[derive(Debug, Clone, Copy)]
183pub struct WitnessEntry<'arena, FS: TowerField>
184where
185 U: PackScalar<FS>,
186{
187 data: &'arena [U],
188 log_rows: usize,
189 nonzero_scalars_prefix: usize,
190 _marker: PhantomData<FS>,
191}
192
193impl<'arena, FS: TowerField> WitnessEntry<'arena, FS>
194where
195 U: PackScalar<FS>,
196{
197 #[inline]
198 pub fn packed(&self) -> &'arena [PackedType<U, FS>] {
199 WithUnderlier::from_underliers_ref(self.data)
200 }
201
202 #[inline]
203 pub const fn as_slice<T: Pod>(&self) -> &'arena [T] {
204 must_cast_slice(self.data)
205 }
206
207 pub const fn repacked<FE>(&self) -> WitnessEntry<'arena, FE>
208 where
209 FE: TowerField + ExtensionField<FS>,
210 U: PackScalar<FE>,
211 {
212 let log_extension_degree = <FE as ExtensionField<FS>>::LOG_DEGREE;
213 WitnessEntry {
214 data: self.data,
215 log_rows: self.log_rows - log_extension_degree,
216 nonzero_scalars_prefix: self
217 .nonzero_scalars_prefix
218 .div_ceil(1 << log_extension_degree),
219 _marker: PhantomData,
220 }
221 }
222
223 pub const fn low_rows(&self) -> usize {
224 self.log_rows
225 }
226}
227
228pub struct EntryBuilder<'arena, FS>
229where
230 FS: TowerField,
231 U: PackScalar<FS>,
232 F: ExtensionField<FS>,
233{
234 _marker: PhantomData<FS>,
235 #[allow(clippy::type_complexity)]
236 entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
237 id: OracleId,
238 log_rows: usize,
239 nonzero_scalars_prefix: usize,
240 data: Option<&'arena mut [U]>,
241}
242
243impl<FS> EntryBuilder<'_, FS>
244where
245 FS: TowerField,
246 U: PackScalar<FS>,
247 F: ExtensionField<FS>,
248{
249 #[inline]
250 pub fn packed(&mut self) -> &mut [PackedType<U, FS>] {
251 PackedType::<U, FS>::from_underliers_ref_mut(self.underliers())
252 }
253
254 #[inline]
255 pub fn as_mut_slice<T: Pod>(&mut self) -> &mut [T] {
256 must_cast_slice_mut(self.underliers())
257 }
258
259 #[inline]
260 const fn underliers(&mut self) -> &mut [U] {
261 self.data
262 .as_mut()
263 .expect("Should only be None after Drop::drop has run")
264 }
265}
266
267impl<FS> Drop for EntryBuilder<'_, FS>
268where
269 FS: TowerField,
270 U: PackScalar<FS>,
271 F: ExtensionField<FS>,
272{
273 fn drop(&mut self) {
274 let data = Option::take(&mut self.data).expect("data is always Some until this point");
275 let mut entries = self.entries.borrow_mut();
276 let oracle_index = self.id.index();
277 let nonzero_scalars_prefix = self.nonzero_scalars_prefix;
278 if oracle_index >= entries.len() {
279 entries.resize_with(oracle_index + 1, || None);
280 }
281 entries[oracle_index] = Some(WitnessBuilderEntry {
282 data,
283 nonzero_scalars_prefix,
284 tower_level: FS::TOWER_LEVEL,
285 witness: MultilinearExtension::new(
286 self.log_rows,
287 PackedType::<U, FS>::from_underliers_ref(data),
288 )
289 .map(|x| x.specialize_arc_dyn()),
290 })
291 }
292}