1use std::{cell::RefCell, marker::PhantomData, rc::Rc};
4
5use anyhow::{anyhow, Error};
6use binius_core::{
7 oracle::{MultilinearOracleSet, OracleId},
8 witness::{MultilinearExtensionIndex, MultilinearWitness},
9};
10use binius_field::{
11 as_packed_field::{PackScalar, PackedType},
12 underlier::WithUnderlier,
13 ExtensionField, PackedField, TowerField,
14};
15use binius_math::MultilinearExtension;
16use binius_utils::bail;
17use bytemuck::{must_cast_slice, must_cast_slice_mut, Pod};
18
19use super::types::{F, U};
20
21pub struct Builder<'arena> {
22 bump: &'arena bumpalo::Bump,
23
24 oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
25
26 #[allow(clippy::type_complexity)]
27 entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
28}
29
30struct WitnessBuilderEntry<'arena> {
31 witness: Result<MultilinearWitness<'arena, PackedType<U, F>>, binius_math::Error>,
32 tower_level: usize,
33 data: &'arena [U],
34}
35
36impl<'arena> Builder<'arena> {
37 pub fn new(
38 allocator: &'arena bumpalo::Bump,
39 oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
40 ) -> Self {
41 Self {
42 bump: allocator,
43 oracles,
44 entries: Rc::new(RefCell::new(Vec::new())),
45 }
46 }
47
48 pub fn new_column<FS: TowerField>(&self, id: OracleId) -> EntryBuilder<'arena, FS>
49 where
50 U: PackScalar<FS>,
51 F: ExtensionField<FS>,
52 {
53 let oracles = self.oracles.borrow();
54 let log_rows = oracles.n_vars(id);
55 let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
56 let data = bumpalo::vec![in self.bump; U::default(); len].into_bump_slice_mut();
57 EntryBuilder {
58 _marker: PhantomData,
59 log_rows,
60 id,
61 data: Some(data),
62 entries: self.entries.clone(),
63 }
64 }
65
66 pub fn new_column_with_default<FS: TowerField>(
67 &self,
68 id: OracleId,
69 default: FS,
70 ) -> EntryBuilder<'arena, FS>
71 where
72 U: PackScalar<FS>,
73 F: ExtensionField<FS>,
74 {
75 let oracles = self.oracles.borrow();
76 let log_rows = oracles.n_vars(id);
77 let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
78 let default = WithUnderlier::to_underlier(PackedType::<U, FS>::broadcast(default));
79 let data = bumpalo::vec![in self.bump; default; len].into_bump_slice_mut();
80 EntryBuilder {
81 _marker: PhantomData,
82 log_rows,
83 id,
84 data: Some(data),
85 entries: self.entries.clone(),
86 }
87 }
88
89 pub fn get<FS>(&self, id: OracleId) -> Result<WitnessEntry<'arena, FS>, Error>
90 where
91 FS: TowerField,
92 U: PackScalar<FS>,
93 F: ExtensionField<FS>,
94 {
95 let entries = self.entries.borrow();
96 let oracles = self.oracles.borrow();
97 if !oracles.is_valid_oracle_id(id) {
98 bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
99 }
100 let entry = entries
101 .get(id)
102 .and_then(|entry| entry.as_ref())
103 .ok_or_else(|| anyhow!("Witness for {} is missing", oracles.label(id)))?;
104
105 if entry.tower_level != FS::TOWER_LEVEL {
106 bail!(anyhow!(
107 "Provided tower level ({}) for {} does not match stored tower level {}.",
108 FS::TOWER_LEVEL,
109 oracles.label(id),
110 entry.tower_level
111 ));
112 }
113
114 Ok(WitnessEntry {
115 data: entry.data,
116 log_rows: oracles.n_vars(id),
117 _marker: PhantomData,
118 })
119 }
120
121 pub fn set<FS: TowerField>(
122 &self,
123 id: OracleId,
124 entry: WitnessEntry<'arena, FS>,
125 ) -> Result<(), Error>
126 where
127 U: PackScalar<FS>,
128 F: ExtensionField<FS>,
129 {
130 let oracles = self.oracles.borrow();
131 if !oracles.is_valid_oracle_id(id) {
132 bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
133 }
134 let mut entries = self.entries.borrow_mut();
135 if id >= entries.len() {
136 entries.resize_with(id + 1, || None);
137 }
138 entries[id] = Some(WitnessBuilderEntry {
139 data: entry.data,
140 tower_level: FS::TOWER_LEVEL,
141 witness: MultilinearExtension::new(entry.log_rows, entry.packed())
142 .map(|x| x.specialize_arc_dyn()),
143 });
144 Ok(())
145 }
146
147 pub fn build(self) -> Result<MultilinearExtensionIndex<'arena, U, F>, Error> {
148 let mut result = MultilinearExtensionIndex::new();
149 let entries = Rc::into_inner(self.entries)
150 .ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))?
151 .into_inner()
152 .into_iter()
153 .enumerate()
154 .filter_map(|(id, entry)| entry.map(|entry| Ok((id, entry.witness?))))
155 .collect::<Result<Vec<_>, Error>>()?;
156 result.update_multilin_poly(entries)?;
157 Ok(result)
158 }
159}
160
161#[derive(Debug, Clone, Copy)]
162pub struct WitnessEntry<'arena, FS: TowerField>
163where
164 U: PackScalar<FS>,
165{
166 data: &'arena [U],
167 log_rows: usize,
168 _marker: PhantomData<FS>,
169}
170
171impl<'arena, FS: TowerField> WitnessEntry<'arena, FS>
172where
173 U: PackScalar<FS>,
174{
175 #[inline]
176 pub fn packed(&self) -> &'arena [PackedType<U, FS>] {
177 WithUnderlier::from_underliers_ref(self.data)
178 }
179
180 #[inline]
181 pub const fn as_slice<T: Pod>(&self) -> &'arena [T] {
182 must_cast_slice(self.data)
183 }
184
185 pub const fn repacked<FE>(&self) -> WitnessEntry<'arena, FE>
186 where
187 FE: TowerField + ExtensionField<FS>,
188 U: PackScalar<FE>,
189 {
190 WitnessEntry {
191 data: self.data,
192 log_rows: self.log_rows - <FE as ExtensionField<FS>>::LOG_DEGREE,
193 _marker: PhantomData,
194 }
195 }
196
197 pub const fn low_rows(&self) -> usize {
198 self.log_rows
199 }
200}
201
202pub struct EntryBuilder<'arena, FS>
203where
204 FS: TowerField,
205 U: PackScalar<FS>,
206 F: ExtensionField<FS>,
207{
208 _marker: PhantomData<FS>,
209 #[allow(clippy::type_complexity)]
210 entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
211 id: OracleId,
212 log_rows: usize,
213 data: Option<&'arena mut [U]>,
214}
215
216impl<FS> EntryBuilder<'_, FS>
217where
218 FS: TowerField,
219 U: PackScalar<FS>,
220 F: ExtensionField<FS>,
221{
222 #[inline]
223 pub fn packed(&mut self) -> &mut [PackedType<U, FS>] {
224 PackedType::<U, FS>::from_underliers_ref_mut(self.underliers())
225 }
226
227 #[inline]
228 pub fn as_mut_slice<T: Pod>(&mut self) -> &mut [T] {
229 must_cast_slice_mut(self.underliers())
230 }
231
232 #[inline]
233 const fn underliers(&mut self) -> &mut [U] {
234 self.data
235 .as_mut()
236 .expect("Should only be None after Drop::drop has run")
237 }
238}
239
240impl<FS> Drop for EntryBuilder<'_, FS>
241where
242 FS: TowerField,
243 U: PackScalar<FS>,
244 F: ExtensionField<FS>,
245{
246 fn drop(&mut self) {
247 let data = Option::take(&mut self.data).expect("data is always Some until this point");
248 let mut entries = self.entries.borrow_mut();
249 let id = self.id;
250 if id >= entries.len() {
251 entries.resize_with(id + 1, || None);
252 }
253 entries[id] = Some(WitnessBuilderEntry {
254 data,
255 tower_level: FS::TOWER_LEVEL,
256 witness: MultilinearExtension::new(
257 self.log_rows,
258 PackedType::<U, FS>::from_underliers_ref(data),
259 )
260 .map(|x| x.specialize_arc_dyn()),
261 })
262 }
263}