binius_core/merkle_tree/
scheme.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, fmt::Debug, marker::PhantomData};
4
5use binius_field::TowerField;
6use binius_hash::{HashBuffer, PseudoCompressionFunction};
7use binius_utils::{
8	SerializationMode, SerializeBytes, bail,
9	checked_arithmetics::{log2_ceil_usize, log2_strict_usize},
10};
11use bytes::Buf;
12use digest::{Digest, Output, core_api::BlockSizeUser};
13use getset::Getters;
14
15use super::{
16	errors::{Error, VerificationError},
17	merkle_tree_vcs::MerkleTreeScheme,
18};
19use crate::transcript::TranscriptReader;
20
21#[derive(Debug, Getters)]
22pub struct BinaryMerkleTreeScheme<T, H, C> {
23	#[getset(get = "pub")]
24	compression: C,
25	// This makes it so that `BinaryMerkleTreeScheme` remains Send + Sync
26	// See https://doc.rust-lang.org/nomicon/phantom-data.html#table-of-phantomdata-patterns
27	_phantom: PhantomData<fn() -> (T, H)>,
28}
29
30impl<T, H, C> BinaryMerkleTreeScheme<T, H, C> {
31	pub fn new(compression: C) -> Self {
32		Self {
33			compression,
34			_phantom: PhantomData,
35		}
36	}
37}
38
39impl<F, H, C> MerkleTreeScheme<F> for BinaryMerkleTreeScheme<F, H, C>
40where
41	F: TowerField,
42	H: Digest + BlockSizeUser,
43	C: PseudoCompressionFunction<Output<H>, 2> + Sync,
44{
45	type Digest = Output<H>;
46
47	/// This layer allows minimizing the proof size.
48	fn optimal_verify_layer(&self, n_queries: usize, tree_depth: usize) -> usize {
49		log2_ceil_usize(n_queries).min(tree_depth)
50	}
51
52	fn proof_size(&self, len: usize, n_queries: usize, layer_depth: usize) -> Result<usize, Error> {
53		if !len.is_power_of_two() {
54			bail!(Error::PowerOfTwoLengthRequired)
55		}
56
57		let log_len = log2_strict_usize(len);
58
59		if layer_depth > log_len {
60			bail!(Error::IncorrectLayerDepth)
61		}
62
63		Ok(((log_len - layer_depth - 1) * n_queries + (1 << layer_depth))
64			* <H as Digest>::output_size())
65	}
66
67	fn verify_vector(
68		&self,
69		root: &Self::Digest,
70		data: &[F],
71		batch_size: usize,
72	) -> Result<(), Error> {
73		if data.len() % batch_size != 0 {
74			bail!(Error::IncorrectBatchSize);
75		}
76
77		let mut digests = data
78			.chunks(batch_size)
79			.map(|chunk| hash_field_elems::<_, H>(chunk))
80			.collect::<Vec<_>>();
81
82		fold_digests_vector_inplace(&self.compression, &mut digests)?;
83		if digests[0] != *root {
84			bail!(VerificationError::InvalidProof)
85		}
86		Ok(())
87	}
88
89	fn verify_layer(
90		&self,
91		root: &Self::Digest,
92		layer_depth: usize,
93		layer_digests: &[Self::Digest],
94	) -> Result<(), Error> {
95		if 1 << layer_depth != layer_digests.len() {
96			bail!(VerificationError::IncorrectVectorLength)
97		}
98
99		let mut digests = layer_digests.to_owned();
100
101		fold_digests_vector_inplace(&self.compression, &mut digests)?;
102
103		if digests[0] != *root {
104			bail!(VerificationError::InvalidProof)
105		}
106		Ok(())
107	}
108
109	fn verify_opening<B: Buf>(
110		&self,
111		mut index: usize,
112		values: &[F],
113		layer_depth: usize,
114		tree_depth: usize,
115		layer_digests: &[Self::Digest],
116		proof: &mut TranscriptReader<B>,
117	) -> Result<(), Error> {
118		if (1 << layer_depth) != layer_digests.len() {
119			bail!(VerificationError::IncorrectVectorLength);
120		}
121
122		if index >= (1 << tree_depth) {
123			bail!(Error::IndexOutOfRange {
124				max: (1 << tree_depth) - 1
125			});
126		}
127
128		let mut leaf_digest = hash_field_elems::<_, H>(values);
129		for branch_node in proof.read_vec(tree_depth - layer_depth)? {
130			leaf_digest = self.compression.compress(if index & 1 == 0 {
131				[leaf_digest, branch_node]
132			} else {
133				[branch_node, leaf_digest]
134			});
135			index >>= 1;
136		}
137
138		(leaf_digest == layer_digests[index])
139			.then_some(())
140			.ok_or_else(|| VerificationError::InvalidProof.into())
141	}
142}
143
144// Merkle-tree-like folding
145fn fold_digests_vector_inplace<C, D>(compression: &C, digests: &mut [D]) -> Result<(), Error>
146where
147	C: PseudoCompressionFunction<D, 2> + Sync,
148	D: Clone + Default + Send + Sync + Debug,
149{
150	if !digests.len().is_power_of_two() {
151		bail!(Error::PowerOfTwoLengthRequired);
152	}
153
154	let mut len = digests.len() / 2;
155
156	while len != 0 {
157		for i in 0..len {
158			digests[i] = compression.compress(array::from_fn(|j| digests[2 * i + j].clone()));
159		}
160		len /= 2;
161	}
162
163	Ok(())
164}
165
166/// Hashes a slice of tower field elements.
167fn hash_field_elems<F, H>(elems: &[F]) -> Output<H>
168where
169	F: TowerField,
170	H: Digest + BlockSizeUser,
171{
172	let mut hasher = H::new();
173	{
174		let mut buffer = HashBuffer::new(&mut hasher);
175		for elem in elems {
176			let mode = SerializationMode::CanonicalTower;
177			SerializeBytes::serialize(elem, &mut buffer, mode)
178				.expect("HashBuffer has infinite capacity");
179		}
180	}
181	hasher.finalize()
182}