Skip to content

Commit

Permalink
Remove tf.function from helper method used inside `tff.tf_computati…
Browse files Browse the repository at this point in the history
…on`.

This results in the logic being inlined in the outer scope, rather than a PartitionedCall op being introduced.

PiperOrigin-RevId: 452318086
  • Loading branch information
ZacharyGarrett authored and tensorflow-copybara committed Jun 1, 2022
1 parent 9b31f9d commit b267217
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tensorflow_federated/python/aggregators/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ def _normalize_secure_quantized_sum_args(client_value, lower_bound,
return client_value, lower_bound, upper_bound


@tf.function
def _client_tensor_shift_for_secure_sum(value, lower_bound, upper_bound):
"""Mapping to be applied to every tensor before secure sum.
Expand All @@ -452,7 +451,11 @@ def _client_tensor_shift_for_secure_sum(value, lower_bound, upper_bound):
Returns:
Shifted value of dtype `tf.int64`.
"""
tf.Assert(lower_bound <= upper_bound, [lower_bound, upper_bound])
tf.debugging.assert_less_equal(
lower_bound,
upper_bound,
message='lower_bound must be smaller than upper_bound for secagg '
'quantization')
if value.dtype == tf.int32:
clipped_val = tf.clip_by_value(value, lower_bound, upper_bound)
# Cast BEFORE shift in order to avoid overflow if full int32 range is used.
Expand All @@ -462,7 +465,7 @@ def _client_tensor_shift_for_secure_sum(value, lower_bound, upper_bound):
range_span = upper_bound - lower_bound
scale_factor = tf.math.floordiv(range_span, _SECAGG_MAX) + 1
shifted_value = tf.cond(
scale_factor > 1,
tf.greater(scale_factor, 1),
lambda: tf.math.floordiv(clipped_val - lower_bound, scale_factor),
lambda: clipped_val - lower_bound)
return shifted_value
Expand Down

0 comments on commit b267217

Please sign in to comment.