binius_core/piop/
commit.rs1use binius_field::{
4 as_packed_field::{PackScalar, PackedType},
5 underlier::UnderlierType,
6 TowerField,
7};
8use binius_utils::sparse_index::SparseIndex;
9
10use super::{error::Error, util::ResizeableIndex, verify::CommitMeta};
11use crate::{
12 oracle::{MultilinearOracleSet, MultilinearPolyOracle, MultilinearPolyVariant},
13 witness::{MultilinearExtensionIndex, MultilinearWitness},
14};
15
16pub fn make_oracle_commit_meta<F: TowerField>(
21 oracles: &MultilinearOracleSet<F>,
22) -> Result<(CommitMeta, SparseIndex<usize>), Error> {
23 #[derive(Clone)]
36 struct CommitIDFirstPass {
37 n_packed_vars: usize,
38 idx_in_bucket: usize,
39 }
40
41 let mut first_pass_index = SparseIndex::with_capacity(oracles.size());
43 let mut n_multilins_by_vars = ResizeableIndex::<usize>::new();
44 for oracle in oracles.iter() {
45 if matches!(oracle.variant, MultilinearPolyVariant::Committed) {
46 let n_packed_vars = n_packed_vars_for_committed_oracle(&oracle);
47 let n_multilins_for_vars = n_multilins_by_vars.get_mut(n_packed_vars);
48
49 first_pass_index.set(
50 oracle.id(),
51 CommitIDFirstPass {
52 n_packed_vars,
53 idx_in_bucket: *n_multilins_for_vars,
54 },
55 );
56 *n_multilins_for_vars += 1;
57 }
58 }
59
60 let commit_meta = CommitMeta::new(n_multilins_by_vars.into_vec());
61
62 let mut index = SparseIndex::with_capacity(oracles.size());
64 for id in 0..oracles.size() {
65 if let Some(CommitIDFirstPass {
66 n_packed_vars,
67 idx_in_bucket,
68 }) = first_pass_index.get(id)
69 {
70 let offset = commit_meta.range_by_vars(*n_packed_vars).start;
71 index.set(id, offset + *idx_in_bucket);
72 }
73 }
74
75 Ok((commit_meta, index))
76}
77
78pub fn collect_committed_witnesses<'a, U, F>(
89 commit_meta: &CommitMeta,
90 oracle_to_commit_index: &SparseIndex<usize>,
91 oracles: &MultilinearOracleSet<F>,
92 witness_index: &MultilinearExtensionIndex<'a, PackedType<U, F>>,
93) -> Result<Vec<MultilinearWitness<'a, PackedType<U, F>>>, Error>
94where
95 U: UnderlierType + PackScalar<F>,
96 F: TowerField,
97{
98 let mut witnesses = vec![None; commit_meta.total_multilins()];
99 for oracle_id in 0..oracles.size() {
100 if let Some(commit_idx) = oracle_to_commit_index.get(oracle_id) {
101 witnesses[*commit_idx] = Some(witness_index.get_multilin_poly(oracle_id)?);
102 }
103 }
104 Ok(witnesses
105 .into_iter()
106 .map(|witness| witness.expect("pre-condition: oracle_to_commit index is surjective"))
107 .collect())
108}
109
110fn n_packed_vars_for_committed_oracle<F: TowerField>(oracle: &MultilinearPolyOracle<F>) -> usize {
111 let n_vars = oracle.n_vars();
112 let tower_level = oracle.binary_tower_level();
113 (n_vars + tower_level).saturating_sub(F::TOWER_LEVEL)
114}
115
116#[cfg(test)]
117mod tests {
118 use binius_field::BinaryField128b;
119
120 use super::*;
121
122 #[test]
123 fn test_make_oracle_commit_meta() {
124 let mut oracles = MultilinearOracleSet::<BinaryField128b>::new();
125
126 let batch_0_0_ids = oracles.add_committed_multiple::<2>(8, 0);
127 let batch_0_1_ids = oracles.add_committed_multiple::<2>(10, 0);
128 let batch_0_2_ids = oracles.add_committed_multiple::<2>(12, 0);
129
130 let repeat = oracles.add_repeating(batch_0_2_ids[0], 5).unwrap();
131
132 let batch_2_0_ids = oracles.add_committed_multiple::<2>(8, 2);
133 let batch_2_1_ids = oracles.add_committed_multiple::<2>(10, 2);
134 let batch_2_2_ids = oracles.add_committed_multiple::<2>(12, 2);
135
136 let (commit_meta, index) = make_oracle_commit_meta(&oracles).unwrap();
137 assert_eq!(commit_meta.n_multilins_by_vars(), &[0, 2, 0, 4, 0, 4, 0, 2]);
138 assert_eq!(index.get(batch_0_0_ids[0]).copied(), Some(0));
139 assert_eq!(index.get(batch_0_0_ids[1]).copied(), Some(1));
140 assert_eq!(index.get(batch_0_1_ids[0]).copied(), Some(2));
141 assert_eq!(index.get(batch_0_1_ids[1]).copied(), Some(3));
142 assert_eq!(index.get(batch_0_2_ids[0]).copied(), Some(6));
143 assert_eq!(index.get(batch_0_2_ids[1]).copied(), Some(7));
144 assert_eq!(index.get(batch_2_0_ids[0]).copied(), Some(4));
145 assert_eq!(index.get(batch_2_0_ids[1]).copied(), Some(5));
146 assert_eq!(index.get(batch_2_1_ids[0]).copied(), Some(8));
147 assert_eq!(index.get(batch_2_1_ids[1]).copied(), Some(9));
148 assert_eq!(index.get(batch_2_2_ids[0]).copied(), Some(10));
149 assert_eq!(index.get(batch_2_2_ids[1]).copied(), Some(11));
150 assert_eq!(index.get(repeat).copied(), None);
151 }
152}