binius_circuits/builder/
witness.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cell::RefCell, marker::PhantomData, rc::Rc};
4
5use anyhow::{anyhow, Error};
6use binius_core::{
7	oracle::{MultilinearOracleSet, OracleId},
8	witness::{MultilinearExtensionIndex, MultilinearWitness},
9};
10use binius_field::{
11	as_packed_field::{PackScalar, PackedType},
12	underlier::WithUnderlier,
13	ExtensionField, PackedField, TowerField,
14};
15use binius_math::MultilinearExtension;
16use binius_utils::bail;
17use bytemuck::{must_cast_slice, must_cast_slice_mut, Pod};
18
19use super::types::{F, U};
20
21pub struct Builder<'arena> {
22	bump: &'arena bumpalo::Bump,
23
24	oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
25
26	#[allow(clippy::type_complexity)]
27	entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
28}
29
30struct WitnessBuilderEntry<'arena> {
31	witness: Result<MultilinearWitness<'arena, PackedType<U, F>>, binius_math::Error>,
32	tower_level: usize,
33	nonzero_scalars_prefix: usize,
34	data: &'arena [U],
35}
36
37impl<'arena> Builder<'arena> {
38	pub fn new(
39		allocator: &'arena bumpalo::Bump,
40		oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
41	) -> Self {
42		Self {
43			bump: allocator,
44			oracles,
45			entries: Rc::new(RefCell::new(Vec::new())),
46		}
47	}
48
49	pub fn new_column<FS: TowerField>(&self, id: OracleId) -> EntryBuilder<'arena, FS>
50	where
51		U: PackScalar<FS>,
52		F: ExtensionField<FS>,
53	{
54		let nonzero_scalars_prefix = 1 << self.oracles.borrow().n_vars(id);
55		self.new_column_with_nonzero_scalars_prefix(id, nonzero_scalars_prefix)
56	}
57
58	pub fn new_column_with_nonzero_scalars_prefix<FS: TowerField>(
59		&self,
60		id: OracleId,
61		nonzero_scalars_prefix: usize,
62	) -> EntryBuilder<'arena, FS>
63	where
64		U: PackScalar<FS>,
65		F: ExtensionField<FS>,
66	{
67		let oracles = self.oracles.borrow();
68		let log_rows = oracles.n_vars(id);
69		// TODO: validate nonzero_scalars_prefix
70		let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
71		let data = bumpalo::vec![in self.bump; U::default(); len].into_bump_slice_mut();
72		EntryBuilder {
73			_marker: PhantomData,
74			entries: self.entries.clone(),
75			id,
76			log_rows,
77			nonzero_scalars_prefix,
78			data: Some(data),
79		}
80	}
81
82	pub fn new_column_with_default<FS: TowerField>(
83		&self,
84		id: OracleId,
85		default: FS,
86	) -> EntryBuilder<'arena, FS>
87	where
88		U: PackScalar<FS>,
89		F: ExtensionField<FS>,
90	{
91		let oracles = self.oracles.borrow();
92		let log_rows = oracles.n_vars(id);
93		let nonzero_scalars_prefix = 1 << log_rows;
94		let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
95		let default = WithUnderlier::to_underlier(PackedType::<U, FS>::broadcast(default));
96		let data = bumpalo::vec![in self.bump; default; len].into_bump_slice_mut();
97		EntryBuilder {
98			_marker: PhantomData,
99			entries: self.entries.clone(),
100			id,
101			log_rows,
102			nonzero_scalars_prefix,
103			data: Some(data),
104		}
105	}
106
107	pub fn get<FS>(&self, id: OracleId) -> Result<WitnessEntry<'arena, FS>, Error>
108	where
109		FS: TowerField,
110		U: PackScalar<FS>,
111		F: ExtensionField<FS>,
112	{
113		let entries = self.entries.borrow();
114		let oracles = self.oracles.borrow();
115		if !oracles.is_valid_oracle_id(id) {
116			bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
117		}
118		let entry = entries
119			.get(id.index())
120			.and_then(|entry| entry.as_ref())
121			.ok_or_else(|| anyhow!("Witness for {} is missing", oracles.label(id)))?;
122
123		if entry.tower_level != FS::TOWER_LEVEL {
124			bail!(anyhow!(
125				"Provided tower level ({}) for {} does not match stored tower level {}.",
126				FS::TOWER_LEVEL,
127				oracles.label(id),
128				entry.tower_level
129			));
130		}
131
132		Ok(WitnessEntry {
133			data: entry.data,
134			log_rows: oracles.n_vars(id),
135			nonzero_scalars_prefix: entry.nonzero_scalars_prefix,
136			_marker: PhantomData,
137		})
138	}
139
140	pub fn set<FS: TowerField>(
141		&self,
142		oracle_id: OracleId,
143		entry: WitnessEntry<'arena, FS>,
144	) -> Result<(), Error>
145	where
146		U: PackScalar<FS>,
147		F: ExtensionField<FS>,
148	{
149		let oracles = self.oracles.borrow();
150		if !oracles.is_valid_oracle_id(oracle_id) {
151			bail!(anyhow!("OracleId {oracle_id} does not exist in MultilinearOracleSet"));
152		}
153		let mut entries = self.entries.borrow_mut();
154		let oracle_index = oracle_id.index();
155		if oracle_index >= entries.len() {
156			entries.resize_with(oracle_index + 1, || None);
157		}
158		entries[oracle_index] = Some(WitnessBuilderEntry {
159			data: entry.data,
160			nonzero_scalars_prefix: entry.nonzero_scalars_prefix,
161			tower_level: FS::TOWER_LEVEL,
162			witness: MultilinearExtension::new(entry.log_rows, entry.packed())
163				.map(|x| x.specialize_arc_dyn()),
164		});
165		Ok(())
166	}
167
168	pub fn build(self) -> Result<MultilinearExtensionIndex<'arena, PackedType<U, F>>, Error> {
169		let mut result = MultilinearExtensionIndex::new();
170		let entries = Rc::into_inner(self.entries)
171			.ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))?
172			.into_inner()
173			.into_iter()
174			.enumerate()
175			.filter_map(|(index, entry)| entry.map(|entry| Ok((OracleId::from_index(index), entry.witness?, entry.nonzero_scalars_prefix))))
176			.collect::<Result<Vec<_>, Error>>()?;
177		result.update_multilin_poly_with_nonzero_scalars_prefixes(entries)?;
178		Ok(result)
179	}
180}
181
182#[derive(Debug, Clone, Copy)]
183pub struct WitnessEntry<'arena, FS: TowerField>
184where
185	U: PackScalar<FS>,
186{
187	data: &'arena [U],
188	log_rows: usize,
189	nonzero_scalars_prefix: usize,
190	_marker: PhantomData<FS>,
191}
192
193impl<'arena, FS: TowerField> WitnessEntry<'arena, FS>
194where
195	U: PackScalar<FS>,
196{
197	#[inline]
198	pub fn packed(&self) -> &'arena [PackedType<U, FS>] {
199		WithUnderlier::from_underliers_ref(self.data)
200	}
201
202	#[inline]
203	pub const fn as_slice<T: Pod>(&self) -> &'arena [T] {
204		must_cast_slice(self.data)
205	}
206
207	pub const fn repacked<FE>(&self) -> WitnessEntry<'arena, FE>
208	where
209		FE: TowerField + ExtensionField<FS>,
210		U: PackScalar<FE>,
211	{
212		let log_extension_degree = <FE as ExtensionField<FS>>::LOG_DEGREE;
213		WitnessEntry {
214			data: self.data,
215			log_rows: self.log_rows - log_extension_degree,
216			nonzero_scalars_prefix: self
217				.nonzero_scalars_prefix
218				.div_ceil(1 << log_extension_degree),
219			_marker: PhantomData,
220		}
221	}
222
223	pub const fn low_rows(&self) -> usize {
224		self.log_rows
225	}
226}
227
228pub struct EntryBuilder<'arena, FS>
229where
230	FS: TowerField,
231	U: PackScalar<FS>,
232	F: ExtensionField<FS>,
233{
234	_marker: PhantomData<FS>,
235	#[allow(clippy::type_complexity)]
236	entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
237	id: OracleId,
238	log_rows: usize,
239	nonzero_scalars_prefix: usize,
240	data: Option<&'arena mut [U]>,
241}
242
243impl<FS> EntryBuilder<'_, FS>
244where
245	FS: TowerField,
246	U: PackScalar<FS>,
247	F: ExtensionField<FS>,
248{
249	#[inline]
250	pub fn packed(&mut self) -> &mut [PackedType<U, FS>] {
251		PackedType::<U, FS>::from_underliers_ref_mut(self.underliers())
252	}
253
254	#[inline]
255	pub fn as_mut_slice<T: Pod>(&mut self) -> &mut [T] {
256		must_cast_slice_mut(self.underliers())
257	}
258
259	#[inline]
260	const fn underliers(&mut self) -> &mut [U] {
261		self.data
262			.as_mut()
263			.expect("Should only be None after Drop::drop has run")
264	}
265}
266
267impl<FS> Drop for EntryBuilder<'_, FS>
268where
269	FS: TowerField,
270	U: PackScalar<FS>,
271	F: ExtensionField<FS>,
272{
273	fn drop(&mut self) {
274		let data = Option::take(&mut self.data).expect("data is always Some until this point");
275		let mut entries = self.entries.borrow_mut();
276		let oracle_index = self.id.index();
277		let nonzero_scalars_prefix = self.nonzero_scalars_prefix;
278		if oracle_index >= entries.len() {
279			entries.resize_with(oracle_index + 1, || None);
280		}
281		entries[oracle_index] = Some(WitnessBuilderEntry {
282			data,
283			nonzero_scalars_prefix,
284			tower_level: FS::TOWER_LEVEL,
285			witness: MultilinearExtension::new(
286				self.log_rows,
287				PackedType::<U, FS>::from_underliers_ref(data),
288			)
289			.map(|x| x.specialize_arc_dyn()),
290		})
291	}
292}