binius_m3/gadgets/
barrel_shifter.rs

1// Copyright 2025 Irreducible Inc.
2
3use std::cell::RefMut;
4
5use binius_core::oracle::ShiftVariant;
6use binius_field::{Field, PackedExtension, PackedFieldIndexable, packed::set_packed_slice};
7
8use crate::builder::{B1, B32, B128, Col, Expr, TableBuilder, TableWitnessSegment, upcast_col};
9
10/// Maximum number of bits of the shift amount, i.e. 0 < shift_amount < 1 <<
11/// SHIFT_MAX_BITS - 1 = 31 where dst_val = src_val >> shift_amount or dst_val =
12/// src_val << shift_amount
13const MAX_SHIFT_BITS: usize = 5;
14
15/// A gadget for performing a barrel shift circuit (<https://en.wikipedia.org/wiki/Barrel_shifter>).
16///
17/// The `BarrelShifter` gadget allows for left shifts, right shifts, and
18/// rotations on 32-bit inputs, with a configurable shift amount and direction.
19pub struct BarrelShifter {
20	/// The input column representing the 32-bit value to be shifted.
21	input: Col<B1, 32>,
22
23	/// The shift amount column representing the 5 bits of positions to shift,
24	/// ignoring the remaining 11.
25	shift_amount: Col<B1, 16>,
26
27	/// Virtual columns containing the binary decomposition of the shifted amount.
28	shift_amount_bits: [Col<B1>; MAX_SHIFT_BITS],
29
30	// TODO: Try to replace the Vec with an array.
31	/// Partial shift virtual columns containing the partial_shift[i - 1]
32	/// shifted by 2^i.
33	shifted: Vec<Col<B1, 32>>, // Virtual
34
35	/// Partial shift virtual columns containing either shifted[i] or partial_shit[i-1],
36	/// depending on the value of `shift_amount_bits`.
37	partial_shift: [Col<B1, 32>; MAX_SHIFT_BITS],
38
39	/// The output column representing the result of the shift operation. This column is
40	/// virtual or committed, depending on the flags
41	pub output: Col<B1, 32>,
42
43	/// The variant of the shift operation: logical left, logical right or
44	/// circular left.
45	pub variant: ShiftVariant,
46}
47
48impl BarrelShifter {
49	/// Creates a new instance of the `BarrelShifter` gadget.
50	///
51	/// # Arguments
52	///
53	/// * `table` - A mutable reference to the `TableBuilder` used to define the gadget.
54	/// * `input` - The input column of type `Col<B1, 32>`.
55	/// * `shift_amount` - The shift amount column of type `Col<B1, 16>`. The 11 most significant
56	///   bits are ignored.
57	/// * `variant` - Indicates whether the circuits performs a logical left, logical right, or
58	///   circular left shift.
59	///
60	/// # Returns
61	///
62	/// A new instance of the `BarrelShifter` gadget.
63	pub fn new(
64		table: &mut TableBuilder,
65		input: Col<B1, 32>,
66		shift_amount: Col<B1, 16>,
67		variant: ShiftVariant,
68	) -> Self {
69		let partial_shift =
70			core::array::from_fn(|i| table.add_committed(format!("partial_shift_{i}")));
71		let shift_amount_bits: [_; MAX_SHIFT_BITS] = core::array::from_fn(|i| {
72			table.add_selected(format!("shift_amount_bits_{i}"), shift_amount, i)
73		});
74		let mut shifted = Vec::with_capacity(MAX_SHIFT_BITS);
75		let mut current_shift = input;
76		for i in 0..MAX_SHIFT_BITS {
77			shifted.push(table.add_shifted("shifted", current_shift, 5, 1 << i, variant));
78			let partial_shift_packed: Col<B32> =
79				table.add_packed(format!("partial_shift_packed_{i}"), partial_shift[i]);
80			let shifted_packed: Expr<B32, 1> = table
81				.add_packed(format!("shifted_packed_{i}"), shifted[i])
82				.into();
83			let current_shift_packed: Col<B32> =
84				table.add_packed(format!("current_shift_packed_{i}"), current_shift);
85			table.assert_zero(
86				format!("correct_partial_shift_{i}"),
87				partial_shift_packed
88					- (shifted_packed * upcast_col(shift_amount_bits[i])
89						+ current_shift_packed * (upcast_col(shift_amount_bits[i]) + B32::ONE)),
90			);
91			current_shift = partial_shift[i];
92		}
93
94		Self {
95			input,
96			shift_amount,
97			shift_amount_bits,
98			shifted,
99			partial_shift,
100			output: current_shift,
101			variant,
102		}
103	}
104
105	/// Populates the table with witness values for the barrel shifter.
106	///
107	/// # Arguments
108	///
109	/// * `index` - A mutable reference to the `TableWitness` used to populate the table.
110	///
111	/// # Returns
112	///
113	/// A `Result` indicating success or failure.
114	pub fn populate<P>(&self, index: &mut TableWitnessSegment<P>) -> Result<(), anyhow::Error>
115	where
116		P: PackedFieldIndexable<Scalar = B128> + PackedExtension<B1>,
117	{
118		let input: RefMut<'_, [u32]> = index.get_mut_as(self.input).unwrap();
119		let shift_amount: RefMut<'_, [u16]> = index.get_mut_as(self.shift_amount).unwrap();
120		let mut partial_shift: [_; MAX_SHIFT_BITS] =
121			array_util::try_from_fn(|i| index.get_mut_as(self.partial_shift[i]))?;
122		let mut shifted: [_; MAX_SHIFT_BITS] =
123			array_util::try_from_fn(|i| index.get_mut_as(self.shifted[i]))?;
124		let mut shift_amount_bits: [_; MAX_SHIFT_BITS] =
125			array_util::try_from_fn(|i| index.get_mut(self.shift_amount_bits[i]))?;
126
127		for i in 0..index.size() {
128			let mut current_shift = input[i];
129			for j in 0..MAX_SHIFT_BITS {
130				let bit = ((shift_amount[i] >> j) & 1) == 1;
131				set_packed_slice(&mut shift_amount_bits[j], i, B1::from(bit));
132				shifted[j][i] = match self.variant {
133					ShiftVariant::LogicalLeft => current_shift << (1 << j),
134					ShiftVariant::LogicalRight => current_shift >> (1 << j),
135					ShiftVariant::CircularLeft => {
136						(current_shift << (1 << j)) | (current_shift >> (32 - (1 << j)))
137					}
138				};
139				if bit {
140					current_shift = shifted[j][i];
141				}
142				partial_shift[j][i] = current_shift;
143			}
144		}
145		Ok(())
146	}
147}
148
149#[cfg(test)]
150mod tests {
151	use std::iter::repeat_with;
152
153	use binius_compute::cpu::alloc::CpuComputeAllocator;
154	use binius_field::{arch::OptimalUnderlier128b, as_packed_field::PackedType};
155	use rand::{Rng, SeedableRng, rngs::StdRng};
156
157	use super::*;
158	use crate::builder::{ConstraintSystem, WitnessIndex};
159
160	fn test_barrel_shifter(variant: ShiftVariant) {
161		let mut cs = ConstraintSystem::new();
162		let mut table = cs.add_table("BarrelShifterTable");
163		let table_id = table.id();
164		let mut allocator = CpuComputeAllocator::new(1 << 12);
165		let allocator = allocator.into_bump_allocator();
166
167		let input = table.add_committed::<B1, 32>("input");
168		let shift_amount = table.add_committed::<B1, 16>("shift_amount");
169
170		let shifter = BarrelShifter::new(&mut table, input, shift_amount, variant);
171
172		let mut witness =
173			WitnessIndex::<PackedType<OptimalUnderlier128b, B128>>::new(&cs, &allocator);
174		let table_witness = witness.init_table(table_id, 1 << 8).unwrap();
175		let mut segment = table_witness.full_segment();
176
177		let mut rng = StdRng::seed_from_u64(0x1234);
178		let test_inputs = repeat_with(|| rng.random())
179			.take(1 << 8)
180			.collect::<Vec<u32>>();
181
182		for (i, (input, shift_amount)) in (*segment.get_mut_as(input).unwrap())
183			.iter_mut()
184			.zip(segment.get_mut_as(shift_amount).unwrap().iter_mut())
185			.enumerate()
186		{
187			*input = test_inputs[i];
188			*shift_amount = i as u16; // Only the first 5 bits are used
189		}
190
191		shifter.populate(&mut segment).unwrap();
192
193		for (i, &output) in segment
194			.get_as::<u32, B1, 32>(shifter.output)
195			.unwrap()
196			.iter()
197			.enumerate()
198		{
199			let expected_output = match variant {
200				ShiftVariant::LogicalLeft => test_inputs[i] << (i % 32),
201				ShiftVariant::LogicalRight => test_inputs[i] >> (i % 32),
202				ShiftVariant::CircularLeft => test_inputs[i].rotate_left(i as u32 % 32),
203			};
204			assert_eq!(output, expected_output);
205		}
206
207		let ccs = cs.compile().unwrap();
208		let table_sizes = witness.table_sizes();
209		let witness = witness.into_multilinear_extension_index();
210
211		binius_core::constraint_system::validate::validate_witness(
212			&ccs,
213			&[],
214			&table_sizes,
215			&witness,
216		)
217		.unwrap();
218	}
219
220	#[test]
221	fn test_barrel_shifter_logical_left() {
222		test_barrel_shifter(ShiftVariant::LogicalLeft);
223	}
224
225	#[test]
226	fn test_barrel_shifter_logical_right() {
227		test_barrel_shifter(ShiftVariant::LogicalRight);
228	}
229
230	#[test]
231	fn test_barrel_shifter_circular_left() {
232		test_barrel_shifter(ShiftVariant::CircularLeft);
233	}
234}