From f7800afd716f09db31311fc9b85a24679136c78b Mon Sep 17 00:00:00 2001 From: james-choncholas Date: Wed, 8 Jan 2025 00:52:58 +0000 Subject: [PATCH] Revert and fix "Typo fix preventing training of post scale models across multiple GPUs." --- tf_shell_ml/dpsgd_sequential_model.py | 7 +++++++ tf_shell_ml/postscale_sequential_model.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/tf_shell_ml/dpsgd_sequential_model.py b/tf_shell_ml/dpsgd_sequential_model.py index 1458120..558ce59 100644 --- a/tf_shell_ml/dpsgd_sequential_model.py +++ b/tf_shell_ml/dpsgd_sequential_model.py @@ -76,6 +76,13 @@ def compute_grads(self, features, enc_labels): max_two_norms_list.append(self.jacobian_max_two_norm(jacobians)) with tf.device(self.features_party_dev): + # For some reason, when running the jacobian on an accelerator, the + # weights must be touched otherwise training loss goes to NaN. Maybe + # it is to ensure the weights are on assigned to features_party + # device for later, when the final gradient is added to weights (on + # CPU)? + tf.print(self.trainable_variables, output_stream="file:///dev/null") + predictions = tf.concat(predictions_list, axis=0) max_two_norm = tf.reduce_max(max_two_norms_list) diff --git a/tf_shell_ml/postscale_sequential_model.py b/tf_shell_ml/postscale_sequential_model.py index b36a823..1b5c0d0 100644 --- a/tf_shell_ml/postscale_sequential_model.py +++ b/tf_shell_ml/postscale_sequential_model.py @@ -67,6 +67,13 @@ def compute_grads(self, features, enc_labels): max_two_norms_list.append(self.jacobian_max_two_norm(jacobians)) with tf.device(self.features_party_dev): + # For some reason, when running the jacobian on an accelerator, the + # weights must be touched otherwise training loss goes to NaN. Maybe + # it is to ensure the weights are on assigned to features_party + # device for later, when the final gradient is added to weights (on + # CPU)? + tf.print(self.trainable_variables, output_stream="file:///dev/null") + predictions = tf.concat(predictions_list, axis=0) max_two_norm = tf.reduce_max(max_two_norms_list) jacobians = [tf.concat(j, axis=0) for j in zip(*jacobians_list)]