binius_math/multilinear/
fold.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{Deref, DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::{random_access_sequence::RandomAccessSequence, rayon::prelude::*};
7
8use crate::{Error, FieldBuffer};
9
10/// Computes the partial evaluation of a multilinear on its highest variable, inplace.
11///
12/// Each scalar of the result requires one multiplication to compute. Multilinear evaluations
13/// occupy a prefix of the field buffer; scalars after the truncated length are zeroed out.
14pub fn fold_highest_var_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
15	values: &mut FieldBuffer<P, Data>,
16	scalar: P::Scalar,
17) -> Result<(), Error> {
18	let broadcast_scalar = P::broadcast(scalar);
19	{
20		let mut split = values.split_half_mut()?;
21		let (mut lo, mut hi) = split.halves();
22		(lo.as_mut(), hi.as_mut())
23			.into_par_iter()
24			.for_each(|(zero, one)| {
25				*zero += broadcast_scalar * (*one - *zero);
26			});
27	}
28
29	values.truncate(values.log_len() - 1);
30	Ok(())
31}
32
33/// Computes the fold high of a binary multilinear with a fold tensor.
34///
35/// Binary multilinear is represented transparently by a boolean sequence.
36/// Fold high meaning: for every hypercube vertex of the result, we specialize lower
37/// indexed variables of the binary multilinear to the vertex coordinates and take an
38/// inner product of the remaining multilinear and the tensor.
39///
40/// This method is single threaded.
41///
42/// # Throws
43///
44/// * `PowerOfTwoLengthRequired` if the bool sequence is not of power of two length.
45/// * `FoldLengthMismatch` if the tensor, result and binary multilinear lengths do not add up.
46pub fn binary_fold_high<P, DataOut, DataIn>(
47	values: &mut FieldBuffer<P, DataOut>,
48	tensor: &FieldBuffer<P, DataIn>,
49	bits: impl RandomAccessSequence<bool> + Sync,
50) -> Result<(), Error>
51where
52	P: PackedField,
53	DataOut: DerefMut<Target = [P]>,
54	DataIn: Deref<Target = [P]> + Sync,
55{
56	if !bits.len().is_power_of_two() {
57		return Err(Error::PowerOfTwoLengthRequired);
58	}
59
60	let values_log_len = values.log_len();
61	let width = P::WIDTH.min(values.len());
62
63	if 1 << (values_log_len + tensor.log_len()) != bits.len() {
64		return Err(Error::FoldLengthMismatch);
65	}
66
67	values
68		.as_mut()
69		.iter_mut()
70		.enumerate()
71		.for_each(|(i, packed)| {
72			*packed = P::from_scalars((0..width).map(|j| {
73				let scalar_index = i << P::LOG_WIDTH | j;
74				let mut acc = P::Scalar::ZERO;
75
76				for (k, tensor_packed) in tensor.as_ref().iter().enumerate() {
77					for (l, tensor_scalar) in tensor_packed.iter().take(tensor.len()).enumerate() {
78						let tensor_scalar_index = k << P::LOG_WIDTH | l;
79						if bits.get(tensor_scalar_index << values_log_len | scalar_index) {
80							acc += tensor_scalar;
81						}
82					}
83				}
84
85				acc
86			}));
87		});
88
89	Ok(())
90}
91
92#[cfg(test)]
93mod tests {
94	use std::iter::repeat_with;
95
96	use rand::prelude::*;
97
98	use super::*;
99	use crate::{
100		multilinear::{eq::eq_ind_partial_eval, evaluate::evaluate},
101		test_utils::{B128, Packed128b, random_field_buffer, random_scalars},
102	};
103
104	type P = Packed128b;
105	type F = B128;
106
107	#[test]
108	fn test_fold_highest_var_inplace() {
109		let mut rng = StdRng::seed_from_u64(0);
110
111		let n_vars = 10;
112
113		let point = random_scalars::<F>(&mut rng, n_vars);
114		let mut multilinear = random_field_buffer::<P>(&mut rng, n_vars);
115
116		let eval = evaluate(&multilinear, &point).unwrap();
117
118		for &scalar in point.iter().rev() {
119			fold_highest_var_inplace(&mut multilinear, scalar).unwrap();
120		}
121
122		assert_eq!(multilinear.get_checked(0).unwrap(), eval);
123	}
124
125	fn test_binary_fold_high_conforms_to_regular_fold_high_helper(
126		n_vars: usize,
127		tensor_n_vars: usize,
128	) {
129		let mut rng = StdRng::seed_from_u64(0);
130
131		let point = random_scalars::<F>(&mut rng, tensor_n_vars);
132
133		let tensor = eq_ind_partial_eval::<P>(&point);
134
135		let bits = repeat_with(|| rng.random())
136			.take(1 << n_vars)
137			.collect::<Vec<bool>>();
138
139		let bits_scalars = bits
140			.iter()
141			.map(|&b| if b { F::ONE } else { F::ZERO })
142			.collect::<Vec<F>>();
143
144		let mut bits_buffer = FieldBuffer::<P>::from_values(&bits_scalars).unwrap();
145
146		let mut binary_fold_result = FieldBuffer::<P>::zeros(n_vars - tensor_n_vars);
147		binary_fold_high(&mut binary_fold_result, &tensor, bits.as_slice()).unwrap();
148
149		for &scalar in point.iter().rev() {
150			fold_highest_var_inplace(&mut bits_buffer, scalar).unwrap();
151		}
152
153		assert_eq!(bits_buffer, binary_fold_result);
154	}
155
156	#[test]
157	fn test_binary_fold_high_conforms_to_regular_fold_high() {
158		for (n_vars, tensor_n_vars) in [(2, 0), (2, 1), (4, 4), (10, 3)] {
159			test_binary_fold_high_conforms_to_regular_fold_high_helper(n_vars, tensor_n_vars)
160		}
161	}
162}