binius_core/piop/
commit.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use binius_field::{
4	as_packed_field::{PackScalar, PackedType},
5	underlier::UnderlierType,
6	TowerField,
7};
8use binius_utils::sparse_index::SparseIndex;
9
10use super::{error::Error, util::ResizeableIndex, verify::CommitMeta};
11use crate::{
12	oracle::{MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant},
13	witness::{MultilinearExtensionIndex, MultilinearWitness},
14};
15
16/// Indexes the committed oracles in a [`MultilinearOracleSet`] and returns:
17///
18/// 1. a [`CommitMeta`] struct that stores information about the committed polynomials
19/// 2. a sparse index mapping oracle IDs to committed IDs in the commit metadata
20pub fn make_oracle_commit_meta<F: TowerField>(
21	oracles: &MultilinearOracleSet<F>,
22) -> Result<(CommitMeta, SparseIndex<usize>), Error> {
23	// We need to construct two structures:
24	//
25	// 1) the commit metadata structure, which depends on the counts of the number of multilinears
26	//    per number of packed variables
27	// 2) a sparse index mapping oracle IDs to IDs in the commit metadata
28	//
29	// We will construct the two indices in two passes. On the first pass, we count the number of
30	// multilinears and assign for each oracle the index of the oracle in the bucket of oracles
31	// with the same number of packed variables. On the second pass, the commit metadata is
32	// finalized, so we can determine the absolute indices into the commit metadata structure by
33	// adding offsets.
34
35	#[derive(Clone)]
36	struct CommitIDFirstPass {
37		n_packed_vars: usize,
38		idx_in_bucket: usize,
39	}
40
41	// First pass: count the number of multilinears and index within buckets
42	let mut first_pass_index = SparseIndex::with_capacity(oracles.size());
43	let mut n_multilins_by_vars = ResizeableIndex::<usize>::new();
44	for oracle in oracles.iter() {
45		if matches!(oracle.variant, MultilinearPolyVariant::Committed) {
46			let n_packed_vars = n_packed_vars_for_committed_oracle(&oracle)?;
47			let n_multilins_for_vars = n_multilins_by_vars.get_mut(n_packed_vars);
48
49			first_pass_index.set(
50				oracle.id(),
51				CommitIDFirstPass {
52					n_packed_vars,
53					idx_in_bucket: *n_multilins_for_vars,
54				},
55			);
56			*n_multilins_for_vars += 1;
57		}
58	}
59
60	let commit_meta = CommitMeta::new(n_multilins_by_vars.into_vec());
61
62	// Second pass: use commit_meta counts to finalized indices with offsets
63	let mut index = SparseIndex::with_capacity(oracles.size());
64	for id in 0..oracles.size() {
65		if let Some(CommitIDFirstPass {
66			n_packed_vars,
67			idx_in_bucket,
68		}) = first_pass_index.get(id)
69		{
70			let offset = commit_meta.range_by_vars(*n_packed_vars).start;
71			index.set(id, offset + *idx_in_bucket);
72		}
73	}
74
75	Ok((commit_meta, index))
76}
77
78/// Collects the committed multilinear witnesses from the witness index and returns them in order.
79///
80/// During the commitment phase of the protocol, the trace polynomials are committed in a specific
81/// order recorded by the commit metadata. This collects the witnesses corresponding to committed
82/// multilinears and returns a vector of them in the commitment order.
83///
84/// ## Preconditions
85///
86/// * `oracle_to_commit_index` must be correctly constructed. Specifically, it must be surjective,
87///   mapping at exactly one oracle to every index up to the number of committed multilinears.
88pub fn collect_committed_witnesses<'a, U, F>(
89	commit_meta: &CommitMeta,
90	oracle_to_commit_index: &SparseIndex<usize>,
91	oracles: &MultilinearOracleSet<F>,
92	witness_index: &MultilinearExtensionIndex<'a, U, F>,
93) -> Result<Vec<MultilinearWitness<'a, PackedType<U, F>>>, Error>
94where
95	U: UnderlierType + PackScalar<F>,
96	F: TowerField,
97{
98	let mut witnesses = vec![None; commit_meta.total_multilins()];
99	for oracle_id in 0..oracles.size() {
100		if let Some(commit_idx) = oracle_to_commit_index.get(oracle_id) {
101			witnesses[*commit_idx] = Some(witness_index.get_multilin_poly(oracle_id)?);
102		}
103	}
104	Ok(witnesses
105		.into_iter()
106		.map(|witness| witness.expect("pre-condition: oracle_to_commit index is surjective"))
107		.collect())
108}
109
110fn n_packed_vars_for_committed_oracle<F: TowerField>(
111	oracle: &MultilinearPolyOracle<F>,
112) -> Result<usize, Error> {
113	let n_vars = oracle.n_vars();
114	let tower_level = oracle.binary_tower_level();
115	n_vars
116		.checked_sub(F::TOWER_LEVEL - tower_level)
117		.ok_or_else(|| Error::OracleTooSmall {
118			id: oracle.id(),
119			min_vars: F::TOWER_LEVEL - tower_level,
120		})
121}
122
123#[cfg(test)]
124mod tests {
125	use binius_field::BinaryField128b;
126
127	use super::*;
128
129	#[test]
130	fn test_make_oracle_commit_meta() {
131		let mut oracles = MultilinearOracleSet::<BinaryField128b>::new();
132
133		let batch_0_0_ids = oracles.add_committed_multiple::<2>(8, 0);
134		let batch_0_1_ids = oracles.add_committed_multiple::<2>(10, 0);
135		let batch_0_2_ids = oracles.add_committed_multiple::<2>(12, 0);
136
137		let repeat = oracles.add_repeating(batch_0_2_ids[0], 5).unwrap();
138
139		let batch_2_0_ids = oracles.add_committed_multiple::<2>(8, 2);
140		let batch_2_1_ids = oracles.add_committed_multiple::<2>(10, 2);
141		let batch_2_2_ids = oracles.add_committed_multiple::<2>(12, 2);
142
143		let (commit_meta, index) = make_oracle_commit_meta(&oracles).unwrap();
144		assert_eq!(commit_meta.n_multilins_by_vars(), &[0, 2, 0, 4, 0, 4, 0, 2]);
145		assert_eq!(index.get(batch_0_0_ids[0]).copied(), Some(0));
146		assert_eq!(index.get(batch_0_0_ids[1]).copied(), Some(1));
147		assert_eq!(index.get(batch_0_1_ids[0]).copied(), Some(2));
148		assert_eq!(index.get(batch_0_1_ids[1]).copied(), Some(3));
149		assert_eq!(index.get(batch_0_2_ids[0]).copied(), Some(6));
150		assert_eq!(index.get(batch_0_2_ids[1]).copied(), Some(7));
151		assert_eq!(index.get(batch_2_0_ids[0]).copied(), Some(4));
152		assert_eq!(index.get(batch_2_0_ids[1]).copied(), Some(5));
153		assert_eq!(index.get(batch_2_1_ids[0]).copied(), Some(8));
154		assert_eq!(index.get(batch_2_1_ids[1]).copied(), Some(9));
155		assert_eq!(index.get(batch_2_2_ids[0]).copied(), Some(10));
156		assert_eq!(index.get(batch_2_2_ids[1]).copied(), Some(11));
157		assert_eq!(index.get(repeat).copied(), None);
158	}
159}