binius_core/protocols/prodcheck/
prove.rs

1// Copyright 2025 Irreducible Inc.
2
3use binius_compute::{
4	ComputeLayer, ComputeLayerExecutor, ComputeMemory, FSlice, SizedSlice, SlicesBatch,
5	alloc::{BumpAllocator, ComputeAllocator, HostBumpAllocator},
6};
7use binius_field::TowerField;
8use binius_math::CompositionPoly;
9use binius_utils::{bail, checked_arithmetics::checked_log_2};
10use getset::CopyGetters;
11
12use super::common::Error;
13use crate::composition::BivariateProduct;
14
15/// The computed layer evaluations of a product tree circuit.
16#[derive(CopyGetters)]
17pub struct ProductCircuitLayers<'a, F: TowerField, DevMem: ComputeMemory<F>> {
18	layers: Vec<DevMem::FSlice<'a>>,
19	/// The product of the evaluations.
20	#[get_copy = "pub"]
21	product: F,
22}
23
24impl<'a, F, DevMem> ProductCircuitLayers<'a, F, DevMem>
25where
26	F: TowerField,
27	DevMem: ComputeMemory<F>,
28{
29	/// Computes the full sequence of GKR layers of the binary product circuit.
30	///
31	/// ## Throws
32	///
33	/// - [`Error::ExpectInputSlicePowerOfTwoLength`] unless `evals` has power-of-two length
34	pub fn compute<'dev_mem, 'host_mem, Hal>(
35		evals: FSlice<'a, F, Hal>,
36		hal: &'a Hal,
37		dev_alloc: &'a BumpAllocator<'dev_mem, F, Hal::DevMem>,
38		host_alloc: &'a HostBumpAllocator<'host_mem, F>,
39	) -> Result<Self, Error>
40	where
41		Hal: ComputeLayer<F, DevMem = DevMem>,
42	{
43		if !evals.len().is_power_of_two() {
44			bail!(Error::ExpectInputSlicePowerOfTwoLength);
45		}
46		let log_n = checked_log_2(evals.len());
47		let prod_expr =
48			hal.compile_expr(&CompositionPoly::<F>::expression(&BivariateProduct::default()))?;
49
50		let mut last_layer = evals;
51		let mut layers = Vec::with_capacity(log_n);
52		let _ = hal.execute(|exec| {
53			for i in (0..log_n).rev() {
54				let row_len = 1 << i;
55				let (lo_half, hi_half) = Hal::DevMem::split_half(last_layer);
56				let mut new_layer = dev_alloc.alloc(row_len)?;
57				exec.compute_composite(
58					&SlicesBatch::new(vec![lo_half, hi_half], row_len),
59					&mut new_layer,
60					&prod_expr,
61				)?;
62
63				layers.push(last_layer);
64				last_layer = DevMem::to_const(new_layer);
65			}
66			Ok(Vec::new())
67		})?;
68
69		let product_dst = host_alloc.alloc(1)?;
70		hal.copy_d2h(last_layer, product_dst)?;
71
72		layers.reverse();
73		let product = product_dst[0];
74		Ok(Self { layers, product })
75	}
76
77	/// Returns the layer evaluations of the product tree circuit.
78	///
79	/// The $i$'th entry has $2^{i+1}$ values.
80	pub fn layers(&self) -> &[DevMem::FSlice<'a>] {
81		&self.layers
82	}
83}