Skip to main content

binius_frontend/compiler/
gate_graph.rs

1// Copyright 2025 Irreducible Inc.
2use std::collections::{HashMap, HashSet};
3
4use binius_core::word::Word;
5use cranelift_entity::{PrimaryMap, SecondaryMap, entity_impl};
6
7use crate::compiler::{
8	gate::opcode::{Opcode, OpcodeShape},
9	hints::{HintId, HintRegistry},
10	pathspec::{PathSpec, PathSpecTree},
11};
12
13#[derive(Default)]
14pub struct ConstPool {
15	pub pool: HashMap<Word, Wire>,
16}
17
18impl ConstPool {
19	pub fn new() -> Self {
20		ConstPool::default()
21	}
22
23	pub fn get(&self, value: Word) -> Option<Wire> {
24		self.pool.get(&value).cloned()
25	}
26
27	pub fn insert(&mut self, word: Word, wire: Wire) {
28		let prev = self.pool.insert(word, wire);
29		assert!(prev.is_none());
30	}
31}
32
33/// A wire through which a value flows in and out of gates.
34///
35/// The difference from `ValueIndex` is that a wire is abstract. Some wires could be moved during
36/// compilation and some wires might be pruned altogether.
37#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
38pub struct Wire(u32);
39entity_impl!(Wire);
40
41#[derive(Copy, Clone, Debug)]
42pub enum WireKind {
43	Constant(Word),
44	Inout,
45	Witness,
46	/// An internal wire is a wire created inside a gate.
47	Internal,
48	/// A scratch wire is a temporary wire used only during evaluation.
49	Scratch,
50}
51impl WireKind {
52	/// Returns `true` if this is a constant wire.
53	pub fn is_const(&self) -> bool {
54		matches!(self, WireKind::Constant(_))
55	}
56}
57
58#[derive(Copy, Clone)]
59pub struct WireData {
60	pub kind: WireKind,
61}
62
63/// Gate ID - identifies a gate in the graph
64#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
65pub struct Gate(u32);
66
67entity_impl!(Gate);
68
69/// A handy struct that allows a more type safe destructure.
70pub struct GateParam<'a> {
71	pub constants: &'a [Wire],
72	pub inputs: &'a [Wire],
73	pub outputs: &'a [Wire],
74	pub aux: &'a [Wire],
75	pub scratch: &'a [Wire],
76	pub imm: &'a [u32],
77}
78
79/// Describes a particular gate in the gate graph, it's type, input and output wires and
80/// immediate parameters.
81pub struct GateData {
82	/// The code of operation of this gate.
83	pub opcode: Opcode,
84
85	/// The input and output wires of this gate.
86	///
87	/// They are laid out in the following order:
88	///
89	/// - Constants
90	/// - Inputs
91	/// - Outputs
92	/// - Aux
93	/// - Scratch
94	///
95	/// The number of input and output wires is specified by the opcode's shape.
96	pub wires: Vec<Wire>,
97
98	/// The immediate parameters of this gate.
99	///
100	/// The immediates contain compile-time parameters of the circuits, such as shift amounts,
101	/// byte indices, etc.
102	///
103	/// The length of the immediates is specified by the opcode's shape.
104	pub immediates: Vec<u32>,
105
106	/// The dimensions of this gate.
107	///
108	/// This is empty for gates of constant shape. When the shape is variable, the number of
109	/// input, output and internal wires is a function of non-empty `dimensions`. This function is
110	/// typically linear.
111	pub dimensions: Vec<usize>,
112}
113
114impl GateData {
115	/// Slice this gate's wire vector into its semantic portions.
116	///
117	/// Panics for [`Opcode::Hint`] gates — those carry their shape in the [`HintRegistry`]
118	/// and must use [`gate_param_with_registry`](Self::gate_param_with_registry) instead.
119	/// The ~25 per-gate-module callers never see `Opcode::Hint`, so they use this method.
120	pub fn gate_param(&self) -> GateParam<'_> {
121		self.gate_param_for_shape(self.opcode.shape(&self.dimensions))
122	}
123
124	/// Like [`gate_param`](Self::gate_param) but works for [`Opcode::Hint`] gates by looking
125	/// up the shape in the provided registry.
126	pub fn gate_param_with_registry(&self, registry: &HintRegistry) -> GateParam<'_> {
127		self.gate_param_for_shape(self.shape(registry))
128	}
129
130	fn gate_param_for_shape(&self, shape: OpcodeShape) -> GateParam<'_> {
131		let start_const = 0;
132		let end_const = shape.const_in.len();
133		let start_input = end_const;
134		let end_input = start_input + shape.n_in;
135		let start_output = end_input;
136		let end_output = start_output + shape.n_out;
137		let start_aux = end_output;
138		let end_aux = start_aux + shape.n_aux;
139		let start_scratch = end_aux;
140		let end_scratch = start_scratch + shape.n_scratch;
141		GateParam {
142			constants: &self.wires[start_const..end_const],
143			inputs: &self.wires[start_input..end_input],
144			outputs: &self.wires[start_output..end_output],
145			aux: &self.wires[start_aux..end_aux],
146			scratch: &self.wires[start_scratch..end_scratch],
147			imm: &self.immediates,
148		}
149	}
150
151	/// The gate shape (takes dimensions into account).
152	///
153	/// For [`Opcode::Hint`] the shape is looked up via `registry`; the hint id lives in
154	/// `immediates[0]` and the user dimensions are `&self.dimensions`.
155	pub fn shape(&self, registry: &HintRegistry) -> OpcodeShape {
156		match self.opcode {
157			Opcode::Hint => {
158				let hint_id = self.immediates[0];
159				let (n_in, n_out) = registry.shape(hint_id, &self.dimensions);
160				OpcodeShape {
161					const_in: &[],
162					n_in,
163					n_out,
164					n_aux: 0,
165					n_scratch: 0,
166					n_imm: 1,
167				}
168			}
169			_ => self.opcode.shape(&self.dimensions),
170		}
171	}
172
173	/// Ensures the gate has the right shape.
174	pub fn validate_shape(&self, registry: &HintRegistry) {
175		let shape = self.shape(registry);
176		let expected_wires =
177			shape.const_in.len() + shape.n_in + shape.n_out + shape.n_aux + shape.n_scratch;
178		assert_eq!(self.wires.len(), expected_wires);
179		assert_eq!(self.immediates.len(), shape.n_imm);
180	}
181}
182
183/// Gate graph replaces the current Shared struct
184pub struct GateGraph {
185	// Primary maps
186	pub gates: PrimaryMap<Gate, GateData>,
187	pub wires: PrimaryMap<Wire, WireData>,
188
189	pub path_spec_tree: PathSpecTree,
190	pub gate_origin: SecondaryMap<Gate, PathSpec>,
191	pub assertion_names: SecondaryMap<Gate, PathSpec>,
192
193	pub const_pool: ConstPool,
194	pub n_witness: usize,
195	pub n_inout: usize,
196
197	// Use-def analysis
198	/// Maps each wire to the gate that defines it (if any)
199	pub wire_def: SecondaryMap<Wire, Option<Gate>>,
200	/// Maps each wire to the set of gates that use it
201	wire_uses: SecondaryMap<Wire, HashSet<Gate>>,
202}
203
204impl GateGraph {
205	pub fn new() -> Self {
206		let path_spec_tree = PathSpecTree::new();
207		let root = path_spec_tree.root();
208		Self {
209			gates: PrimaryMap::new(),
210			wires: PrimaryMap::new(),
211			path_spec_tree,
212			gate_origin: SecondaryMap::with_default(root),
213			assertion_names: SecondaryMap::with_default(root),
214			const_pool: ConstPool::new(),
215			n_witness: 0,
216			n_inout: 0,
217			wire_def: SecondaryMap::new(),
218			wire_uses: SecondaryMap::new(),
219		}
220	}
221
222	/// Runs a validation pass ensuring all the invariants hold.
223	pub fn validate(&self, hint_registry: &HintRegistry) {
224		// Every gate holds shape.
225		for gate in self.gates.values() {
226			gate.validate_shape(hint_registry);
227		}
228	}
229
230	pub fn add_inout(&mut self) -> Wire {
231		self.n_inout += 1;
232		self.wires.push(WireData {
233			kind: WireKind::Inout,
234		})
235	}
236
237	pub fn add_witness(&mut self) -> Wire {
238		self.n_witness += 1;
239		self.wires.push(WireData {
240			kind: WireKind::Witness,
241		})
242	}
243
244	pub fn add_internal(&mut self) -> Wire {
245		// Internal wires are treated as witnesses for allocation purposes
246		self.n_witness += 1;
247		self.wires.push(WireData {
248			kind: WireKind::Internal,
249		})
250	}
251
252	pub fn add_scratch(&mut self) -> Wire {
253		// Scratch wires are temporary storage, not part of witness
254		self.wires.push(WireData {
255			kind: WireKind::Scratch,
256		})
257	}
258
259	pub fn add_constant(&mut self, word: Word) -> Wire {
260		if let Some(wire) = self.const_pool.get(word) {
261			return wire;
262		}
263		let wire = self.wires.push(WireData {
264			kind: WireKind::Constant(word),
265		});
266		self.const_pool.insert(word, wire);
267		wire
268	}
269
270	/// Emits a gate with the given opcode, inputs and outputs.
271	pub fn emit_gate(
272		&mut self,
273		gate_origin: PathSpec,
274		opcode: Opcode,
275		inputs: impl IntoIterator<Item = Wire>,
276		outputs: impl IntoIterator<Item = Wire>,
277	) -> Gate {
278		self.emit_gate_generic(gate_origin, opcode, inputs, outputs, &[], &[])
279	}
280
281	/// Emits a gate with the given opcode, inputs, outputs and a single immediate argument.
282	pub fn emit_gate_imm(
283		&mut self,
284		gate_origin: PathSpec,
285		opcode: Opcode,
286		inputs: impl IntoIterator<Item = Wire>,
287		outputs: impl IntoIterator<Item = Wire>,
288		imm32: u32,
289	) -> Gate {
290		self.emit_gate_generic(gate_origin, opcode, inputs, outputs, &[], &[imm32])
291	}
292
293	/// Creates a gate inline with the given opcode's shape parametrized with the inputs, outputs
294	/// and immediates.
295	///
296	/// Panics if the resulting opcode shape is not valid.
297	pub fn emit_gate_generic(
298		&mut self,
299		gate_origin: PathSpec,
300		opcode: Opcode,
301		inputs: impl IntoIterator<Item = Wire>,
302		outputs: impl IntoIterator<Item = Wire>,
303		dimensions: &[usize],
304		immediates: &[u32],
305	) -> Gate {
306		// Hint gates go through `emit_hint_gate`, which knows the hint's shape from the
307		// `Hint` impl directly without needing the registry here.
308		assert!(
309			opcode != Opcode::Hint,
310			"emit_gate_generic does not handle Opcode::Hint; use emit_hint_gate"
311		);
312		let shape = opcode.shape(dimensions);
313		let mut wires: Vec<Wire> = Vec::with_capacity(
314			shape.const_in.len() + shape.n_in + shape.n_out + shape.n_aux + shape.n_scratch,
315		);
316		for c in shape.const_in {
317			wires.push(self.add_constant(*c));
318		}
319		wires.extend(inputs);
320		wires.extend(outputs);
321		for _ in 0..shape.n_aux {
322			// We create internal wires as auxiliary.
323			wires.push(self.add_internal());
324		}
325		for _ in 0..shape.n_scratch {
326			wires.push(self.add_scratch());
327		}
328		let data = GateData {
329			opcode,
330			wires,
331			dimensions: dimensions.to_vec(),
332			immediates: immediates.to_vec(),
333		};
334		// Inline validate_shape: non-hint shape doesn't need a registry.
335		let expected_wires =
336			shape.const_in.len() + shape.n_in + shape.n_out + shape.n_aux + shape.n_scratch;
337		assert_eq!(data.wires.len(), expected_wires);
338		assert_eq!(data.immediates.len(), shape.n_imm);
339
340		let gate = self.gates.push(data);
341
342		self.gate_origin[gate] = gate_origin;
343
344		gate
345	}
346
347	/// Emit a generic [`Opcode::Hint`] gate. Caller has already validated input arity
348	/// against the hint's [`Hint::shape`](crate::compiler::hints::Hint::shape) and allocated
349	/// `n_out` output wires.
350	pub fn emit_hint_gate(
351		&mut self,
352		gate_origin: PathSpec,
353		hint_id: HintId,
354		dimensions: &[usize],
355		inputs: impl IntoIterator<Item = Wire>,
356		outputs: impl IntoIterator<Item = Wire>,
357	) -> Gate {
358		let mut wires: Vec<Wire> = Vec::new();
359		wires.extend(inputs);
360		wires.extend(outputs);
361		let data = GateData {
362			opcode: Opcode::Hint,
363			wires,
364			dimensions: dimensions.to_vec(),
365			immediates: vec![hint_id],
366		};
367		let gate = self.gates.push(data);
368		self.gate_origin[gate] = gate_origin;
369		gate
370	}
371
372	/// Updates use-def information for a newly added gate
373	fn update_use_def_for_gate(&mut self, gate: Gate, hint_registry: &HintRegistry) {
374		let gate_data = &self.gates[gate];
375		let gate_param = gate_data.gate_param_with_registry(hint_registry);
376
377		// Record this gate as defining its outputs
378		for &output_wire in gate_param.outputs {
379			self.wire_def[output_wire] = Some(gate);
380		}
381
382		// Record this gate as defining its internal wires
383		for &aux_wire in gate_param.aux {
384			self.wire_def[aux_wire] = Some(gate);
385		}
386
387		// Record this gate as using its inputs
388		for &input_wire in gate_param.inputs {
389			self.wire_uses[input_wire].insert(gate);
390		}
391
392		// Record this gate as using its constants
393		for &const_wire in gate_param.constants {
394			self.wire_uses[const_wire].insert(gate);
395		}
396	}
397
398	/// Rebuilds the use-def chains from scratch by analyzing all gates.
399	///
400	/// `hint_registry` must contain all [`Opcode::Hint`] gates' hints.
401	pub fn rebuild_use_def_chains(&mut self, hint_registry: &HintRegistry) {
402		// Clear existing use-def information
403		self.wire_def.clear();
404		self.wire_uses.clear();
405
406		// Rebuild from all gates
407		for gate in self.gates.keys() {
408			self.update_use_def_for_gate(gate, hint_registry);
409		}
410	}
411
412	/// Returns all gates that use the given wire
413	pub fn get_wire_uses(&self, wire: Wire) -> &HashSet<Gate> {
414		&self.wire_uses[wire]
415	}
416
417	/// Returns an iterator over all constant wires and their data
418	pub fn iter_const_wires(&self) -> impl Iterator<Item = (Wire, &WireData)> {
419		self.wires
420			.iter()
421			.filter(|(wire, _)| self.wire_data(*wire).kind.is_const())
422	}
423
424	/// Gets wire data by reference
425	pub fn wire_data(&self, wire: Wire) -> &WireData {
426		&self.wires[wire]
427	}
428
429	/// Gets gate data by reference
430	pub fn gate_data(&self, gate: Gate) -> &GateData {
431		&self.gates[gate]
432	}
433
434	/// Replaces all occurrences of a wire in a gate with another wire
435	pub fn replace_gate_wire(&mut self, gate: Gate, old_wire: Wire, new_wire: Wire) {
436		let gate_data = &mut self.gates[gate];
437		for wire in &mut gate_data.wires {
438			if *wire == old_wire {
439				*wire = new_wire;
440			}
441		}
442	}
443
444	/// Updates use-def chains when replacing a wire use
445	pub fn update_wire_use(&mut self, old_wire: Wire, new_wire: Wire, gate: Gate) {
446		self.wire_uses[old_wire].remove(&gate);
447		self.wire_uses[new_wire].insert(gate);
448	}
449
450	/// Replaces all uses of old_wire with a constant wire containing the given value.
451	///
452	/// Returns the constant wire that was used, the number of individual wire replacements,
453	/// and the list of gates that were actually affected by this replacement.
454	/// This encapsulates both wire replacement and use-def chain updates.
455	pub fn replace_wire_with_constant(
456		&mut self,
457		old_wire: Wire,
458		value: Word,
459		hint_registry: &HintRegistry,
460	) -> (Wire, usize, Vec<Gate>) {
461		let const_wire = self.add_constant(value);
462
463		if const_wire == old_wire {
464			return (const_wire, 0, Vec::new());
465		}
466
467		// Get all users of the old wire (clone to avoid borrow conflicts)
468		let users: Vec<Gate> = self.get_wire_uses(old_wire).iter().copied().collect();
469		let mut total_replacements = 0;
470
471		// Replace wire references in all user gates
472		for user_gate in &users {
473			// Count how many times this wire appears in this gate before replacing
474			let gate_data = self.gate_data(*user_gate);
475			let gate_param = gate_data.gate_param_with_registry(hint_registry);
476			let replacements_in_gate = gate_param.inputs.iter().filter(|&&w| w == old_wire).count()
477				+ gate_param
478					.outputs
479					.iter()
480					.filter(|&&w| w == old_wire)
481					.count();
482			total_replacements += replacements_in_gate;
483
484			self.replace_gate_wire(*user_gate, old_wire, const_wire);
485			self.update_wire_use(old_wire, const_wire, *user_gate);
486		}
487
488		(const_wire, total_replacements, users)
489	}
490}
491
492impl Default for GateGraph {
493	fn default() -> Self {
494		Self::new()
495	}
496}
497
498#[cfg(test)]
499mod tests {
500	use super::*;
501	use crate::compiler::gate::opcode::Opcode;
502
503	// Test helper functions
504	fn get_wire_def(graph: &GateGraph, wire: Wire) -> Option<Gate> {
505		graph.wire_def[wire]
506	}
507
508	fn wire_use_count(graph: &GateGraph, wire: Wire) -> usize {
509		graph.wire_uses[wire].len()
510	}
511
512	fn is_wire_single_use(graph: &GateGraph, wire: Wire) -> bool {
513		graph.wire_uses[wire].len() == 1
514	}
515
516	fn get_wire_single_use(graph: &GateGraph, wire: Wire) -> Option<Gate> {
517		let uses = &graph.wire_uses[wire];
518		if uses.len() == 1 {
519			uses.iter().next().copied()
520		} else {
521			None
522		}
523	}
524
525	fn get_gate_inputs(graph: &GateGraph, gate: Gate) -> Vec<Wire> {
526		let gate_data = &graph.gates[gate];
527		let gate_param = gate_data.gate_param();
528
529		let mut inputs = Vec::new();
530		inputs.extend_from_slice(gate_param.constants);
531		inputs.extend_from_slice(gate_param.inputs);
532		inputs
533	}
534
535	fn get_gate_outputs(graph: &GateGraph, gate: Gate) -> Vec<Wire> {
536		let gate_data = &graph.gates[gate];
537		let gate_param = gate_data.gate_param();
538
539		let mut outputs = Vec::new();
540		outputs.extend_from_slice(gate_param.outputs);
541		outputs
542	}
543
544	#[test]
545	fn test_use_def_analysis() {
546		let mut graph = GateGraph::new();
547		let root = graph.path_spec_tree.root();
548
549		// Create some wires
550		let in1 = graph.add_inout();
551		let in2 = graph.add_inout();
552		let out1 = graph.add_witness();
553		let out2 = graph.add_witness();
554
555		// Create a gate that uses in1 and in2, produces out1
556		let gate1 = graph.emit_gate(root, Opcode::Bxor, vec![in1, in2], vec![out1]);
557
558		// Create another gate that uses out1 and in1, produces out2
559		let gate2 = graph.emit_gate(root, Opcode::Band, vec![out1, in1], vec![out2]);
560
561		// Build use-def chains
562		graph.rebuild_use_def_chains(&HintRegistry::new());
563
564		// Check that gate1 defines out1
565		assert_eq!(get_wire_def(&graph, out1), Some(gate1));
566
567		// Check that gate2 defines out2
568		assert_eq!(get_wire_def(&graph, out2), Some(gate2));
569
570		// Check that in1 and in2 are used by gate1
571		assert!(graph.get_wire_uses(in1).contains(&gate1));
572		assert!(graph.get_wire_uses(in2).contains(&gate1));
573
574		// Check that out1 is used by gate2
575		assert!(graph.get_wire_uses(out1).contains(&gate2));
576
577		// Check wire use counts
578		assert_eq!(wire_use_count(&graph, in1), 2); // Used by gate1 and gate2
579		assert_eq!(wire_use_count(&graph, in2), 1);
580		assert_eq!(wire_use_count(&graph, out1), 1);
581		assert_eq!(wire_use_count(&graph, out2), 0);
582
583		// Check single use queries
584		assert!(!is_wire_single_use(&graph, in1)); // Used twice
585		assert!(is_wire_single_use(&graph, in2));
586		assert!(is_wire_single_use(&graph, out1));
587		assert!(!is_wire_single_use(&graph, out2)); // No uses
588
589		// Check get_wire_single_use
590		assert_eq!(get_wire_single_use(&graph, in1), None); // Used twice
591		assert_eq!(get_wire_single_use(&graph, out1), Some(gate2));
592		assert_eq!(get_wire_single_use(&graph, out2), None); // No uses
593	}
594
595	#[test]
596	fn test_constant_use_def() {
597		let mut graph = GateGraph::new();
598		let root = graph.path_spec_tree.root();
599
600		// Create a constant wire
601		let const_wire = graph.add_constant(Word(42u64));
602		let in_wire = graph.add_inout();
603		let out = graph.add_witness();
604
605		// Create a gate that uses the constant and input wire
606		let gate = graph.emit_gate(root, Opcode::Bxor, vec![const_wire, in_wire], vec![out]);
607
608		// Build use-def chains
609		graph.rebuild_use_def_chains(&HintRegistry::new());
610
611		// Constants are not defined by gates
612		assert_eq!(get_wire_def(&graph, const_wire), None);
613
614		// But they should be tracked as used
615		assert!(graph.get_wire_uses(const_wire).contains(&gate));
616		assert_eq!(wire_use_count(&graph, const_wire), 1);
617	}
618
619	#[test]
620	fn test_rebuild_use_def_chains() {
621		let mut graph = GateGraph::new();
622		let root = graph.path_spec_tree.root();
623
624		// Create wires and gates
625		let in1 = graph.add_inout();
626		let in2 = graph.add_inout();
627		let out = graph.add_witness();
628
629		graph.emit_gate(root, Opcode::Bxor, vec![in1, in2], vec![out]);
630
631		// Clear use-def info manually (simulating corruption)
632		graph.wire_def.clear();
633		graph.wire_uses.clear();
634
635		// Verify it's cleared
636		assert_eq!(get_wire_def(&graph, out), None);
637		assert!(graph.get_wire_uses(in1).is_empty());
638
639		// Rebuild
640		graph.rebuild_use_def_chains(&HintRegistry::new());
641
642		// Verify it's restored
643		assert!(get_wire_def(&graph, out).is_some());
644		assert!(!graph.get_wire_uses(in1).is_empty());
645		assert!(!graph.get_wire_uses(in2).is_empty());
646	}
647
648	#[test]
649	fn test_gate_inputs_outputs() {
650		let mut graph = GateGraph::new();
651		let root = graph.path_spec_tree.root();
652
653		let in1 = graph.add_inout();
654		let in2 = graph.add_inout();
655		let out = graph.add_witness();
656
657		let gate = graph.emit_gate(root, Opcode::Bxor, vec![in1, in2], vec![out]);
658
659		// No need to rebuild use-def chains for this test
660		// as we're just checking the gate structure
661
662		let inputs = get_gate_inputs(&graph, gate);
663		// Bxor has 1 constant input (ALL_ONE) + 2 regular inputs
664		assert_eq!(inputs.len(), 3);
665		assert!(inputs.contains(&in1));
666		assert!(inputs.contains(&in2));
667		// First input should be the constant wire
668		let const_wire = inputs[0];
669		match graph.wires[const_wire].kind {
670			WireKind::Constant(word) => assert_eq!(word, Word::ALL_ONE),
671			_ => panic!("Expected constant wire"),
672		}
673
674		let outputs = get_gate_outputs(&graph, gate);
675		assert_eq!(outputs.len(), 1);
676		assert!(outputs.contains(&out));
677	}
678}