binius_m3/gadgets/
barrel_shifter.rs1use std::cell::RefMut;
4
5use binius_core::oracle::ShiftVariant;
6use binius_field::{Field, PackedExtension, PackedFieldIndexable, packed::set_packed_slice};
7
8use crate::builder::{B1, B32, B128, Col, Expr, TableBuilder, TableWitnessSegment, upcast_col};
9
10const MAX_SHIFT_BITS: usize = 5;
14
15pub struct BarrelShifter {
20 input: Col<B1, 32>,
22
23 shift_amount: Col<B1, 16>,
26
27 shift_amount_bits: [Col<B1>; MAX_SHIFT_BITS],
29
30 shifted: Vec<Col<B1, 32>>, partial_shift: [Col<B1, 32>; MAX_SHIFT_BITS],
38
39 pub output: Col<B1, 32>,
42
43 pub variant: ShiftVariant,
46}
47
48impl BarrelShifter {
49 pub fn new(
64 table: &mut TableBuilder,
65 input: Col<B1, 32>,
66 shift_amount: Col<B1, 16>,
67 variant: ShiftVariant,
68 ) -> Self {
69 let partial_shift =
70 core::array::from_fn(|i| table.add_committed(format!("partial_shift_{i}")));
71 let shift_amount_bits: [_; MAX_SHIFT_BITS] = core::array::from_fn(|i| {
72 table.add_selected(format!("shift_amount_bits_{i}"), shift_amount, i)
73 });
74 let mut shifted = Vec::with_capacity(MAX_SHIFT_BITS);
75 let mut current_shift = input;
76 for i in 0..MAX_SHIFT_BITS {
77 shifted.push(table.add_shifted("shifted", current_shift, 5, 1 << i, variant));
78 let partial_shift_packed: Col<B32> =
79 table.add_packed(format!("partial_shift_packed_{i}"), partial_shift[i]);
80 let shifted_packed: Expr<B32, 1> = table
81 .add_packed(format!("shifted_packed_{i}"), shifted[i])
82 .into();
83 let current_shift_packed: Col<B32> =
84 table.add_packed(format!("current_shift_packed_{i}"), current_shift);
85 table.assert_zero(
86 format!("correct_partial_shift_{i}"),
87 partial_shift_packed
88 - (shifted_packed * upcast_col(shift_amount_bits[i])
89 + current_shift_packed * (upcast_col(shift_amount_bits[i]) + B32::ONE)),
90 );
91 current_shift = partial_shift[i];
92 }
93
94 Self {
95 input,
96 shift_amount,
97 shift_amount_bits,
98 shifted,
99 partial_shift,
100 output: current_shift,
101 variant,
102 }
103 }
104
105 pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
115 where
116 P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1>,
117 {
118 let input: RefMut<'_, [u32]> = index.get_mut_as(self.input).unwrap();
119 let shift_amount: RefMut<'_, [u16]> = index.get_mut_as(self.shift_amount).unwrap();
120 let mut partial_shift: [_; MAX_SHIFT_BITS] =
121 array_util::try_from_fn(|i| index.get_mut_as(self.partial_shift[i]))?;
122 let mut shifted: [_; MAX_SHIFT_BITS] =
123 array_util::try_from_fn(|i| index.get_mut_as(self.shifted[i]))?;
124 let mut shift_amount_bits: [_; MAX_SHIFT_BITS] =
125 array_util::try_from_fn(|i| index.get_mut(self.shift_amount_bits[i]))?;
126
127 for i in 0..index.size() {
128 let mut current_shift = input[i];
129 for j in 0..MAX_SHIFT_BITS {
130 let bit = ((shift_amount[i] >> j) & 1) == 1;
131 set_packed_slice(&mut shift_amount_bits[j], i, B1::from(bit));
132 shifted[j][i] = match self.variant {
133 ShiftVariant::LogicalLeft => current_shift << (1 << j),
134 ShiftVariant::LogicalRight => current_shift >> (1 << j),
135 ShiftVariant::CircularLeft => {
136 (current_shift << (1 << j)) | (current_shift >> (32 - (1 << j)))
137 }
138 };
139 if bit {
140 current_shift = shifted[j][i];
141 }
142 partial_shift[j][i] = current_shift;
143 }
144 }
145 Ok(())
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use std::iter::repeat_with;
152
153 use binius_compute::cpu::alloc::CpuComputeAllocator;
154 use binius_field::{arch::OptimalUnderlier128b, as_packed_field::PackedType};
155 use rand::{Rng, SeedableRng, rngs::StdRng};
156
157 use super::*;
158 use crate::builder::{ConstraintSystem, WitnessIndex};
159
160 fn test_barrel_shifter(variant: ShiftVariant) {
161 let mut cs = ConstraintSystem::new();
162 let mut table = cs.add_table("BarrelShifterTable");
163 let table_id = table.id();
164 let mut allocator = CpuComputeAllocator::new(1 << 12);
165 let allocator = allocator.into_bump_allocator();
166
167 let input = table.add_committed::<B1, 32>("input");
168 let shift_amount = table.add_committed::<B1, 16>("shift_amount");
169
170 let shifter = BarrelShifter::new(&mut table, input, shift_amount, variant);
171
172 let mut witness =
173 WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
174 let table_witness = witness.init_table(table_id, 1 << 8).unwrap();
175 let mut segment = table_witness.full_segment();
176
177 let mut rng = StdRng::seed_from_u64(0x1234);
178 let test_inputs = repeat_with(|| rng.random())
179 .take(1 << 8)
180 .collect::<Vec<u32>>();
181
182 for (i, (input, shift_amount)) in (*segment.get_mut_as(input).unwrap())
183 .iter_mut()
184 .zip(segment.get_mut_as(shift_amount).unwrap().iter_mut())
185 .enumerate()
186 {
187 *input = test_inputs[i];
188 *shift_amount = i as u16; }
190
191 shifter.populate(&mut segment).unwrap();
192
193 for (i, &output) in segment
194 .get_as::<u32, B1, 32>(shifter.output)
195 .unwrap()
196 .iter()
197 .enumerate()
198 {
199 let expected_output = match variant {
200 ShiftVariant::LogicalLeft => test_inputs[i] << (i % 32),
201 ShiftVariant::LogicalRight => test_inputs[i] >> (i % 32),
202 ShiftVariant::CircularLeft => test_inputs[i].rotate_left(i as u32 % 32),
203 };
204 assert_eq!(output, expected_output);
205 }
206
207 let ccs = cs.compile().unwrap();
208 let table_sizes = witness.table_sizes();
209 let witness = witness.into_multilinear_extension_index();
210
211 binius_core::constraint_system::validate::validate_witness(
212 &ccs,
213 &[],
214 &table_sizes,
215 &witness,
216 )
217 .unwrap();
218 }
219
220 #[test]
221 fn test_barrel_shifter_logical_left() {
222 test_barrel_shifter(ShiftVariant::LogicalLeft);
223 }
224
225 #[test]
226 fn test_barrel_shifter_logical_right() {
227 test_barrel_shifter(ShiftVariant::LogicalRight);
228 }
229
230 #[test]
231 fn test_barrel_shifter_circular_left() {
232 test_barrel_shifter(ShiftVariant::CircularLeft);
233 }
234}