Skip to content

Commit

Permalink
Revert and fix "Typo fix preventing training of post scale models acr…
Browse files Browse the repository at this point in the history
…oss multiple GPUs."
  • Loading branch information
james-choncholas committed Jan 8, 2025
1 parent a5745f2 commit f7800af
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tf_shell_ml/dpsgd_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def compute_grads(self, features, enc_labels):
max_two_norms_list.append(self.jacobian_max_two_norm(jacobians))

with tf.device(self.features_party_dev):
# For some reason, when running the jacobian on an accelerator, the
# weights must be touched otherwise training loss goes to NaN. Maybe
# it is to ensure the weights are on assigned to features_party
# device for later, when the final gradient is added to weights (on
# CPU)?
tf.print(self.trainable_variables, output_stream="file:///dev/null")

predictions = tf.concat(predictions_list, axis=0)
max_two_norm = tf.reduce_max(max_two_norms_list)

Expand Down
7 changes: 7 additions & 0 deletions tf_shell_ml/postscale_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def compute_grads(self, features, enc_labels):
max_two_norms_list.append(self.jacobian_max_two_norm(jacobians))

with tf.device(self.features_party_dev):
# For some reason, when running the jacobian on an accelerator, the
# weights must be touched otherwise training loss goes to NaN. Maybe
# it is to ensure the weights are on assigned to features_party
# device for later, when the final gradient is added to weights (on
# CPU)?
tf.print(self.trainable_variables, output_stream="file:///dev/null")

predictions = tf.concat(predictions_list, axis=0)
max_two_norm = tf.reduce_max(max_two_norms_list)
jacobians = [tf.concat(j, axis=0) for j in zip(*jacobians_list)]
Expand Down

0 comments on commit f7800af

Please sign in to comment.