binius_circuits/builder/
witness.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{cell::RefCell, marker::PhantomData, rc::Rc};
4
5use anyhow::{anyhow, Error};
6use binius_core::{
7	oracle::{MultilinearOracleSet, OracleId},
8	witness::{MultilinearExtensionIndex, MultilinearWitness},
9};
10use binius_field::{
11	as_packed_field::{PackScalar, PackedType},
12	underlier::WithUnderlier,
13	ExtensionField, PackedField, TowerField,
14};
15use binius_math::MultilinearExtension;
16use binius_utils::bail;
17use bytemuck::{must_cast_slice, must_cast_slice_mut, Pod};
18
19use super::types::{F, U};
20
21pub struct Builder<'arena> {
22	bump: &'arena bumpalo::Bump,
23
24	oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
25
26	#[allow(clippy::type_complexity)]
27	entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
28}
29
30struct WitnessBuilderEntry<'arena> {
31	witness: Result<MultilinearWitness<'arena, PackedType<U, F>>, binius_math::Error>,
32	tower_level: usize,
33	data: &'arena [U],
34}
35
36impl<'arena> Builder<'arena> {
37	pub fn new(
38		allocator: &'arena bumpalo::Bump,
39		oracles: Rc<RefCell<MultilinearOracleSet<F>>>,
40	) -> Self {
41		Self {
42			bump: allocator,
43			oracles,
44			entries: Rc::new(RefCell::new(Vec::new())),
45		}
46	}
47
48	pub fn new_column<FS: TowerField>(&self, id: OracleId) -> EntryBuilder<'arena, FS>
49	where
50		U: PackScalar<FS>,
51		F: ExtensionField<FS>,
52	{
53		let oracles = self.oracles.borrow();
54		let log_rows = oracles.n_vars(id);
55		let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
56		let data = bumpalo::vec![in self.bump; U::default(); len].into_bump_slice_mut();
57		EntryBuilder {
58			_marker: PhantomData,
59			log_rows,
60			id,
61			data: Some(data),
62			entries: self.entries.clone(),
63		}
64	}
65
66	pub fn new_column_with_default<FS: TowerField>(
67		&self,
68		id: OracleId,
69		default: FS,
70	) -> EntryBuilder<'arena, FS>
71	where
72		U: PackScalar<FS>,
73		F: ExtensionField<FS>,
74	{
75		let oracles = self.oracles.borrow();
76		let log_rows = oracles.n_vars(id);
77		let len = 1 << log_rows.saturating_sub(<PackedType<U, FS>>::LOG_WIDTH);
78		let default = WithUnderlier::to_underlier(PackedType::<U, FS>::broadcast(default));
79		let data = bumpalo::vec![in self.bump; default; len].into_bump_slice_mut();
80		EntryBuilder {
81			_marker: PhantomData,
82			log_rows,
83			id,
84			data: Some(data),
85			entries: self.entries.clone(),
86		}
87	}
88
89	pub fn get<FS>(&self, id: OracleId) -> Result<WitnessEntry<'arena, FS>, Error>
90	where
91		FS: TowerField,
92		U: PackScalar<FS>,
93		F: ExtensionField<FS>,
94	{
95		let entries = self.entries.borrow();
96		let oracles = self.oracles.borrow();
97		if !oracles.is_valid_oracle_id(id) {
98			bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
99		}
100		let entry = entries
101			.get(id)
102			.and_then(|entry| entry.as_ref())
103			.ok_or_else(|| anyhow!("Witness for {} is missing", oracles.label(id)))?;
104
105		if entry.tower_level != FS::TOWER_LEVEL {
106			bail!(anyhow!(
107				"Provided tower level ({}) for {} does not match stored tower level {}.",
108				FS::TOWER_LEVEL,
109				oracles.label(id),
110				entry.tower_level
111			));
112		}
113
114		Ok(WitnessEntry {
115			data: entry.data,
116			log_rows: oracles.n_vars(id),
117			_marker: PhantomData,
118		})
119	}
120
121	pub fn set<FS: TowerField>(
122		&self,
123		id: OracleId,
124		entry: WitnessEntry<'arena, FS>,
125	) -> Result<(), Error>
126	where
127		U: PackScalar<FS>,
128		F: ExtensionField<FS>,
129	{
130		let oracles = self.oracles.borrow();
131		if !oracles.is_valid_oracle_id(id) {
132			bail!(anyhow!("OracleId {id} does not exist in MultilinearOracleSet"));
133		}
134		let mut entries = self.entries.borrow_mut();
135		if id >= entries.len() {
136			entries.resize_with(id + 1, || None);
137		}
138		entries[id] = Some(WitnessBuilderEntry {
139			data: entry.data,
140			tower_level: FS::TOWER_LEVEL,
141			witness: MultilinearExtension::new(entry.log_rows, entry.packed())
142				.map(|x| x.specialize_arc_dyn()),
143		});
144		Ok(())
145	}
146
147	pub fn build(self) -> Result<MultilinearExtensionIndex<'arena, U, F>, Error> {
148		let mut result = MultilinearExtensionIndex::new();
149		let entries = Rc::into_inner(self.entries)
150			.ok_or_else(|| anyhow!("Failed to build. There are still entries refs. Make sure there are no pending column insertions."))?
151			.into_inner()
152			.into_iter()
153			.enumerate()
154			.filter_map(|(id, entry)| entry.map(|entry| Ok((id, entry.witness?))))
155			.collect::<Result<Vec<_>, Error>>()?;
156		result.update_multilin_poly(entries)?;
157		Ok(result)
158	}
159}
160
161#[derive(Debug, Clone, Copy)]
162pub struct WitnessEntry<'arena, FS: TowerField>
163where
164	U: PackScalar<FS>,
165{
166	data: &'arena [U],
167	log_rows: usize,
168	_marker: PhantomData<FS>,
169}
170
171impl<'arena, FS: TowerField> WitnessEntry<'arena, FS>
172where
173	U: PackScalar<FS>,
174{
175	#[inline]
176	pub fn packed(&self) -> &'arena [PackedType<U, FS>] {
177		WithUnderlier::from_underliers_ref(self.data)
178	}
179
180	#[inline]
181	pub const fn as_slice<T: Pod>(&self) -> &'arena [T] {
182		must_cast_slice(self.data)
183	}
184
185	pub const fn repacked<FE>(&self) -> WitnessEntry<'arena, FE>
186	where
187		FE: TowerField + ExtensionField<FS>,
188		U: PackScalar<FE>,
189	{
190		WitnessEntry {
191			data: self.data,
192			log_rows: self.log_rows - <FE as ExtensionField<FS>>::LOG_DEGREE,
193			_marker: PhantomData,
194		}
195	}
196
197	pub const fn low_rows(&self) -> usize {
198		self.log_rows
199	}
200}
201
202pub struct EntryBuilder<'arena, FS>
203where
204	FS: TowerField,
205	U: PackScalar<FS>,
206	F: ExtensionField<FS>,
207{
208	_marker: PhantomData<FS>,
209	#[allow(clippy::type_complexity)]
210	entries: Rc<RefCell<Vec<Option<WitnessBuilderEntry<'arena>>>>>,
211	id: OracleId,
212	log_rows: usize,
213	data: Option<&'arena mut [U]>,
214}
215
216impl<FS> EntryBuilder<'_, FS>
217where
218	FS: TowerField,
219	U: PackScalar<FS>,
220	F: ExtensionField<FS>,
221{
222	#[inline]
223	pub fn packed(&mut self) -> &mut [PackedType<U, FS>] {
224		PackedType::<U, FS>::from_underliers_ref_mut(self.underliers())
225	}
226
227	#[inline]
228	pub fn as_mut_slice<T: Pod>(&mut self) -> &mut [T] {
229		must_cast_slice_mut(self.underliers())
230	}
231
232	#[inline]
233	const fn underliers(&mut self) -> &mut [U] {
234		self.data
235			.as_mut()
236			.expect("Should only be None after Drop::drop has run")
237	}
238}
239
240impl<FS> Drop for EntryBuilder<'_, FS>
241where
242	FS: TowerField,
243	U: PackScalar<FS>,
244	F: ExtensionField<FS>,
245{
246	fn drop(&mut self) {
247		let data = Option::take(&mut self.data).expect("data is always Some until this point");
248		let mut entries = self.entries.borrow_mut();
249		let id = self.id;
250		if id >= entries.len() {
251			entries.resize_with(id + 1, || None);
252		}
253		entries[id] = Some(WitnessBuilderEntry {
254			data,
255			tower_level: FS::TOWER_LEVEL,
256			witness: MultilinearExtension::new(
257				self.log_rows,
258				PackedType::<U, FS>::from_underliers_ref(data),
259			)
260			.map(|x| x.specialize_arc_dyn()),
261		})
262	}
263}