binius_core/merkle_tree/
scheme.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::{array, fmt::Debug, marker::PhantomData};
4
5use binius_field::TowerField;
6use binius_hash::{HashBuffer, PseudoCompressionFunction};
7use binius_utils::{
8	bail,
9	checked_arithmetics::{log2_ceil_usize, log2_strict_usize},
10	SerializationMode, SerializeBytes,
11};
12use bytes::Buf;
13use digest::{core_api::BlockSizeUser, Digest, Output};
14use getset::Getters;
15
16use super::{
17	errors::{Error, VerificationError},
18	merkle_tree_vcs::MerkleTreeScheme,
19};
20use crate::transcript::TranscriptReader;
21
22#[derive(Debug, Getters)]
23pub struct BinaryMerkleTreeScheme<T, H, C> {
24	#[getset(get = "pub")]
25	compression: C,
26	// This makes it so that `BinaryMerkleTreeScheme` remains Send + Sync
27	// See https://doc.rust-lang.org/nomicon/phantom-data.html#table-of-phantomdata-patterns
28	_phantom: PhantomData<fn() -> (T, H)>,
29}
30
31impl<T, H, C> BinaryMerkleTreeScheme<T, H, C> {
32	pub fn new(compression: C) -> Self {
33		Self {
34			compression,
35			_phantom: PhantomData,
36		}
37	}
38}
39
40impl<F, H, C> MerkleTreeScheme<F> for BinaryMerkleTreeScheme<F, H, C>
41where
42	F: TowerField,
43	H: Digest + BlockSizeUser,
44	C: PseudoCompressionFunction<Output<H>, 2> + Sync,
45{
46	type Digest = Output<H>;
47
48	/// This layer allows minimizing the proof size.
49	fn optimal_verify_layer(&self, n_queries: usize, tree_depth: usize) -> usize {
50		log2_ceil_usize(n_queries).min(tree_depth)
51	}
52
53	fn proof_size(&self, len: usize, n_queries: usize, layer_depth: usize) -> Result<usize, Error> {
54		if !len.is_power_of_two() {
55			bail!(Error::PowerOfTwoLengthRequired)
56		}
57
58		let log_len = log2_strict_usize(len);
59
60		if layer_depth > log_len {
61			bail!(Error::IncorrectLayerDepth)
62		}
63
64		Ok(((log_len - layer_depth - 1) * n_queries + (1 << layer_depth))
65			* <H as Digest>::output_size())
66	}
67
68	fn verify_vector(
69		&self,
70		root: &Self::Digest,
71		data: &[F],
72		batch_size: usize,
73	) -> Result<(), Error> {
74		if data.len() % batch_size != 0 {
75			bail!(Error::IncorrectBatchSize);
76		}
77
78		let mut digests = data
79			.chunks(batch_size)
80			.map(|chunk| hash_field_elems::<_, H>(chunk))
81			.collect::<Vec<_>>();
82
83		fold_digests_vector_inplace(&self.compression, &mut digests)?;
84		if digests[0] != *root {
85			bail!(VerificationError::InvalidProof)
86		}
87		Ok(())
88	}
89
90	fn verify_layer(
91		&self,
92		root: &Self::Digest,
93		layer_depth: usize,
94		layer_digests: &[Self::Digest],
95	) -> Result<(), Error> {
96		if 1 << layer_depth != layer_digests.len() {
97			bail!(VerificationError::IncorrectVectorLength)
98		}
99
100		let mut digests = layer_digests.to_owned();
101
102		fold_digests_vector_inplace(&self.compression, &mut digests)?;
103
104		if digests[0] != *root {
105			bail!(VerificationError::InvalidProof)
106		}
107		Ok(())
108	}
109
110	fn verify_opening<B: Buf>(
111		&self,
112		mut index: usize,
113		values: &[F],
114		layer_depth: usize,
115		tree_depth: usize,
116		layer_digests: &[Self::Digest],
117		proof: &mut TranscriptReader<B>,
118	) -> Result<(), Error> {
119		if (1 << layer_depth) != layer_digests.len() {
120			bail!(VerificationError::IncorrectVectorLength);
121		}
122
123		if index >= (1 << tree_depth) {
124			bail!(Error::IndexOutOfRange {
125				max: (1 << tree_depth) - 1
126			});
127		}
128
129		let mut leaf_digest = hash_field_elems::<_, H>(values);
130		for branch_node in proof.read_vec(tree_depth - layer_depth)? {
131			leaf_digest = self.compression.compress(if index & 1 == 0 {
132				[leaf_digest, branch_node]
133			} else {
134				[branch_node, leaf_digest]
135			});
136			index >>= 1;
137		}
138
139		(leaf_digest == layer_digests[index])
140			.then_some(())
141			.ok_or_else(|| VerificationError::InvalidProof.into())
142	}
143}
144
145// Merkle-tree-like folding
146fn fold_digests_vector_inplace<C, D>(compression: &C, digests: &mut [D]) -> Result<(), Error>
147where
148	C: PseudoCompressionFunction<D, 2> + Sync,
149	D: Clone + Default + Send + Sync + Debug,
150{
151	if !digests.len().is_power_of_two() {
152		bail!(Error::PowerOfTwoLengthRequired);
153	}
154
155	let mut len = digests.len() / 2;
156
157	while len != 0 {
158		for i in 0..len {
159			digests[i] = compression.compress(array::from_fn(|j| digests[2 * i + j].clone()));
160		}
161		len /= 2;
162	}
163
164	Ok(())
165}
166
167/// Hashes a slice of tower field elements.
168fn hash_field_elems<F, H>(elems: &[F]) -> Output<H>
169where
170	F: TowerField,
171	H: Digest + BlockSizeUser,
172{
173	let mut hasher = H::new();
174	{
175		let mut buffer = HashBuffer::new(&mut hasher);
176		for elem in elems {
177			let mode = SerializationMode::CanonicalTower;
178			SerializeBytes::serialize(elem, &mut buffer, mode)
179				.expect("HashBuffer has infinite capacity");
180		}
181	}
182	hasher.finalize()
183}