From c0ade86b9cca3c4a831bcf559f84228296b329c9 Mon Sep 17 00:00:00 2001 From: JunqiHu Date: Thu, 31 Aug 2023 10:14:49 +0800 Subject: [PATCH] [Op] Prevent inconsistent number of Ops and devices during distributed training. Signed-off-by: JunqiHu --- tensorflow/python/training/saver.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 981d01dd7be..43e10009532 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -550,8 +550,14 @@ def _GroupByDevices(self, saveables): """ per_device = collections.defaultdict(lambda: []) for saveable in saveables: - canonical_device = set( - pydev.canonical_name(spec.tensor.device) for spec in saveable.specs) + canonical_device = set() + for spec in saveable.specs: + device_name = pydev.canonical_name(spec.tensor.device) + device_idx = device_name.find("/device") + if device_idx != -1: + canonical_device.add(device_name[:device_idx]) + else: + canonical_device.add(device_name) if len(canonical_device) != 1: raise ValueError("All tensors of a saveable object must be " "on the same device: %s" % saveable.name)