binius_core/piop/
commit.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
// Copyright 2024-2025 Irreducible Inc.

use binius_field::{
	as_packed_field::{PackScalar, PackedType},
	underlier::UnderlierType,
	TowerField,
};
use binius_utils::sparse_index::SparseIndex;

use super::{error::Error, util::ResizeableIndex, verify::CommitMeta};
use crate::{
	oracle::{MultilinearOracleSet, MultilinearPolyOracle},
	witness::{MultilinearExtensionIndex, MultilinearWitness},
};

/// Indexes the committed oracles in a [`MultilinearOracleSet`] and returns:
///
/// 1. a [`CommitMeta`] struct that stores information about the committed polynomials
/// 2. a sparse index mapping oracle IDs to committed IDs in the commit metadata
pub fn make_oracle_commit_meta<F: TowerField>(
	oracles: &MultilinearOracleSet<F>,
) -> Result<(CommitMeta, SparseIndex<usize>), Error> {
	// We need to construct two structures:
	//
	// 1) the commit metadata structure, which depends on the counts of the number of multilinears
	//    per number of packed variables
	// 2) a sparse index mapping oracle IDs to IDs in the commit metadata
	//
	// We will construct the two indices in two passes. On the first pass, we count the number of
	// multilinears and assign for each oracle the index of the oracle in the bucket of oracles
	// with the same number of packed variables. On the second pass, the commit metadata is
	// finalized, so we can determine the absolute indices into the commit metadata structure by
	// adding offsets.

	#[derive(Clone)]
	struct CommitIDFirstPass {
		n_packed_vars: usize,
		idx_in_bucket: usize,
	}

	// First pass: count the number of multilinears and index within buckets
	let mut first_pass_index = SparseIndex::new(oracles.size());
	let mut n_multilins_by_vars = ResizeableIndex::<usize>::new();
	for oracle in oracles.iter() {
		if let MultilinearPolyOracle::Committed { oracle_id: id, .. } = &oracle {
			let n_packed_vars = n_packed_vars_for_committed_oracle(&oracle)?;
			let n_multilins_for_vars = n_multilins_by_vars.get_mut(n_packed_vars);

			first_pass_index.set(
				*id,
				CommitIDFirstPass {
					n_packed_vars,
					idx_in_bucket: *n_multilins_for_vars,
				},
			);
			*n_multilins_for_vars += 1;
		}
	}

	let commit_meta = CommitMeta::new(n_multilins_by_vars.into_vec());

	// Second pass: use commit_meta counts to finalized indices with offsets
	let mut index = SparseIndex::new(oracles.size());
	for id in 0..oracles.size() {
		if let Some(CommitIDFirstPass {
			n_packed_vars,
			idx_in_bucket,
		}) = first_pass_index.get(id)
		{
			let offset = commit_meta.range_by_vars(*n_packed_vars).start;
			index.set(id, offset + *idx_in_bucket);
		}
	}

	Ok((commit_meta, index))
}

/// Collects the committed multilinear witnesses from the witness index and returns them in order.
///
/// During the commitment phase of the protocol, the trace polynomials are committed in a specific
/// order recorded by the commit metadata. This collects the witnesses corresponding to committed
/// multilinears and returns a vector of them in the commitment order.
///
/// ## Preconditions
///
/// * `oracle_to_commit_index` must be correctly constructed. Specifically, it must be surjective,
///   mapping at exactly one oracle to every index up to the number of committed multilinears.
pub fn collect_committed_witnesses<'a, U, F>(
	commit_meta: &CommitMeta,
	oracle_to_commit_index: &SparseIndex<usize>,
	oracles: &MultilinearOracleSet<F>,
	witness_index: &MultilinearExtensionIndex<'a, U, F>,
) -> Result<Vec<MultilinearWitness<'a, PackedType<U, F>>>, Error>
where
	U: UnderlierType + PackScalar<F>,
	F: TowerField,
{
	let mut witnesses = vec![None; commit_meta.total_multilins()];
	for oracle_id in 0..oracles.size() {
		if let Some(commit_idx) = oracle_to_commit_index.get(oracle_id) {
			witnesses[*commit_idx] = Some(witness_index.get_multilin_poly(oracle_id)?);
		}
	}
	Ok(witnesses
		.into_iter()
		.map(|witness| witness.expect("pre-condition: oracle_to_commit index is surjective"))
		.collect())
}

fn n_packed_vars_for_committed_oracle<F: TowerField>(
	oracle: &MultilinearPolyOracle<F>,
) -> Result<usize, Error> {
	let n_vars = oracle.n_vars();
	let tower_level = oracle.binary_tower_level();
	n_vars
		.checked_sub(F::TOWER_LEVEL - tower_level)
		.ok_or_else(|| Error::OracleTooSmall {
			id: oracle.id(),
			min_vars: F::TOWER_LEVEL - tower_level,
		})
}

#[cfg(test)]
mod tests {
	use binius_field::BinaryField128b;

	use super::*;

	#[test]
	fn test_make_oracle_commit_meta() {
		let mut oracles = MultilinearOracleSet::<BinaryField128b>::new();

		let batch_0_0_ids = oracles.add_committed_multiple::<2>(8, 0);
		let batch_0_1_ids = oracles.add_committed_multiple::<2>(10, 0);
		let batch_0_2_ids = oracles.add_committed_multiple::<2>(12, 0);

		let repeat = oracles.add_repeating(batch_0_2_ids[0], 5).unwrap();

		let batch_2_0_ids = oracles.add_committed_multiple::<2>(8, 2);
		let batch_2_1_ids = oracles.add_committed_multiple::<2>(10, 2);
		let batch_2_2_ids = oracles.add_committed_multiple::<2>(12, 2);

		let (commit_meta, index) = make_oracle_commit_meta(&oracles).unwrap();
		assert_eq!(commit_meta.n_multilins_by_vars(), &[0, 2, 0, 4, 0, 4, 0, 2]);
		assert_eq!(index.get(batch_0_0_ids[0]).cloned(), Some(0));
		assert_eq!(index.get(batch_0_0_ids[1]).cloned(), Some(1));
		assert_eq!(index.get(batch_0_1_ids[0]).cloned(), Some(2));
		assert_eq!(index.get(batch_0_1_ids[1]).cloned(), Some(3));
		assert_eq!(index.get(batch_0_2_ids[0]).cloned(), Some(6));
		assert_eq!(index.get(batch_0_2_ids[1]).cloned(), Some(7));
		assert_eq!(index.get(batch_2_0_ids[0]).cloned(), Some(4));
		assert_eq!(index.get(batch_2_0_ids[1]).cloned(), Some(5));
		assert_eq!(index.get(batch_2_1_ids[0]).cloned(), Some(8));
		assert_eq!(index.get(batch_2_1_ids[1]).cloned(), Some(9));
		assert_eq!(index.get(batch_2_2_ids[0]).cloned(), Some(10));
		assert_eq!(index.get(batch_2_2_ids[1]).cloned(), Some(11));
		assert_eq!(index.get(repeat).cloned(), None);
	}
}