Internal change
PiperOrigin-RevId: 394423967
diff --git a/dpf/distributed_point_function.h b/dpf/distributed_point_function.h
index 094cd6b..a0635eb 100644
--- a/dpf/distributed_point_function.h
+++ b/dpf/distributed_point_function.h
@@ -788,16 +788,20 @@
return correction_ints.status();
}
- // Split up evaluation_points into tree indices and block indices.
+ // Split up evaluation_points into tree indices and block indices, if we're
+ // operating on a packed type.
std::vector<absl::uint128> tree_indices;
std::vector<int> block_indices;
- tree_indices.reserve(num_evaluation_points);
- block_indices.reserve(num_evaluation_points);
- for (int64_t i = 0; i < num_evaluation_points; ++i) {
- tree_indices.push_back(
- DomainToTreeIndex(evaluation_points[i], hierarchy_level));
- block_indices.push_back(
- DomainToBlockIndex(evaluation_points[i], hierarchy_level));
+ if constexpr (elements_per_block > 1) {
+ tree_indices.reserve(num_evaluation_points);
+ block_indices.reserve(num_evaluation_points);
+ for (int64_t i = 0; i < num_evaluation_points; ++i) {
+ tree_indices.push_back(
+ DomainToTreeIndex(evaluation_points[i], hierarchy_level));
+ block_indices.push_back(
+ DomainToBlockIndex(evaluation_points[i], hierarchy_level));
+ }
+ evaluation_points = absl::MakeConstSpan(tree_indices);
}
// Extract seed and party for DPF evaluation.
@@ -812,7 +816,7 @@
auto correction_words =
absl::MakeConstSpan(key.correction_words()).subspan(0, stop_level);
absl::StatusOr<DpfExpansion> evaluated_inputs =
- EvaluateSeeds(std::move(inputs), tree_indices, correction_words);
+ EvaluateSeeds(std::move(inputs), evaluation_points, correction_words);
if (!evaluated_inputs.ok()) {
return evaluated_inputs.status();
}
@@ -836,9 +840,13 @@
absl::string_view(reinterpret_cast<const char*>(
&(*hashed_expansion)[i * blocks_needed]),
blocks_needed * sizeof(absl::uint128)));
- result.push_back(current_elements[block_indices[i]]);
+ int block_index = 0;
+ if constexpr (elements_per_block > 1) {
+ block_index = block_indices[i];
+ }
+ result.push_back(current_elements[block_index]);
if (evaluated_inputs->control_bits[i]) {
- result[i] += (*correction_ints)[block_indices[i]];
+ result[i] += (*correction_ints)[block_index];
}
if (party == 1) {
result[i] = -result[i];