1use std::collections::{HashMap, HashSet};
3
4use binius_core::word::Word;
5use cranelift_entity::{PrimaryMap, SecondaryMap, entity_impl};
6
7use crate::compiler::{
8 gate::opcode::{Opcode, OpcodeShape},
9 hints::{HintId, HintRegistry},
10 pathspec::{PathSpec, PathSpecTree},
11};
12
13#[derive(Default)]
14pub struct ConstPool {
15 pub pool: HashMap<Word, Wire>,
16}
17
18impl ConstPool {
19 pub fn new() -> Self {
20 ConstPool::default()
21 }
22
23 pub fn get(&self, value: Word) -> Option<Wire> {
24 self.pool.get(&value).cloned()
25 }
26
27 pub fn insert(&mut self, word: Word, wire: Wire) {
28 let prev = self.pool.insert(word, wire);
29 assert!(prev.is_none());
30 }
31}
32
33#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
38pub struct Wire(u32);
39entity_impl!(Wire);
40
41#[derive(Copy, Clone, Debug)]
42pub enum WireKind {
43 Constant(Word),
44 Inout,
45 Witness,
46 Internal,
48 Scratch,
50}
51impl WireKind {
52 pub fn is_const(&self) -> bool {
54 matches!(self, WireKind::Constant(_))
55 }
56}
57
58#[derive(Copy, Clone)]
59pub struct WireData {
60 pub kind: WireKind,
61}
62
63#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
65pub struct Gate(u32);
66
67entity_impl!(Gate);
68
69pub struct GateParam<'a> {
71 pub constants: &'a [Wire],
72 pub inputs: &'a [Wire],
73 pub outputs: &'a [Wire],
74 pub aux: &'a [Wire],
75 pub scratch: &'a [Wire],
76 pub imm: &'a [u32],
77}
78
79pub struct GateData {
82 pub opcode: Opcode,
84
85 pub wires: Vec<Wire>,
97
98 pub immediates: Vec<u32>,
105
106 pub dimensions: Vec<usize>,
112}
113
114impl GateData {
115 pub fn gate_param(&self) -> GateParam<'_> {
121 self.gate_param_for_shape(self.opcode.shape(&self.dimensions))
122 }
123
124 pub fn gate_param_with_registry(&self, registry: &HintRegistry) -> GateParam<'_> {
127 self.gate_param_for_shape(self.shape(registry))
128 }
129
130 fn gate_param_for_shape(&self, shape: OpcodeShape) -> GateParam<'_> {
131 let start_const = 0;
132 let end_const = shape.const_in.len();
133 let start_input = end_const;
134 let end_input = start_input + shape.n_in;
135 let start_output = end_input;
136 let end_output = start_output + shape.n_out;
137 let start_aux = end_output;
138 let end_aux = start_aux + shape.n_aux;
139 let start_scratch = end_aux;
140 let end_scratch = start_scratch + shape.n_scratch;
141 GateParam {
142 constants: &self.wires[start_const..end_const],
143 inputs: &self.wires[start_input..end_input],
144 outputs: &self.wires[start_output..end_output],
145 aux: &self.wires[start_aux..end_aux],
146 scratch: &self.wires[start_scratch..end_scratch],
147 imm: &self.immediates,
148 }
149 }
150
151 pub fn shape(&self, registry: &HintRegistry) -> OpcodeShape {
156 match self.opcode {
157 Opcode::Hint => {
158 let hint_id = self.immediates[0];
159 let (n_in, n_out) = registry.shape(hint_id, &self.dimensions);
160 OpcodeShape {
161 const_in: &[],
162 n_in,
163 n_out,
164 n_aux: 0,
165 n_scratch: 0,
166 n_imm: 1,
167 }
168 }
169 _ => self.opcode.shape(&self.dimensions),
170 }
171 }
172
173 pub fn validate_shape(&self, registry: &HintRegistry) {
175 let shape = self.shape(registry);
176 let expected_wires =
177 shape.const_in.len() + shape.n_in + shape.n_out + shape.n_aux + shape.n_scratch;
178 assert_eq!(self.wires.len(), expected_wires);
179 assert_eq!(self.immediates.len(), shape.n_imm);
180 }
181}
182
183pub struct GateGraph {
185 pub gates: PrimaryMap<Gate, GateData>,
187 pub wires: PrimaryMap<Wire, WireData>,
188
189 pub path_spec_tree: PathSpecTree,
190 pub gate_origin: SecondaryMap<Gate, PathSpec>,
191 pub assertion_names: SecondaryMap<Gate, PathSpec>,
192
193 pub const_pool: ConstPool,
194 pub n_witness: usize,
195 pub n_inout: usize,
196
197 pub wire_def: SecondaryMap<Wire, Option<Gate>>,
200 wire_uses: SecondaryMap<Wire, HashSet<Gate>>,
202}
203
204impl GateGraph {
205 pub fn new() -> Self {
206 let path_spec_tree = PathSpecTree::new();
207 let root = path_spec_tree.root();
208 Self {
209 gates: PrimaryMap::new(),
210 wires: PrimaryMap::new(),
211 path_spec_tree,
212 gate_origin: SecondaryMap::with_default(root),
213 assertion_names: SecondaryMap::with_default(root),
214 const_pool: ConstPool::new(),
215 n_witness: 0,
216 n_inout: 0,
217 wire_def: SecondaryMap::new(),
218 wire_uses: SecondaryMap::new(),
219 }
220 }
221
222 pub fn validate(&self, hint_registry: &HintRegistry) {
224 for gate in self.gates.values() {
226 gate.validate_shape(hint_registry);
227 }
228 }
229
230 pub fn add_inout(&mut self) -> Wire {
231 self.n_inout += 1;
232 self.wires.push(WireData {
233 kind: WireKind::Inout,
234 })
235 }
236
237 pub fn add_witness(&mut self) -> Wire {
238 self.n_witness += 1;
239 self.wires.push(WireData {
240 kind: WireKind::Witness,
241 })
242 }
243
244 pub fn add_internal(&mut self) -> Wire {
245 self.n_witness += 1;
247 self.wires.push(WireData {
248 kind: WireKind::Internal,
249 })
250 }
251
252 pub fn add_scratch(&mut self) -> Wire {
253 self.wires.push(WireData {
255 kind: WireKind::Scratch,
256 })
257 }
258
259 pub fn add_constant(&mut self, word: Word) -> Wire {
260 if let Some(wire) = self.const_pool.get(word) {
261 return wire;
262 }
263 let wire = self.wires.push(WireData {
264 kind: WireKind::Constant(word),
265 });
266 self.const_pool.insert(word, wire);
267 wire
268 }
269
270 pub fn emit_gate(
272 &mut self,
273 gate_origin: PathSpec,
274 opcode: Opcode,
275 inputs: impl IntoIterator<Item = Wire>,
276 outputs: impl IntoIterator<Item = Wire>,
277 ) -> Gate {
278 self.emit_gate_generic(gate_origin, opcode, inputs, outputs, &[], &[])
279 }
280
281 pub fn emit_gate_imm(
283 &mut self,
284 gate_origin: PathSpec,
285 opcode: Opcode,
286 inputs: impl IntoIterator<Item = Wire>,
287 outputs: impl IntoIterator<Item = Wire>,
288 imm32: u32,
289 ) -> Gate {
290 self.emit_gate_generic(gate_origin, opcode, inputs, outputs, &[], &[imm32])
291 }
292
293 pub fn emit_gate_generic(
298 &mut self,
299 gate_origin: PathSpec,
300 opcode: Opcode,
301 inputs: impl IntoIterator<Item = Wire>,
302 outputs: impl IntoIterator<Item = Wire>,
303 dimensions: &[usize],
304 immediates: &[u32],
305 ) -> Gate {
306 assert!(
309 opcode != Opcode::Hint,
310 "emit_gate_generic does not handle Opcode::Hint; use emit_hint_gate"
311 );
312 let shape = opcode.shape(dimensions);
313 let mut wires: Vec<Wire> = Vec::with_capacity(
314 shape.const_in.len() + shape.n_in + shape.n_out + shape.n_aux + shape.n_scratch,
315 );
316 for c in shape.const_in {
317 wires.push(self.add_constant(*c));
318 }
319 wires.extend(inputs);
320 wires.extend(outputs);
321 for _ in 0..shape.n_aux {
322 wires.push(self.add_internal());
324 }
325 for _ in 0..shape.n_scratch {
326 wires.push(self.add_scratch());
327 }
328 let data = GateData {
329 opcode,
330 wires,
331 dimensions: dimensions.to_vec(),
332 immediates: immediates.to_vec(),
333 };
334 let expected_wires =
336 shape.const_in.len() + shape.n_in + shape.n_out + shape.n_aux + shape.n_scratch;
337 assert_eq!(data.wires.len(), expected_wires);
338 assert_eq!(data.immediates.len(), shape.n_imm);
339
340 let gate = self.gates.push(data);
341
342 self.gate_origin[gate] = gate_origin;
343
344 gate
345 }
346
347 pub fn emit_hint_gate(
351 &mut self,
352 gate_origin: PathSpec,
353 hint_id: HintId,
354 dimensions: &[usize],
355 inputs: impl IntoIterator<Item = Wire>,
356 outputs: impl IntoIterator<Item = Wire>,
357 ) -> Gate {
358 let mut wires: Vec<Wire> = Vec::new();
359 wires.extend(inputs);
360 wires.extend(outputs);
361 let data = GateData {
362 opcode: Opcode::Hint,
363 wires,
364 dimensions: dimensions.to_vec(),
365 immediates: vec![hint_id],
366 };
367 let gate = self.gates.push(data);
368 self.gate_origin[gate] = gate_origin;
369 gate
370 }
371
372 fn update_use_def_for_gate(&mut self, gate: Gate, hint_registry: &HintRegistry) {
374 let gate_data = &self.gates[gate];
375 let gate_param = gate_data.gate_param_with_registry(hint_registry);
376
377 for &output_wire in gate_param.outputs {
379 self.wire_def[output_wire] = Some(gate);
380 }
381
382 for &aux_wire in gate_param.aux {
384 self.wire_def[aux_wire] = Some(gate);
385 }
386
387 for &input_wire in gate_param.inputs {
389 self.wire_uses[input_wire].insert(gate);
390 }
391
392 for &const_wire in gate_param.constants {
394 self.wire_uses[const_wire].insert(gate);
395 }
396 }
397
398 pub fn rebuild_use_def_chains(&mut self, hint_registry: &HintRegistry) {
402 self.wire_def.clear();
404 self.wire_uses.clear();
405
406 for gate in self.gates.keys() {
408 self.update_use_def_for_gate(gate, hint_registry);
409 }
410 }
411
412 pub fn get_wire_uses(&self, wire: Wire) -> &HashSet<Gate> {
414 &self.wire_uses[wire]
415 }
416
417 pub fn iter_const_wires(&self) -> impl Iterator<Item = (Wire, &WireData)> {
419 self.wires
420 .iter()
421 .filter(|(wire, _)| self.wire_data(*wire).kind.is_const())
422 }
423
424 pub fn wire_data(&self, wire: Wire) -> &WireData {
426 &self.wires[wire]
427 }
428
429 pub fn gate_data(&self, gate: Gate) -> &GateData {
431 &self.gates[gate]
432 }
433
434 pub fn replace_gate_wire(&mut self, gate: Gate, old_wire: Wire, new_wire: Wire) {
436 let gate_data = &mut self.gates[gate];
437 for wire in &mut gate_data.wires {
438 if *wire == old_wire {
439 *wire = new_wire;
440 }
441 }
442 }
443
444 pub fn update_wire_use(&mut self, old_wire: Wire, new_wire: Wire, gate: Gate) {
446 self.wire_uses[old_wire].remove(&gate);
447 self.wire_uses[new_wire].insert(gate);
448 }
449
450 pub fn replace_wire_with_constant(
456 &mut self,
457 old_wire: Wire,
458 value: Word,
459 hint_registry: &HintRegistry,
460 ) -> (Wire, usize, Vec<Gate>) {
461 let const_wire = self.add_constant(value);
462
463 if const_wire == old_wire {
464 return (const_wire, 0, Vec::new());
465 }
466
467 let users: Vec<Gate> = self.get_wire_uses(old_wire).iter().copied().collect();
469 let mut total_replacements = 0;
470
471 for user_gate in &users {
473 let gate_data = self.gate_data(*user_gate);
475 let gate_param = gate_data.gate_param_with_registry(hint_registry);
476 let replacements_in_gate = gate_param.inputs.iter().filter(|&&w| w == old_wire).count()
477 + gate_param
478 .outputs
479 .iter()
480 .filter(|&&w| w == old_wire)
481 .count();
482 total_replacements += replacements_in_gate;
483
484 self.replace_gate_wire(*user_gate, old_wire, const_wire);
485 self.update_wire_use(old_wire, const_wire, *user_gate);
486 }
487
488 (const_wire, total_replacements, users)
489 }
490}
491
492impl Default for GateGraph {
493 fn default() -> Self {
494 Self::new()
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use crate::compiler::gate::opcode::Opcode;
502
503 fn get_wire_def(graph: &GateGraph, wire: Wire) -> Option<Gate> {
505 graph.wire_def[wire]
506 }
507
508 fn wire_use_count(graph: &GateGraph, wire: Wire) -> usize {
509 graph.wire_uses[wire].len()
510 }
511
512 fn is_wire_single_use(graph: &GateGraph, wire: Wire) -> bool {
513 graph.wire_uses[wire].len() == 1
514 }
515
516 fn get_wire_single_use(graph: &GateGraph, wire: Wire) -> Option<Gate> {
517 let uses = &graph.wire_uses[wire];
518 if uses.len() == 1 {
519 uses.iter().next().copied()
520 } else {
521 None
522 }
523 }
524
525 fn get_gate_inputs(graph: &GateGraph, gate: Gate) -> Vec<Wire> {
526 let gate_data = &graph.gates[gate];
527 let gate_param = gate_data.gate_param();
528
529 let mut inputs = Vec::new();
530 inputs.extend_from_slice(gate_param.constants);
531 inputs.extend_from_slice(gate_param.inputs);
532 inputs
533 }
534
535 fn get_gate_outputs(graph: &GateGraph, gate: Gate) -> Vec<Wire> {
536 let gate_data = &graph.gates[gate];
537 let gate_param = gate_data.gate_param();
538
539 let mut outputs = Vec::new();
540 outputs.extend_from_slice(gate_param.outputs);
541 outputs
542 }
543
544 #[test]
545 fn test_use_def_analysis() {
546 let mut graph = GateGraph::new();
547 let root = graph.path_spec_tree.root();
548
549 let in1 = graph.add_inout();
551 let in2 = graph.add_inout();
552 let out1 = graph.add_witness();
553 let out2 = graph.add_witness();
554
555 let gate1 = graph.emit_gate(root, Opcode::Bxor, vec![in1, in2], vec![out1]);
557
558 let gate2 = graph.emit_gate(root, Opcode::Band, vec![out1, in1], vec![out2]);
560
561 graph.rebuild_use_def_chains(&HintRegistry::new());
563
564 assert_eq!(get_wire_def(&graph, out1), Some(gate1));
566
567 assert_eq!(get_wire_def(&graph, out2), Some(gate2));
569
570 assert!(graph.get_wire_uses(in1).contains(&gate1));
572 assert!(graph.get_wire_uses(in2).contains(&gate1));
573
574 assert!(graph.get_wire_uses(out1).contains(&gate2));
576
577 assert_eq!(wire_use_count(&graph, in1), 2); assert_eq!(wire_use_count(&graph, in2), 1);
580 assert_eq!(wire_use_count(&graph, out1), 1);
581 assert_eq!(wire_use_count(&graph, out2), 0);
582
583 assert!(!is_wire_single_use(&graph, in1)); assert!(is_wire_single_use(&graph, in2));
586 assert!(is_wire_single_use(&graph, out1));
587 assert!(!is_wire_single_use(&graph, out2)); assert_eq!(get_wire_single_use(&graph, in1), None); assert_eq!(get_wire_single_use(&graph, out1), Some(gate2));
592 assert_eq!(get_wire_single_use(&graph, out2), None); }
594
595 #[test]
596 fn test_constant_use_def() {
597 let mut graph = GateGraph::new();
598 let root = graph.path_spec_tree.root();
599
600 let const_wire = graph.add_constant(Word(42u64));
602 let in_wire = graph.add_inout();
603 let out = graph.add_witness();
604
605 let gate = graph.emit_gate(root, Opcode::Bxor, vec![const_wire, in_wire], vec![out]);
607
608 graph.rebuild_use_def_chains(&HintRegistry::new());
610
611 assert_eq!(get_wire_def(&graph, const_wire), None);
613
614 assert!(graph.get_wire_uses(const_wire).contains(&gate));
616 assert_eq!(wire_use_count(&graph, const_wire), 1);
617 }
618
619 #[test]
620 fn test_rebuild_use_def_chains() {
621 let mut graph = GateGraph::new();
622 let root = graph.path_spec_tree.root();
623
624 let in1 = graph.add_inout();
626 let in2 = graph.add_inout();
627 let out = graph.add_witness();
628
629 graph.emit_gate(root, Opcode::Bxor, vec![in1, in2], vec![out]);
630
631 graph.wire_def.clear();
633 graph.wire_uses.clear();
634
635 assert_eq!(get_wire_def(&graph, out), None);
637 assert!(graph.get_wire_uses(in1).is_empty());
638
639 graph.rebuild_use_def_chains(&HintRegistry::new());
641
642 assert!(get_wire_def(&graph, out).is_some());
644 assert!(!graph.get_wire_uses(in1).is_empty());
645 assert!(!graph.get_wire_uses(in2).is_empty());
646 }
647
648 #[test]
649 fn test_gate_inputs_outputs() {
650 let mut graph = GateGraph::new();
651 let root = graph.path_spec_tree.root();
652
653 let in1 = graph.add_inout();
654 let in2 = graph.add_inout();
655 let out = graph.add_witness();
656
657 let gate = graph.emit_gate(root, Opcode::Bxor, vec![in1, in2], vec![out]);
658
659 let inputs = get_gate_inputs(&graph, gate);
663 assert_eq!(inputs.len(), 3);
665 assert!(inputs.contains(&in1));
666 assert!(inputs.contains(&in2));
667 let const_wire = inputs[0];
669 match graph.wires[const_wire].kind {
670 WireKind::Constant(word) => assert_eq!(word, Word::ALL_ONE),
671 _ => panic!("Expected constant wire"),
672 }
673
674 let outputs = get_gate_outputs(&graph, gate);
675 assert_eq!(outputs.len(), 1);
676 assert!(outputs.contains(&out));
677 }
678}