binius_hash/vision_4/
parallel_compression.rs

1// Copyright 2025 Irreducible Inc.
2// Copyright 2026 The Binius Developers
3
4use std::{array, fmt::Debug, mem::MaybeUninit};
5
6use binius_field::{BinaryField128bGhash as Ghash, Field};
7use binius_utils::{
8	DeserializeBytes, SerializeBytes,
9	rayon::{
10		iter::{IndexedParallelIterator, ParallelIterator},
11		slice::{ParallelSlice, ParallelSliceMut},
12	},
13};
14use digest::Output;
15
16use super::{
17	compression::VisionCompression, constants::M, digest::VisionHasherDigest,
18	parallel_permutation::batch_permutation,
19};
20use crate::parallel_compression::ParallelPseudoCompression;
21
22// The number of parallel compressions N must be a power of 2.
23// The amortization of batch inversion grows with the batch size
24// and thus with N. Heuristically 128 is the largest N before
25// performance degrades.
26const N: usize = 128;
27const MN: usize = N * M;
28
29/// Parallel Vision compression with N parallel compressions using rayon.
30///
31/// Processes N compression pairs simultaneously using parallel Vision permutation
32/// and multithreading for optimal performance.
33#[derive(Clone, Debug, Default)]
34pub struct VisionParallelCompression {
35	compression: VisionCompression,
36}
37
38impl VisionParallelCompression {
39	pub fn new() -> Self {
40		Self::default()
41	}
42}
43
44impl ParallelPseudoCompression<Output<VisionHasherDigest>, 2> for VisionParallelCompression {
45	type Compression = VisionCompression;
46
47	fn compression(&self) -> &Self::Compression {
48		&self.compression
49	}
50
51	// If we add another implementation of `ParallelPseudoCompression`, it makes sense to add the
52	// compression-equivalent of `MultiDigest` and `ParallelMultidigestImpl` to avoid
53	// duplicating the logic below of breaking into chunks of N and handling remainders.
54	#[tracing::instrument(
55		"VisionParallelCompression::parallel_compress",
56		skip_all,
57		level = "debug"
58	)]
59	fn parallel_compress(
60		&self,
61		inputs: &[Output<VisionHasherDigest>],
62		out: &mut [MaybeUninit<Output<VisionHasherDigest>>],
63	) {
64		assert_eq!(inputs.len(), 2 * out.len(), "Input length must be 2 * output length");
65
66		inputs
67			.par_chunks_exact(N * 2)
68			.zip(out.par_chunks_exact_mut(N))
69			.for_each(|(input_chunk, output_chunk)| {
70				self.compress_batch_parallel(input_chunk, output_chunk);
71			});
72
73		// Handle remaining pairs using batched processing
74		let remainder_inputs = inputs.chunks_exact(N * 2).remainder();
75		let remainder_outputs = out.chunks_exact_mut(N).into_remainder();
76
77		if !remainder_outputs.is_empty() {
78			// Use stack-allocated arrays for remainder handling
79			let mut padded_inputs = [Output::<VisionHasherDigest>::default(); N * 2];
80			let mut padded_outputs = [MaybeUninit::uninit(); N];
81
82			// Copy remainder inputs
83			padded_inputs[..remainder_inputs.len()].copy_from_slice(remainder_inputs);
84
85			// Process full batch (including padding)
86			self.compress_batch_parallel(&padded_inputs, &mut padded_outputs);
87
88			// Copy only the actual results back
89			for (output, padded) in remainder_outputs.iter_mut().zip(padded_outputs) {
90				// Safety: `compress_batch_parallel` guarantees to initialize `padded_outputs`
91				output.write(unsafe { padded.assume_init() });
92			}
93		}
94	}
95}
96
97impl VisionParallelCompression {
98	/// Compress exactly N pairs using parallel permutation.
99	#[tracing::instrument(
100		"VisionParallelCompression::compress_batch_parallel",
101		skip_all,
102		level = "debug"
103	)]
104	#[inline]
105	fn compress_batch_parallel(
106		&self,
107		inputs: &[Output<VisionHasherDigest>],
108		out: &mut [MaybeUninit<Output<VisionHasherDigest>>],
109	) {
110		assert_eq!(out.len(), N, "Must process exactly {N} pairs");
111		assert_eq!(inputs.len(), 2 * N, "Must have 2*N inputs");
112
113		// Step 1: Deserialize inputs into flattened state array
114		let mut states = [Ghash::ZERO; MN];
115		for i in 0..N {
116			let input0 = &inputs[i * 2];
117			let input1 = &inputs[i * 2 + 1];
118
119			// Deserialize each 32-byte input into 2 Ghash elements
120			states[i * M] = Ghash::deserialize(&input0[0..16]).expect("16 bytes fits in Ghash");
121			states[i * M + 1] =
122				Ghash::deserialize(&input0[16..32]).expect("16 bytes fits in Ghash");
123			states[i * M + 2] = Ghash::deserialize(&input1[0..16]).expect("16 bytes fits in Ghash");
124			states[i * M + 3] =
125				Ghash::deserialize(&input1[16..32]).expect("16 bytes fits in Ghash");
126		}
127
128		// Step 2: Copy original first 2 elements for each state
129		let originals: [_; N] = array::from_fn(|i| (states[i * M], states[i * M + 1]));
130
131		// Step 3: Apply parallel permutation to all states
132		batch_permutation::<N, MN>(&mut states);
133
134		// Step 4: Add original elements back and serialize outputs
135		for i in 0..N {
136			states[i * M] += originals[i].0;
137			states[i * M + 1] += originals[i].1;
138
139			let mut output = Output::<VisionHasherDigest>::default();
140			let (left, right) = output.as_mut_slice().split_at_mut(16);
141			states[i * M].serialize(left).expect("fits in 16 bytes");
142			states[i * M + 1]
143				.serialize(right)
144				.expect("fits in 16 bytes");
145			out[i].write(output);
146		}
147	}
148}
149
150#[cfg(test)]
151mod tests {
152	use std::array;
153
154	use digest::Digest;
155
156	use super::*;
157	use crate::PseudoCompressionFunction;
158
159	#[test]
160	fn test_parallel_vs_sequential_simple() {
161		let parallel = VisionParallelCompression::default();
162		let sequential = &parallel.compression;
163
164		// Create test inputs (4 inputs = 2 pairs)
165		let inputs = [
166			VisionHasherDigest::new().finalize(), // input 0 (pair 0, element 0)
167			{
168				let mut hasher = VisionHasherDigest::new();
169				hasher.update(b"first");
170				hasher.finalize()
171			}, // input 1 (pair 0, element 1)
172			{
173				let mut hasher = VisionHasherDigest::new();
174				hasher.update(b"second");
175				hasher.finalize()
176			}, // input 2 (pair 1, element 0)
177			{
178				let mut hasher = VisionHasherDigest::new();
179				hasher.update(b"third");
180				hasher.finalize()
181			}, // input 3 (pair 1, element 1)
182		];
183
184		// Compute expected results sequentially
185		let sequential_results = [
186			sequential.compress([inputs[0], inputs[1]]),
187			sequential.compress([inputs[2], inputs[3]]),
188		];
189
190		// Compute parallel results
191		let mut parallel_outputs = [MaybeUninit::uninit(); 2];
192		parallel.parallel_compress(&inputs, &mut parallel_outputs);
193		let parallel_results: [_; 2] =
194			array::from_fn(|i| unsafe { parallel_outputs[i].assume_init() });
195
196		// Compare
197		assert_eq!(sequential_results, parallel_results);
198	}
199
200	#[test]
201	fn test_parallel_compress_large_batch() {
202		use rand::{Rng, SeedableRng, rngs::StdRng};
203
204		let parallel = VisionParallelCompression::default();
205		let mut rng = StdRng::seed_from_u64(0);
206
207		// Test 300 pairs (600 inputs) to exercise batch processing (N=128) + remainder handling
208		const NUM_PAIRS: usize = 300;
209
210		// Generate test inputs
211		let inputs: Vec<_> = (0..NUM_PAIRS * 2)
212			.map(|i| {
213				let mut hasher = VisionHasherDigest::new();
214				hasher.update(i.to_le_bytes());
215				hasher.update(rng.random::<[u8; 32]>());
216				hasher.finalize()
217			})
218			.collect();
219
220		// Compute expected results sequentially
221		let sequential_results: Vec<_> = (0..NUM_PAIRS)
222			.map(|i| {
223				parallel
224					.compression
225					.compress([inputs[i * 2], inputs[i * 2 + 1]])
226			})
227			.collect();
228
229		// Compute parallel results
230		let mut parallel_outputs = vec![MaybeUninit::uninit(); NUM_PAIRS];
231		parallel.parallel_compress(&inputs, &mut parallel_outputs);
232		let parallel_results: Vec<_> = parallel_outputs
233			.into_iter()
234			.map(|out| unsafe { out.assume_init() })
235			.collect();
236
237		assert_eq!(sequential_results, parallel_results);
238	}
239}