binius_circuits/builder/
witness.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cell::RefCell, marker::PhantomData, rc::Rc};
4
5use anyhow::{anyhow, Error};
6use binius_core::{
7	oracle::{MultilinearOracleSet, OracleId},
8	witness::{MultilinearExtensionIndex, MultilinearWitness},
9};
10use binius_field::{
11	as_packed_field::{PackScalar, PackedType},
12	underlier::WithUnderlier,
13	ExtensionField, PackedField, TowerField,
14};
15use binius_math::MultilinearExtension;
16use binius_utils::bail;
17use bytemuck::{must_cast_slice, must_cast_slice_mut, Pod};
18
19use super::types::{F, U};
20
21pub struct Builder<'arena> {
22	bump: &'arena bumpalo::Bump,
23
24	oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
25
26	#[allow(clippy::type_complexity)]
27	entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
28}
29
30struct WitnessBuilderEntry<'arena> {
31	witness: Result<MultilinearWitness<'arena, PackedType<U, F>>, binius_math::Error>,
32	tower_level: usize,
33	nonzero_scalars_prefix: usize,
34	data: &'arena [U],
35}
36
37impl<'arena> Builder<'arena> {
38	pub fn new(
39		allocator: &'arena bumpalo::Bump,
40		oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
41	) -> Self {
42		Self {
43			bump: allocator,
44			oracles,
45			entries: Rc::new(RefCell::new(Vec::new())),
46		}
47	}
48
49	pub fn new_column<FS: TowerField>(&self, id: OracleId) -> EntryBuilder<'arena, FS>
50	where
51		U: PackScalar<FS>,
52		F: ExtensionField<FS>,
53	{
54		let nonzero_scalars_prefix = 1 << self.oracles.borrow().n_vars(id);
55		self.new_column_with_nonzero_scalars_prefix(id, nonzero_scalars_prefix)
56	}
57
58	pub fn new_column_with_nonzero_scalars_prefix<FS: TowerField>(
59		&self,
60		id: OracleId,
61		nonzero_scalars_prefix: usize,
62	) -> EntryBuilder<'arena, FS>
63	where
64		U: PackScalar<FS>,
65		F: ExtensionField<FS>,
66	{
67		let oracles = self.oracles.borrow();
68		let log_rows = oracles.n_vars(id);
69		// TODO: validate nonzero_scalars_prefix
70		let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
71		let data = bumpalo::vec![in self.bump; U::default(); len].into_bump_slice_mut();
72		EntryBuilder {
73			_marker: PhantomData,
74			entries: self.entries.clone(),
75			id,
76			log_rows,
77			nonzero_scalars_prefix,
78			data: Some(data),
79		}
80	}
81
82	pub fn new_column_with_default<FS: TowerField>(
83		&self,
84		id: OracleId,
85		default: FS,
86	) -> EntryBuilder<'arena, FS>
87	where
88		U: PackScalar<FS>,
89		F: ExtensionField<FS>,
90	{
91		let oracles = self.oracles.borrow();
92		let log_rows = oracles.n_vars(id);
93		let nonzero_scalars_prefix = 1 << log_rows;
94		let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
95		let default = WithUnderlier::to_underlier(PackedType::<U, FS>::broadcast(default));
96		let data = bumpalo::vec![in self.bump; default; len].into_bump_slice_mut();
97		EntryBuilder {
98			_marker: PhantomData,
99			entries: self.entries.clone(),
100			id,
101			log_rows,
102			nonzero_scalars_prefix,
103			data: Some(data),
104		}
105	}
106
107	pub fn get<FS>(&self, id: OracleId) -> Result<WitnessEntry<'arena, FS>, Error>
108	where
109		FS: TowerField,
110		U: PackScalar<FS>,
111		F: ExtensionField<FS>,
112	{
113		let entries = self.entries.borrow();
114		let oracles = self.oracles.borrow();
115		if !oracles.is_valid_oracle_id(id) {
116			bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
117		}
118		let entry = entries
119			.get(id)
120			.and_then(|entry| entry.as_ref())
121			.ok_or_else(|| anyhow!("Witness for {} is missing", oracles.label(id)))?;
122
123		if entry.tower_level != FS::TOWER_LEVEL {
124			bail!(anyhow!(
125				"Provided tower level ({}) for {} does not match stored tower level {}.",
126				FS::TOWER_LEVEL,
127				oracles.label(id),
128				entry.tower_level
129			));
130		}
131
132		Ok(WitnessEntry {
133			data: entry.data,
134			log_rows: oracles.n_vars(id),
135			nonzero_scalars_prefix: entry.nonzero_scalars_prefix,
136			_marker: PhantomData,
137		})
138	}
139
140	pub fn set<FS: TowerField>(
141		&self,
142		id: OracleId,
143		entry: WitnessEntry<'arena, FS>,
144	) -> Result<(), Error>
145	where
146		U: PackScalar<FS>,
147		F: ExtensionField<FS>,
148	{
149		let oracles = self.oracles.borrow();
150		if !oracles.is_valid_oracle_id(id) {
151			bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
152		}
153		let mut entries = self.entries.borrow_mut();
154		if id >= entries.len() {
155			entries.resize_with(id + 1, || None);
156		}
157		entries[id] = Some(WitnessBuilderEntry {
158			data: entry.data,
159			nonzero_scalars_prefix: entry.nonzero_scalars_prefix,
160			tower_level: FS::TOWER_LEVEL,
161			witness: MultilinearExtension::new(entry.log_rows, entry.packed())
162				.map(|x| x.specialize_arc_dyn()),
163		});
164		Ok(())
165	}
166
167	pub fn build(self) -> Result<MultilinearExtensionIndex<'arena, PackedType<U, F>>, Error> {
168		let mut result = MultilinearExtensionIndex::new();
169		let entries = Rc::into_inner(self.entries)
170			.ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))?
171			.into_inner()
172			.into_iter()
173			.enumerate()
174			.filter_map(|(id, entry)| entry.map(|entry| Ok((id, entry.witness?, entry.nonzero_scalars_prefix))))
175			.collect::<Result<Vec<_>, Error>>()?;
176		result.update_multilin_poly_with_nonzero_scalars_prefixes(entries)?;
177		Ok(result)
178	}
179}
180
181#[derive(Debug, Clone, Copy)]
182pub struct WitnessEntry<'arena, FS: TowerField>
183where
184	U: PackScalar<FS>,
185{
186	data: &'arena [U],
187	log_rows: usize,
188	nonzero_scalars_prefix: usize,
189	_marker: PhantomData<FS>,
190}
191
192impl<'arena, FS: TowerField> WitnessEntry<'arena, FS>
193where
194	U: PackScalar<FS>,
195{
196	#[inline]
197	pub fn packed(&self) -> &'arena [PackedType<U, FS>] {
198		WithUnderlier::from_underliers_ref(self.data)
199	}
200
201	#[inline]
202	pub const fn as_slice<T: Pod>(&self) -> &'arena [T] {
203		must_cast_slice(self.data)
204	}
205
206	pub const fn repacked<FE>(&self) -> WitnessEntry<'arena, FE>
207	where
208		FE: TowerField + ExtensionField<FS>,
209		U: PackScalar<FE>,
210	{
211		let log_extension_degree = <FE as ExtensionField<FS>>::LOG_DEGREE;
212		WitnessEntry {
213			data: self.data,
214			log_rows: self.log_rows - log_extension_degree,
215			nonzero_scalars_prefix: self
216				.nonzero_scalars_prefix
217				.div_ceil(1 << log_extension_degree),
218			_marker: PhantomData,
219		}
220	}
221
222	pub const fn low_rows(&self) -> usize {
223		self.log_rows
224	}
225}
226
227pub struct EntryBuilder<'arena, FS>
228where
229	FS: TowerField,
230	U: PackScalar<FS>,
231	F: ExtensionField<FS>,
232{
233	_marker: PhantomData<FS>,
234	#[allow(clippy::type_complexity)]
235	entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
236	id: OracleId,
237	log_rows: usize,
238	nonzero_scalars_prefix: usize,
239	data: Option<&'arena mut [U]>,
240}
241
242impl<FS> EntryBuilder<'_, FS>
243where
244	FS: TowerField,
245	U: PackScalar<FS>,
246	F: ExtensionField<FS>,
247{
248	#[inline]
249	pub fn packed(&mut self) -> &mut [PackedType<U, FS>] {
250		PackedType::<U, FS>::from_underliers_ref_mut(self.underliers())
251	}
252
253	#[inline]
254	pub fn as_mut_slice<T: Pod>(&mut self) -> &mut [T] {
255		must_cast_slice_mut(self.underliers())
256	}
257
258	#[inline]
259	const fn underliers(&mut self) -> &mut [U] {
260		self.data
261			.as_mut()
262			.expect("Should only be None after Drop::drop has run")
263	}
264}
265
266impl<FS> Drop for EntryBuilder<'_, FS>
267where
268	FS: TowerField,
269	U: PackScalar<FS>,
270	F: ExtensionField<FS>,
271{
272	fn drop(&mut self) {
273		let data = Option::take(&mut self.data).expect("data is always Some until this point");
274		let mut entries = self.entries.borrow_mut();
275		let id = self.id;
276		let nonzero_scalars_prefix = self.nonzero_scalars_prefix;
277		if id >= entries.len() {
278			entries.resize_with(id + 1, || None);
279		}
280		entries[id] = Some(WitnessBuilderEntry {
281			data,
282			nonzero_scalars_prefix,
283			tower_level: FS::TOWER_LEVEL,
284			witness: MultilinearExtension::new(
285				self.log_rows,
286				PackedType::<U, FS>::from_underliers_ref(data),
287			)
288			.map(|x| x.specialize_arc_dyn()),
289		})
290	}
291}