binius_core/piop/
commit.rs1use binius_field::{
4 as_packed_field::{PackScalar, PackedType},
5 underlier::UnderlierType,
6 TowerField,
7};
8use binius_utils::sparse_index::SparseIndex;
9
10use super::{error::Error, util::ResizeableIndex, verify::CommitMeta};
11use crate::{
12 oracle::{MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant},
13 witness::{MultilinearExtensionIndex, MultilinearWitness},
14};
15
16pub fn make_oracle_commit_meta<F: TowerField>(
21 oracles: &MultilinearOracleSet<F>,
22) -> Result<(CommitMeta, SparseIndex<usize>), Error> {
23 #[derive(Clone)]
36 struct CommitIDFirstPass {
37 n_packed_vars: usize,
38 idx_in_bucket: usize,
39 }
40
41 let mut first_pass_index = SparseIndex::with_capacity(oracles.size());
43 let mut n_multilins_by_vars = ResizeableIndex::<usize>::new();
44 for oracle in oracles.iter() {
45 if matches!(oracle.variant, MultilinearPolyVariant::Committed) {
46 let n_packed_vars = n_packed_vars_for_committed_oracle(&oracle)?;
47 let n_multilins_for_vars = n_multilins_by_vars.get_mut(n_packed_vars);
48
49 first_pass_index.set(
50 oracle.id(),
51 CommitIDFirstPass {
52 n_packed_vars,
53 idx_in_bucket: *n_multilins_for_vars,
54 },
55 );
56 *n_multilins_for_vars += 1;
57 }
58 }
59
60 let commit_meta = CommitMeta::new(n_multilins_by_vars.into_vec());
61
62 let mut index = SparseIndex::with_capacity(oracles.size());
64 for id in 0..oracles.size() {
65 if let Some(CommitIDFirstPass {
66 n_packed_vars,
67 idx_in_bucket,
68 }) = first_pass_index.get(id)
69 {
70 let offset = commit_meta.range_by_vars(*n_packed_vars).start;
71 index.set(id, offset + *idx_in_bucket);
72 }
73 }
74
75 Ok((commit_meta, index))
76}
77
78pub fn collect_committed_witnesses<'a, U, F>(
89 commit_meta: &CommitMeta,
90 oracle_to_commit_index: &SparseIndex<usize>,
91 oracles: &MultilinearOracleSet<F>,
92 witness_index: &MultilinearExtensionIndex<'a, U, F>,
93) -> Result<Vec<MultilinearWitness<'a, PackedType<U, F>>>, Error>
94where
95 U: UnderlierType + PackScalar<F>,
96 F: TowerField,
97{
98 let mut witnesses = vec![None; commit_meta.total_multilins()];
99 for oracle_id in 0..oracles.size() {
100 if let Some(commit_idx) = oracle_to_commit_index.get(oracle_id) {
101 witnesses[*commit_idx] = Some(witness_index.get_multilin_poly(oracle_id)?);
102 }
103 }
104 Ok(witnesses
105 .into_iter()
106 .map(|witness| witness.expect("pre-condition: oracle_to_commit index is surjective"))
107 .collect())
108}
109
110fn n_packed_vars_for_committed_oracle<F: TowerField>(
111 oracle: &MultilinearPolyOracle<F>,
112) -> Result<usize, Error> {
113 let n_vars = oracle.n_vars();
114 let tower_level = oracle.binary_tower_level();
115 n_vars
116 .checked_sub(F::TOWER_LEVEL - tower_level)
117 .ok_or_else(|| Error::OracleTooSmall {
118 id: oracle.id(),
119 min_vars: F::TOWER_LEVEL - tower_level,
120 })
121}
122
123#[cfg(test)]
124mod tests {
125 use binius_field::BinaryField128b;
126
127 use super::*;
128
129 #[test]
130 fn test_make_oracle_commit_meta() {
131 let mut oracles = MultilinearOracleSet::<BinaryField128b>::new();
132
133 let batch_0_0_ids = oracles.add_committed_multiple::<2>(8, 0);
134 let batch_0_1_ids = oracles.add_committed_multiple::<2>(10, 0);
135 let batch_0_2_ids = oracles.add_committed_multiple::<2>(12, 0);
136
137 let repeat = oracles.add_repeating(batch_0_2_ids[0], 5).unwrap();
138
139 let batch_2_0_ids = oracles.add_committed_multiple::<2>(8, 2);
140 let batch_2_1_ids = oracles.add_committed_multiple::<2>(10, 2);
141 let batch_2_2_ids = oracles.add_committed_multiple::<2>(12, 2);
142
143 let (commit_meta, index) = make_oracle_commit_meta(&oracles).unwrap();
144 assert_eq!(commit_meta.n_multilins_by_vars(), &[0, 2, 0, 4, 0, 4, 0, 2]);
145 assert_eq!(index.get(batch_0_0_ids[0]).copied(), Some(0));
146 assert_eq!(index.get(batch_0_0_ids[1]).copied(), Some(1));
147 assert_eq!(index.get(batch_0_1_ids[0]).copied(), Some(2));
148 assert_eq!(index.get(batch_0_1_ids[1]).copied(), Some(3));
149 assert_eq!(index.get(batch_0_2_ids[0]).copied(), Some(6));
150 assert_eq!(index.get(batch_0_2_ids[1]).copied(), Some(7));
151 assert_eq!(index.get(batch_2_0_ids[0]).copied(), Some(4));
152 assert_eq!(index.get(batch_2_0_ids[1]).copied(), Some(5));
153 assert_eq!(index.get(batch_2_1_ids[0]).copied(), Some(8));
154 assert_eq!(index.get(batch_2_1_ids[1]).copied(), Some(9));
155 assert_eq!(index.get(batch_2_2_ids[0]).copied(), Some(10));
156 assert_eq!(index.get(batch_2_2_ids[1]).copied(), Some(11));
157 assert_eq!(index.get(repeat).copied(), None);
158 }
159}