binius_hash/vision_4/
parallel_compression.rs1use std::{array, fmt::Debug, mem::MaybeUninit};
5
6use binius_field::{BinaryField128bGhash as Ghash, Field};
7use binius_utils::{
8 DeserializeBytes, SerializeBytes,
9 rayon::{
10 iter::{IndexedParallelIterator, ParallelIterator},
11 slice::{ParallelSlice, ParallelSliceMut},
12 },
13};
14use digest::Output;
15
16use super::{
17 compression::VisionCompression, constants::M, digest::VisionHasherDigest,
18 parallel_permutation::batch_permutation,
19};
20use crate::parallel_compression::ParallelPseudoCompression;
21
22const N: usize = 128;
27const MN: usize = N * M;
28
29#[derive(Clone, Debug, Default)]
34pub struct VisionParallelCompression {
35 compression: VisionCompression,
36}
37
38impl VisionParallelCompression {
39 pub fn new() -> Self {
40 Self::default()
41 }
42}
43
44impl ParallelPseudoCompression<Output<VisionHasherDigest>, 2> for VisionParallelCompression {
45 type Compression = VisionCompression;
46
47 fn compression(&self) -> &Self::Compression {
48 &self.compression
49 }
50
51 #[tracing::instrument(
55 "VisionParallelCompression::parallel_compress",
56 skip_all,
57 level = "debug"
58 )]
59 fn parallel_compress(
60 &self,
61 inputs: &[Output<VisionHasherDigest>],
62 out: &mut [MaybeUninit<Output<VisionHasherDigest>>],
63 ) {
64 assert_eq!(inputs.len(), 2 * out.len(), "Input length must be 2 * output length");
65
66 inputs
67 .par_chunks_exact(N * 2)
68 .zip(out.par_chunks_exact_mut(N))
69 .for_each(|(input_chunk, output_chunk)| {
70 self.compress_batch_parallel(input_chunk, output_chunk);
71 });
72
73 let remainder_inputs = inputs.chunks_exact(N * 2).remainder();
75 let remainder_outputs = out.chunks_exact_mut(N).into_remainder();
76
77 if !remainder_outputs.is_empty() {
78 let mut padded_inputs = [Output::<VisionHasherDigest>::default(); N * 2];
80 let mut padded_outputs = [MaybeUninit::uninit(); N];
81
82 padded_inputs[..remainder_inputs.len()].copy_from_slice(remainder_inputs);
84
85 self.compress_batch_parallel(&padded_inputs, &mut padded_outputs);
87
88 for (output, padded) in remainder_outputs.iter_mut().zip(padded_outputs) {
90 output.write(unsafe { padded.assume_init() });
92 }
93 }
94 }
95}
96
97impl VisionParallelCompression {
98 #[tracing::instrument(
100 "VisionParallelCompression::compress_batch_parallel",
101 skip_all,
102 level = "debug"
103 )]
104 #[inline]
105 fn compress_batch_parallel(
106 &self,
107 inputs: &[Output<VisionHasherDigest>],
108 out: &mut [MaybeUninit<Output<VisionHasherDigest>>],
109 ) {
110 assert_eq!(out.len(), N, "Must process exactly {N} pairs");
111 assert_eq!(inputs.len(), 2 * N, "Must have 2*N inputs");
112
113 let mut states = [Ghash::ZERO; MN];
115 for i in 0..N {
116 let input0 = &inputs[i * 2];
117 let input1 = &inputs[i * 2 + 1];
118
119 states[i * M] = Ghash::deserialize(&input0[0..16]).expect("16 bytes fits in Ghash");
121 states[i * M + 1] =
122 Ghash::deserialize(&input0[16..32]).expect("16 bytes fits in Ghash");
123 states[i * M + 2] = Ghash::deserialize(&input1[0..16]).expect("16 bytes fits in Ghash");
124 states[i * M + 3] =
125 Ghash::deserialize(&input1[16..32]).expect("16 bytes fits in Ghash");
126 }
127
128 let originals: [_; N] = array::from_fn(|i| (states[i * M], states[i * M + 1]));
130
131 batch_permutation::<N, MN>(&mut states);
133
134 for i in 0..N {
136 states[i * M] += originals[i].0;
137 states[i * M + 1] += originals[i].1;
138
139 let mut output = Output::<VisionHasherDigest>::default();
140 let (left, right) = output.as_mut_slice().split_at_mut(16);
141 states[i * M].serialize(left).expect("fits in 16 bytes");
142 states[i * M + 1]
143 .serialize(right)
144 .expect("fits in 16 bytes");
145 out[i].write(output);
146 }
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use std::array;
153
154 use digest::Digest;
155
156 use super::*;
157 use crate::PseudoCompressionFunction;
158
159 #[test]
160 fn test_parallel_vs_sequential_simple() {
161 let parallel = VisionParallelCompression::default();
162 let sequential = ¶llel.compression;
163
164 let inputs = [
166 VisionHasherDigest::new().finalize(), {
168 let mut hasher = VisionHasherDigest::new();
169 hasher.update(b"first");
170 hasher.finalize()
171 }, {
173 let mut hasher = VisionHasherDigest::new();
174 hasher.update(b"second");
175 hasher.finalize()
176 }, {
178 let mut hasher = VisionHasherDigest::new();
179 hasher.update(b"third");
180 hasher.finalize()
181 }, ];
183
184 let sequential_results = [
186 sequential.compress([inputs[0], inputs[1]]),
187 sequential.compress([inputs[2], inputs[3]]),
188 ];
189
190 let mut parallel_outputs = [MaybeUninit::uninit(); 2];
192 parallel.parallel_compress(&inputs, &mut parallel_outputs);
193 let parallel_results: [_; 2] =
194 array::from_fn(|i| unsafe { parallel_outputs[i].assume_init() });
195
196 assert_eq!(sequential_results, parallel_results);
198 }
199
200 #[test]
201 fn test_parallel_compress_large_batch() {
202 use rand::{Rng, SeedableRng, rngs::StdRng};
203
204 let parallel = VisionParallelCompression::default();
205 let mut rng = StdRng::seed_from_u64(0);
206
207 const NUM_PAIRS: usize = 300;
209
210 let inputs: Vec<_> = (0..NUM_PAIRS * 2)
212 .map(|i| {
213 let mut hasher = VisionHasherDigest::new();
214 hasher.update(i.to_le_bytes());
215 hasher.update(rng.random::<[u8; 32]>());
216 hasher.finalize()
217 })
218 .collect();
219
220 let sequential_results: Vec<_> = (0..NUM_PAIRS)
222 .map(|i| {
223 parallel
224 .compression
225 .compress([inputs[i * 2], inputs[i * 2 + 1]])
226 })
227 .collect();
228
229 let mut parallel_outputs = vec![MaybeUninit::uninit(); NUM_PAIRS];
231 parallel.parallel_compress(&inputs, &mut parallel_outputs);
232 let parallel_results: Vec<_> = parallel_outputs
233 .into_iter()
234 .map(|out| unsafe { out.assume_init() })
235 .collect();
236
237 assert_eq!(sequential_results, parallel_results);
238 }
239}