binius_hash/
parallel_compression.rs

1// Copyright 2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{array, fmt::Debug, mem::MaybeUninit};
5
6use binius_utils::rayon::prelude::*;
7
8use crate::PseudoCompressionFunction;
9
10/// A trait for parallel application of N-to-1 compression functions.
11///
12/// This trait enables efficient batch compression operations where multiple N-element
13/// chunks are compressed in parallel. It's particularly useful for constructing hash trees
14/// and Merkle trees where many compression operations need to be performed simultaneously.
15///
16/// The trait is parameterized by:
17/// - `T`: The type of values being compressed (typically hash digests)
18/// - `N`: The arity of the compression function (number of inputs per compression)
19pub trait ParallelPseudoCompression<T, const N: usize> {
20	/// The underlying compression function that performs N-to-1 compression.
21	type Compression: PseudoCompressionFunction<T, N>;
22
23	/// Returns a reference to the underlying compression function.
24	fn compression(&self) -> &Self::Compression;
25
26	/// Compresses multiple N-element chunks in parallel.
27	///
28	/// # Arguments
29	/// * `inputs` - A slice containing the values to compress. Must have length `N * out.len()`.
30	/// * `out` - Output buffer where compressed values will be written.
31	///
32	/// # Behavior
33	/// For each index `i` in `0..out.len()`, this method computes:
34	/// ```text
35	/// out[i] = Compression::compress([inputs[i*N], inputs[i*N+1], ..., inputs[i*N+N-1]])
36	/// ```
37	///
38	/// All compressions are performed in parallel for efficiency.
39	///
40	/// # Post-conditions
41	/// After this method returns, all elements in `out` will be initialized with the
42	/// compressed values from their corresponding N-element chunks in `inputs`.
43	///
44	/// # Panics
45	/// Panics if `inputs.len() != N * out.len()`.
46	fn parallel_compress(&self, inputs: &[T], out: &mut [MaybeUninit<T>]);
47}
48
49/// A simple adapter that wraps any `PseudoCompressionFunction` to implement `ParallelCompression`.
50///
51/// This adapter provides a straightforward way to use existing compression functions
52/// in parallel contexts by applying them sequentially to each N-element chunk.
53#[derive(Debug, Clone)]
54pub struct ParallelCompressionAdaptor<C> {
55	compression: C,
56}
57
58impl<C> ParallelCompressionAdaptor<C> {
59	/// Creates a new adapter wrapping the given compression function.
60	pub fn new(compression: C) -> Self {
61		Self { compression }
62	}
63
64	/// Returns a reference to the underlying compression function.
65	pub fn compression(&self) -> &C {
66		&self.compression
67	}
68}
69
70impl<T, C, const ARITY: usize> ParallelPseudoCompression<T, ARITY> for ParallelCompressionAdaptor<C>
71where
72	T: Clone + Send + Sync,
73	C: PseudoCompressionFunction<T, ARITY> + Sync,
74{
75	type Compression = C;
76
77	fn compression(&self) -> &Self::Compression {
78		&self.compression
79	}
80
81	fn parallel_compress(&self, inputs: &[T], out: &mut [MaybeUninit<T>]) {
82		assert_eq!(inputs.len(), ARITY * out.len(), "Input length must be N * output length");
83
84		inputs
85			.par_chunks_exact(ARITY)
86			.zip(out.par_iter_mut())
87			.for_each(|(chunk, output)| {
88				// Convert slice to array for compression function
89				let chunk_array: [T; ARITY] = array::from_fn(|j| chunk[j].clone());
90				let compressed = self.compression.compress(chunk_array);
91				output.write(compressed);
92			});
93	}
94}
95
96#[cfg(test)]
97mod tests {
98	use std::mem::MaybeUninit;
99
100	use rand::{Rng, SeedableRng, rngs::StdRng};
101
102	use super::*;
103
104	// Simple test compression function that XORs all inputs
105	#[derive(Clone, Debug)]
106	struct XorCompression;
107
108	impl PseudoCompressionFunction<u64, 3> for XorCompression {
109		fn compress(&self, input: [u64; 3]) -> u64 {
110			input[0] ^ input[1] ^ input[2]
111		}
112	}
113
114	#[test]
115	fn test_parallel_compression_adaptor() {
116		let mut rng = StdRng::seed_from_u64(0);
117		let compression = XorCompression;
118		let adaptor = ParallelCompressionAdaptor::new(compression.clone());
119
120		// Test with 4 chunks of 3 elements each
121		const N: usize = 3;
122		const NUM_CHUNKS: usize = 4;
123		let inputs: Vec<u64> = (0..N * NUM_CHUNKS).map(|_| rng.random()).collect();
124
125		// Use the adaptor
126		let mut adaptor_output = [MaybeUninit::<u64>::uninit(); NUM_CHUNKS];
127		adaptor.parallel_compress(&inputs, &mut adaptor_output);
128		let adaptor_results: Vec<u64> = adaptor_output
129			.into_iter()
130			.map(|x| unsafe { x.assume_init() })
131			.collect();
132
133		// Manually compress each chunk
134		let mut manual_results = Vec::new();
135		for chunk_idx in 0..NUM_CHUNKS {
136			let start = chunk_idx * N;
137			let chunk = [inputs[start], inputs[start + 1], inputs[start + 2]];
138			manual_results.push(compression.compress(chunk));
139		}
140
141		// Results should be identical
142		assert_eq!(adaptor_results, manual_results);
143	}
144
145	#[test]
146	#[should_panic(expected = "Input length must be N * output length")]
147	fn test_mismatched_input_length() {
148		let compression = XorCompression;
149		let adaptor = ParallelCompressionAdaptor::new(compression);
150
151		let inputs = vec![1u64, 2, 3, 4]; // 4 elements
152		let mut output = [MaybeUninit::<u64>::uninit(); 2]; // Expecting 6 elements (2 * 3)
153
154		adaptor.parallel_compress(&inputs, &mut output);
155	}
156}