binius_math/multilinear/
fold.rs

1// Copyright 2024-2025 Irreducible Inc.
2
3use std::ops::{Deref, DerefMut};
4
5use binius_field::{Field, PackedField};
6use binius_utils::{random_access_sequence::RandomAccessSequence, rayon::prelude::*};
7
8use crate::{Error, FieldBuffer};
9
10/// Computes the partial evaluation of a multilinear on its highest variable, inplace.
11///
12/// Each scalar of the result requires one multiplication to compute. Multilinear evaluations
13/// occupy a prefix of the field buffer; scalars after the truncated length are zeroed out.
14pub fn fold_highest_var_inplace<P: PackedField, Data: DerefMut<Target = [P]>>(
15	values: &mut FieldBuffer<P, Data>,
16	scalar: P::Scalar,
17) -> Result<(), Error> {
18	let broadcast_scalar = P::broadcast(scalar);
19	values.split_half_mut(|lo, hi| {
20		(lo.as_mut(), hi.as_mut())
21			.into_par_iter()
22			.for_each(|(zero, one)| {
23				*zero += broadcast_scalar * (*one - *zero);
24			});
25	})?;
26
27	values.truncate(values.log_len() - 1);
28	Ok(())
29}
30
31/// Computes the fold high of a binary multilinear with a fold tensor.
32///
33/// Binary multilinear is represented transparently by a boolean sequence.
34/// Fold high meaning: for every hypercube vertex of the result, we specialize lower
35/// indexed variables of the binary multilinear to the vertex coordinates and take an
36/// inner product of the remaining multilinear and the tensor.
37///
38/// This method is single threaded.
39///
40/// # Throws
41///
42/// * `PowerOfTwoLengthRequired` if the bool sequence is not of power of two length.
43/// * `FoldLengthMismatch` if the tensor, result and binary multilinear lengths do not add up.
44pub fn binary_fold_high<P, DataOut, DataIn>(
45	values: &mut FieldBuffer<P, DataOut>,
46	tensor: &FieldBuffer<P, DataIn>,
47	bits: impl RandomAccessSequence<bool> + Sync,
48) -> Result<(), Error>
49where
50	P: PackedField,
51	DataOut: DerefMut<Target = [P]>,
52	DataIn: Deref<Target = [P]> + Sync,
53{
54	if !bits.len().is_power_of_two() {
55		return Err(Error::PowerOfTwoLengthRequired);
56	}
57
58	let values_log_len = values.log_len();
59	let width = P::WIDTH.min(values.len());
60
61	if 1 << (values_log_len + tensor.log_len()) != bits.len() {
62		return Err(Error::FoldLengthMismatch);
63	}
64
65	values
66		.as_mut()
67		.iter_mut()
68		.enumerate()
69		.for_each(|(i, packed)| {
70			*packed = P::from_scalars((0..width).map(|j| {
71				let scalar_index = i << P::LOG_WIDTH | j;
72				let mut acc = P::Scalar::ZERO;
73
74				for (k, tensor_packed) in tensor.as_ref().iter().enumerate() {
75					for (l, tensor_scalar) in tensor_packed.iter().take(tensor.len()).enumerate() {
76						let tensor_scalar_index = k << P::LOG_WIDTH | l;
77						if bits.get(tensor_scalar_index << values_log_len | scalar_index) {
78							acc += tensor_scalar;
79						}
80					}
81				}
82
83				acc
84			}));
85		});
86
87	Ok(())
88}
89
90#[cfg(test)]
91mod tests {
92	use std::iter::repeat_with;
93
94	use rand::prelude::*;
95
96	use super::*;
97	use crate::{
98		multilinear::{eq::eq_ind_partial_eval, evaluate::evaluate},
99		test_utils::{B128, Packed128b, random_field_buffer, random_scalars},
100	};
101
102	type P = Packed128b;
103	type F = B128;
104
105	#[test]
106	fn test_fold_highest_var_inplace() {
107		let mut rng = StdRng::seed_from_u64(0);
108
109		let n_vars = 10;
110
111		let point = random_scalars::<F>(&mut rng, n_vars);
112		let mut multilinear = random_field_buffer::<P>(&mut rng, n_vars);
113
114		let eval = evaluate(&multilinear, &point).unwrap();
115
116		for &scalar in point.iter().rev() {
117			fold_highest_var_inplace(&mut multilinear, scalar).unwrap();
118		}
119
120		assert_eq!(multilinear.get(0).unwrap(), eval);
121	}
122
123	fn test_binary_fold_high_conforms_to_regular_fold_high_helper(
124		n_vars: usize,
125		tensor_n_vars: usize,
126	) {
127		let mut rng = StdRng::seed_from_u64(0);
128
129		let point = random_scalars::<F>(&mut rng, tensor_n_vars);
130
131		let tensor = eq_ind_partial_eval::<P>(&point);
132
133		let bits = repeat_with(|| rng.random())
134			.take(1 << n_vars)
135			.collect::<Vec<bool>>();
136
137		let bits_scalars = bits
138			.iter()
139			.map(|&b| if b { F::ONE } else { F::ZERO })
140			.collect::<Vec<F>>();
141
142		let mut bits_buffer = FieldBuffer::<P>::from_values(&bits_scalars).unwrap();
143
144		let mut binary_fold_result = FieldBuffer::<P>::zeros(n_vars - tensor_n_vars);
145		binary_fold_high(&mut binary_fold_result, &tensor, bits.as_slice()).unwrap();
146
147		for &scalar in point.iter().rev() {
148			fold_highest_var_inplace(&mut bits_buffer, scalar).unwrap();
149		}
150
151		assert_eq!(bits_buffer, binary_fold_result);
152	}
153
154	#[test]
155	fn test_binary_fold_high_conforms_to_regular_fold_high() {
156		for (n_vars, tensor_n_vars) in [(2, 0), (2, 1), (4, 4), (10, 3)] {
157			test_binary_fold_high_conforms_to_regular_fold_high_helper(n_vars, tensor_n_vars)
158		}
159	}
160}