1use std::{cell::RefCell, marker::PhantomData, rc::Rc};
4
5use anyhow::{anyhow, Error};
6use binius_core::{
7 oracle::{MultilinearOracleSet, OracleId},
8 witness::{MultilinearExtensionIndex, MultilinearWitness},
9};
10use binius_field::{
11 as_packed_field::{PackScalar, PackedType},
12 underlier::WithUnderlier,
13 ExtensionField, PackedField, TowerField,
14};
15use binius_math::MultilinearExtension;
16use binius_utils::bail;
17use bytemuck::{must_cast_slice, must_cast_slice_mut, Pod};
18
19use super::types::{F, U};
20
21pub struct Builder<'arena> {
22 bump: &'arena bumpalo::Bump,
23
24 oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
25
26 #[allow(clippy::type_complexity)]
27 entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
28}
29
30struct WitnessBuilderEntry<'arena> {
31 witness: Result<MultilinearWitness<'arena, PackedType<U, F>>, binius_math::Error>,
32 tower_level: usize,
33 nonzero_scalars_prefix: usize,
34 data: &'arena [U],
35}
36
37impl<'arena> Builder<'arena> {
38 pub fn new(
39 allocator: &'arena bumpalo::Bump,
40 oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
41 ) -> Self {
42 Self {
43 bump: allocator,
44 oracles,
45 entries: Rc::new(RefCell::new(Vec::new())),
46 }
47 }
48
49 pub fn new_column<FS: TowerField>(&self, id: OracleId) -> EntryBuilder<'arena, FS>
50 where
51 U: PackScalar<FS>,
52 F: ExtensionField<FS>,
53 {
54 let nonzero_scalars_prefix = 1 << self.oracles.borrow().n_vars(id);
55 self.new_column_with_nonzero_scalars_prefix(id, nonzero_scalars_prefix)
56 }
57
58 pub fn new_column_with_nonzero_scalars_prefix<FS: TowerField>(
59 &self,
60 id: OracleId,
61 nonzero_scalars_prefix: usize,
62 ) -> EntryBuilder<'arena, FS>
63 where
64 U: PackScalar<FS>,
65 F: ExtensionField<FS>,
66 {
67 let oracles = self.oracles.borrow();
68 let log_rows = oracles.n_vars(id);
69 let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
71 let data = bumpalo::vec![in self.bump; U::default(); len].into_bump_slice_mut();
72 EntryBuilder {
73 _marker: PhantomData,
74 entries: self.entries.clone(),
75 id,
76 log_rows,
77 nonzero_scalars_prefix,
78 data: Some(data),
79 }
80 }
81
82 pub fn new_column_with_default<FS: TowerField>(
83 &self,
84 id: OracleId,
85 default: FS,
86 ) -> EntryBuilder<'arena, FS>
87 where
88 U: PackScalar<FS>,
89 F: ExtensionField<FS>,
90 {
91 let oracles = self.oracles.borrow();
92 let log_rows = oracles.n_vars(id);
93 let nonzero_scalars_prefix = 1 << log_rows;
94 let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
95 let default = WithUnderlier::to_underlier(PackedType::<U, FS>::broadcast(default));
96 let data = bumpalo::vec![in self.bump; default; len].into_bump_slice_mut();
97 EntryBuilder {
98 _marker: PhantomData,
99 entries: self.entries.clone(),
100 id,
101 log_rows,
102 nonzero_scalars_prefix,
103 data: Some(data),
104 }
105 }
106
107 pub fn get<FS>(&self, id: OracleId) -> Result<WitnessEntry<'arena, FS>, Error>
108 where
109 FS: TowerField,
110 U: PackScalar<FS>,
111 F: ExtensionField<FS>,
112 {
113 let entries = self.entries.borrow();
114 let oracles = self.oracles.borrow();
115 if !oracles.is_valid_oracle_id(id) {
116 bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
117 }
118 let entry = entries
119 .get(id)
120 .and_then(|entry| entry.as_ref())
121 .ok_or_else(|| anyhow!("Witness for {} is missing", oracles.label(id)))?;
122
123 if entry.tower_level != FS::TOWER_LEVEL {
124 bail!(anyhow!(
125 "Provided tower level ({}) for {} does not match stored tower level {}.",
126 FS::TOWER_LEVEL,
127 oracles.label(id),
128 entry.tower_level
129 ));
130 }
131
132 Ok(WitnessEntry {
133 data: entry.data,
134 log_rows: oracles.n_vars(id),
135 nonzero_scalars_prefix: entry.nonzero_scalars_prefix,
136 _marker: PhantomData,
137 })
138 }
139
140 pub fn set<FS: TowerField>(
141 &self,
142 id: OracleId,
143 entry: WitnessEntry<'arena, FS>,
144 ) -> Result<(), Error>
145 where
146 U: PackScalar<FS>,
147 F: ExtensionField<FS>,
148 {
149 let oracles = self.oracles.borrow();
150 if !oracles.is_valid_oracle_id(id) {
151 bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
152 }
153 let mut entries = self.entries.borrow_mut();
154 if id >= entries.len() {
155 entries.resize_with(id + 1, || None);
156 }
157 entries[id] = Some(WitnessBuilderEntry {
158 data: entry.data,
159 nonzero_scalars_prefix: entry.nonzero_scalars_prefix,
160 tower_level: FS::TOWER_LEVEL,
161 witness: MultilinearExtension::new(entry.log_rows, entry.packed())
162 .map(|x| x.specialize_arc_dyn()),
163 });
164 Ok(())
165 }
166
167 pub fn build(self) -> Result<MultilinearExtensionIndex<'arena, PackedType<U, F>>, Error> {
168 let mut result = MultilinearExtensionIndex::new();
169 let entries = Rc::into_inner(self.entries)
170 .ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))?
171 .into_inner()
172 .into_iter()
173 .enumerate()
174 .filter_map(|(id, entry)| entry.map(|entry| Ok((id, entry.witness?, entry.nonzero_scalars_prefix))))
175 .collect::<Result<Vec<_>, Error>>()?;
176 result.update_multilin_poly_with_nonzero_scalars_prefixes(entries)?;
177 Ok(result)
178 }
179}
180
181#[derive(Debug, Clone, Copy)]
182pub struct WitnessEntry<'arena, FS: TowerField>
183where
184 U: PackScalar<FS>,
185{
186 data: &'arena [U],
187 log_rows: usize,
188 nonzero_scalars_prefix: usize,
189 _marker: PhantomData<FS>,
190}
191
192impl<'arena, FS: TowerField> WitnessEntry<'arena, FS>
193where
194 U: PackScalar<FS>,
195{
196 #[inline]
197 pub fn packed(&self) -> &'arena [PackedType<U, FS>] {
198 WithUnderlier::from_underliers_ref(self.data)
199 }
200
201 #[inline]
202 pub const fn as_slice<T: Pod>(&self) -> &'arena [T] {
203 must_cast_slice(self.data)
204 }
205
206 pub const fn repacked<FE>(&self) -> WitnessEntry<'arena, FE>
207 where
208 FE: TowerField + ExtensionField<FS>,
209 U: PackScalar<FE>,
210 {
211 let log_extension_degree = <FE as ExtensionField<FS>>::LOG_DEGREE;
212 WitnessEntry {
213 data: self.data,
214 log_rows: self.log_rows - log_extension_degree,
215 nonzero_scalars_prefix: self
216 .nonzero_scalars_prefix
217 .div_ceil(1 << log_extension_degree),
218 _marker: PhantomData,
219 }
220 }
221
222 pub const fn low_rows(&self) -> usize {
223 self.log_rows
224 }
225}
226
227pub struct EntryBuilder<'arena, FS>
228where
229 FS: TowerField,
230 U: PackScalar<FS>,
231 F: ExtensionField<FS>,
232{
233 _marker: PhantomData<FS>,
234 #[allow(clippy::type_complexity)]
235 entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
236 id: OracleId,
237 log_rows: usize,
238 nonzero_scalars_prefix: usize,
239 data: Option<&'arena mut [U]>,
240}
241
242impl<FS> EntryBuilder<'_, FS>
243where
244 FS: TowerField,
245 U: PackScalar<FS>,
246 F: ExtensionField<FS>,
247{
248 #[inline]
249 pub fn packed(&mut self) -> &mut [PackedType<U, FS>] {
250 PackedType::<U, FS>::from_underliers_ref_mut(self.underliers())
251 }
252
253 #[inline]
254 pub fn as_mut_slice<T: Pod>(&mut self) -> &mut [T] {
255 must_cast_slice_mut(self.underliers())
256 }
257
258 #[inline]
259 const fn underliers(&mut self) -> &mut [U] {
260 self.data
261 .as_mut()
262 .expect("Should only be None after Drop::drop has run")
263 }
264}
265
266impl<FS> Drop for EntryBuilder<'_, FS>
267where
268 FS: TowerField,
269 U: PackScalar<FS>,
270 F: ExtensionField<FS>,
271{
272 fn drop(&mut self) {
273 let data = Option::take(&mut self.data).expect("data is always Some until this point");
274 let mut entries = self.entries.borrow_mut();
275 let id = self.id;
276 let nonzero_scalars_prefix = self.nonzero_scalars_prefix;
277 if id >= entries.len() {
278 entries.resize_with(id + 1, || None);
279 }
280 entries[id] = Some(WitnessBuilderEntry {
281 data,
282 nonzero_scalars_prefix,
283 tower_level: FS::TOWER_LEVEL,
284 witness: MultilinearExtension::new(
285 self.log_rows,
286 PackedType::<U, FS>::from_underliers_ref(data),
287 )
288 .map(|x| x.specialize_arc_dyn()),
289 })
290 }
291}