1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// Copyright 2024 Ulvetanna Inc.

use crate::Error;
use binius_field::PackedField;
use binius_utils::bail;
use rayon::prelude::*;
use std::cmp::max;

/// Tensor Product expansion of values with partial eq indicator evaluated at extra_query_coordinates
///
/// Let $n$ be log_n_values, $p$, $k$ be the lengths of `packed_values` and `extra_query_coordinates`.
/// Requires
///     * $n \geq k$
///     * p = max(1, 2^{n+k} / P::WIDTH)
/// Let $v$ be a vector corresponding to the first $2^n$ scalar values of `values`.
/// Let $r = (r_0, \ldots, r_{k-1})$ be the vector of `extra_query_coordinates`.
///
/// # Formal Definition
/// `values` is updated to contain the result of:
/// $v \otimes (1 - r_0, r_0) \otimes \ldots \otimes (1 - r_{k-1}, r_{k-1})$
/// which is now a vector of length $2^{n+k}$. If 2^{n+k} < P::WIDTH, then
/// the result is packed into a single element of `values` where only the first
/// 2^{n+k} elements have meaning.
///
/// # Interpretation
/// Let $f$ be an $n$ variate multilinear polynomial that has evaluations over
/// the $n$ dimensional hypercube corresponding to $v$.
/// Then `values` is updated to contain the evaluations of $g$ over the $n+k$-dimensional
/// hypercube where
/// * $g(x_0, \ldots, x_{n+k-1}) = f(x_0, \ldots, x_{n-1}) * eq(x_n, \ldots, x_{n+k-1}, r)$
pub fn tensor_prod_eq_ind<P: PackedField>(
	log_n_values: usize,
	packed_values: &mut [P],
	extra_query_coordinates: &[P::Scalar],
) -> Result<(), Error> {
	let new_n_vars = log_n_values + extra_query_coordinates.len();
	if packed_values.len() != max(1, (1 << new_n_vars) / P::WIDTH) {
		bail!(Error::InvalidPackedValuesLength);
	}

	for (i, r_i) in extra_query_coordinates.iter().enumerate() {
		let prev_length = 1 << (log_n_values + i);
		if prev_length < P::WIDTH {
			let q = &mut packed_values[0];
			for h in 0..prev_length {
				let x = q.get(h);
				let prod = x * r_i;
				q.set(h, x - prod);
				q.set(prev_length | h, prod);
			}
		} else {
			let prev_packed_length = prev_length / P::WIDTH;
			let packed_r_i = P::broadcast(*r_i);
			let (xs, ys) = packed_values.split_at_mut(prev_packed_length);
			assert!(xs.len() <= ys.len());

			// These magic numbers were chosen experimentally to have a reasonable performance
			// for the calls with small number of elements.
			xs.par_iter_mut()
				.zip(ys.par_iter_mut())
				.with_min_len(64)
				.for_each(|(x, y): (&mut P, &mut P)| {
					// x = x * (1 - packed_r_i) = x - x * packed_r_i
					// y = x * packed_r_i
					// Notice that we can reuse the multiplication: (x * packed_r_i)
					let prod = (*x) * packed_r_i;
					*x -= prod;
					*y = prod;
				});
		}
	}
	Ok(())
}

#[cfg(test)]
mod tests {
	use super::*;
	use binius_field::{
		packed::{iter_packed_slice, set_packed_slice},
		Field, PackedBinaryField4x32b,
	};
	use itertools::Itertools;

	#[test]
	fn test_tensor_prod_eq_ind() {
		type P = PackedBinaryField4x32b;
		type F = <P as PackedField>::Scalar;
		let v0 = F::new(1);
		let v1 = F::new(2);
		let query = vec![v0, v1];
		let mut result = vec![P::default(); 1 << (query.len() - P::LOG_WIDTH)];
		set_packed_slice(&mut result, 0, F::ONE);
		tensor_prod_eq_ind(0, &mut result, &query).unwrap();
		let result = iter_packed_slice(&result).collect_vec();
		assert_eq!(
			result,
			vec![
				(F::ONE - v0) * (F::ONE - v1),
				v0 * (F::ONE - v1),
				(F::ONE - v0) * v1,
				v0 * v1
			]
		);
	}
}