binius_hash/
parallel_compression.rs1use std::{array, fmt::Debug, mem::MaybeUninit};
5
6use binius_utils::rayon::prelude::*;
7
8use crate::PseudoCompressionFunction;
9
10pub trait ParallelPseudoCompression<T, const N: usize> {
20 type Compression: PseudoCompressionFunction<T, N>;
22
23 fn compression(&self) -> &Self::Compression;
25
26 fn parallel_compress(&self, inputs: &[T], out: &mut [MaybeUninit<T>]);
47}
48
49#[derive(Debug, Clone)]
54pub struct ParallelCompressionAdaptor<C> {
55 compression: C,
56}
57
58impl<C> ParallelCompressionAdaptor<C> {
59 pub fn new(compression: C) -> Self {
61 Self { compression }
62 }
63
64 pub fn compression(&self) -> &C {
66 &self.compression
67 }
68}
69
70impl<T, C, const ARITY: usize> ParallelPseudoCompression<T, ARITY> for ParallelCompressionAdaptor<C>
71where
72 T: Clone + Send + Sync,
73 C: PseudoCompressionFunction<T, ARITY> + Sync,
74{
75 type Compression = C;
76
77 fn compression(&self) -> &Self::Compression {
78 &self.compression
79 }
80
81 fn parallel_compress(&self, inputs: &[T], out: &mut [MaybeUninit<T>]) {
82 assert_eq!(inputs.len(), ARITY * out.len(), "Input length must be N * output length");
83
84 inputs
85 .par_chunks_exact(ARITY)
86 .zip(out.par_iter_mut())
87 .for_each(|(chunk, output)| {
88 let chunk_array: [T; ARITY] = array::from_fn(|j| chunk[j].clone());
90 let compressed = self.compression.compress(chunk_array);
91 output.write(compressed);
92 });
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use std::mem::MaybeUninit;
99
100 use rand::{Rng, SeedableRng, rngs::StdRng};
101
102 use super::*;
103
104 #[derive(Clone, Debug)]
106 struct XorCompression;
107
108 impl PseudoCompressionFunction<u64, 3> for XorCompression {
109 fn compress(&self, input: [u64; 3]) -> u64 {
110 input[0] ^ input[1] ^ input[2]
111 }
112 }
113
114 #[test]
115 fn test_parallel_compression_adaptor() {
116 let mut rng = StdRng::seed_from_u64(0);
117 let compression = XorCompression;
118 let adaptor = ParallelCompressionAdaptor::new(compression.clone());
119
120 const N: usize = 3;
122 const NUM_CHUNKS: usize = 4;
123 let inputs: Vec<u64> = (0..N * NUM_CHUNKS).map(|_| rng.random()).collect();
124
125 let mut adaptor_output = [MaybeUninit::<u64>::uninit(); NUM_CHUNKS];
127 adaptor.parallel_compress(&inputs, &mut adaptor_output);
128 let adaptor_results: Vec<u64> = adaptor_output
129 .into_iter()
130 .map(|x| unsafe { x.assume_init() })
131 .collect();
132
133 let mut manual_results = Vec::new();
135 for chunk_idx in 0..NUM_CHUNKS {
136 let start = chunk_idx * N;
137 let chunk = [inputs[start], inputs[start + 1], inputs[start + 2]];
138 manual_results.push(compression.compress(chunk));
139 }
140
141 assert_eq!(adaptor_results, manual_results);
143 }
144
145 #[test]
146 #[should_panic(expected = "Input length must be N * output length")]
147 fn test_mismatched_input_length() {
148 let compression = XorCompression;
149 let adaptor = ParallelCompressionAdaptor::new(compression);
150
151 let inputs = vec![1u64, 2, 3, 4]; let mut output = [MaybeUninit::<u64>::uninit(); 2]; adaptor.parallel_compress(&inputs, &mut output);
155 }
156}