binius_math/multilinear/
fold.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{Deref, DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::{random_access_sequence::RandomAccessSequence, rayon::prelude::*};
7
8use crate::FieldBuffer;
9
10/// Computes the partial evaluation of a multilinear on its highest variable, inplace.
11///
12/// Each scalar of the result requires one multiplication to compute. Multilinear evaluations
13/// occupy a prefix of the field buffer; scalars after the truncated length are zeroed out.
14///
15/// ## Preconditions
16///
17/// * `values.log_len() >= 1` (buffer must have at least 2 elements)
18pub fn fold_highest_var_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
19	values: &mut FieldBuffer<P, Data>,
20	scalar: P::Scalar,
21) {
22	let broadcast_scalar = P::broadcast(scalar);
23	{
24		let mut split = values.split_half_mut();
25		let (mut lo, mut hi) = split.halves();
26		(lo.as_mut(), hi.as_mut())
27			.into_par_iter()
28			.for_each(|(zero, one)| {
29				*zero += broadcast_scalar * (*one - *zero);
30			});
31	}
32
33	values.truncate(values.log_len() - 1);
34}
35
36/// Computes the fold high of a binary multilinear with a fold tensor.
37///
38/// Binary multilinear is represented transparently by a boolean sequence.
39/// Fold high meaning: for every hypercube vertex of the result, we specialize lower
40/// indexed variables of the binary multilinear to the vertex coordinates and take an
41/// inner product of the remaining multilinear and the tensor.
42///
43/// This method is single threaded.
44///
45/// ## Preconditions
46///
47/// * `bits.len()` must be a power of two
48/// * `bits.len()` must equal `values.len() * tensor.len()`
49pub fn binary_fold_high<P, DataOut, DataIn>(
50	values: &mut FieldBuffer<P, DataOut>,
51	tensor: &FieldBuffer<P, DataIn>,
52	bits: impl RandomAccessSequence<bool> + Sync,
53) where
54	P: PackedField,
55	DataOut: DerefMut<Target = [P]>,
56	DataIn: Deref<Target = [P]> + Sync,
57{
58	assert!(bits.len().is_power_of_two(), "precondition: bits length must be a power of two");
59
60	let values_log_len = values.log_len();
61	let width = P::WIDTH.min(values.len());
62
63	assert_eq!(
64		1 << (values_log_len + tensor.log_len()),
65		bits.len(),
66		"precondition: bits length must equal values length times tensor length"
67	);
68
69	values
70		.as_mut()
71		.iter_mut()
72		.enumerate()
73		.for_each(|(i, packed)| {
74			*packed = P::from_scalars((0..width).map(|j| {
75				let scalar_index = i << P::LOG_WIDTH | j;
76				let mut acc = P::Scalar::ZERO;
77
78				for (k, tensor_packed) in tensor.as_ref().iter().enumerate() {
79					for (l, tensor_scalar) in tensor_packed.iter().take(tensor.len()).enumerate() {
80						let tensor_scalar_index = k << P::LOG_WIDTH | l;
81						if bits.get(tensor_scalar_index << values_log_len | scalar_index) {
82							acc += tensor_scalar;
83						}
84					}
85				}
86
87				acc
88			}));
89		});
90}
91
92#[cfg(test)]
93mod tests {
94	use std::iter::repeat_with;
95
96	use rand::prelude::*;
97
98	use super::*;
99	use crate::{
100		multilinear::{eq::eq_ind_partial_eval, evaluate::evaluate},
101		test_utils::{B128, Packed128b, random_field_buffer, random_scalars},
102	};
103
104	type P = Packed128b;
105	type F = B128;
106
107	#[test]
108	fn test_fold_highest_var_inplace() {
109		let mut rng = StdRng::seed_from_u64(0);
110
111		let n_vars = 10;
112
113		let point = random_scalars::<F>(&mut rng, n_vars);
114		let mut multilinear = random_field_buffer::<P>(&mut rng, n_vars);
115
116		let eval = evaluate(&multilinear, &point);
117
118		for &scalar in point.iter().rev() {
119			fold_highest_var_inplace(&mut multilinear, scalar);
120		}
121
122		assert_eq!(multilinear.get(0), eval);
123	}
124
125	fn test_binary_fold_high_conforms_to_regular_fold_high_helper(
126		n_vars: usize,
127		tensor_n_vars: usize,
128	) {
129		let mut rng = StdRng::seed_from_u64(0);
130
131		let point = random_scalars::<F>(&mut rng, tensor_n_vars);
132
133		let tensor = eq_ind_partial_eval::<P>(&point);
134
135		let bits = repeat_with(|| rng.random())
136			.take(1 << n_vars)
137			.collect::<Vec<bool>>();
138
139		let bits_scalars = bits
140			.iter()
141			.map(|&b| if b { F::ONE } else { F::ZERO })
142			.collect::<Vec<F>>();
143
144		let mut bits_buffer = FieldBuffer::<P>::from_values(&bits_scalars);
145
146		let mut binary_fold_result = FieldBuffer::<P>::zeros(n_vars - tensor_n_vars);
147		binary_fold_high(&mut binary_fold_result, &tensor, bits.as_slice());
148
149		for &scalar in point.iter().rev() {
150			fold_highest_var_inplace(&mut bits_buffer, scalar);
151		}
152
153		assert_eq!(bits_buffer, binary_fold_result);
154	}
155
156	#[test]
157	fn test_binary_fold_high_conforms_to_regular_fold_high() {
158		for (n_vars, tensor_n_vars) in [(2, 0), (2, 1), (4, 4), (10, 3)] {
159			test_binary_fold_high_conforms_to_regular_fold_high_helper(n_vars, tensor_n_vars)
160		}
161	}
162}