From eae2df4162f5ebefb3836c2047f8f3e2bde61106 Mon Sep 17 00:00:00 2001 From: Lyaction Date: Mon, 15 Apr 2024 10:21:12 +0800 Subject: [PATCH] fix type of ev gather op Signed-off-by: Lyaction --- tensorflow/python/ops/kv_variable_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/ops/kv_variable_ops.py b/tensorflow/python/ops/kv_variable_ops.py index b7689acd075..524ce28213c 100644 --- a/tensorflow/python/ops/kv_variable_ops.py +++ b/tensorflow/python/ops/kv_variable_ops.py @@ -778,10 +778,10 @@ def sparse_read(self, indices, name=None, ev_init_value=None, counts=None): if self._trainable: tape.variable_accessed(self) if ev_init_value is not None: - default_value = ev_init_value + default_value = math_ops.cast(ev_init_value, self.dtype) is_use_default_value_tensor = True else: - default_value = ops.convert_to_tensor(1.0) + default_value = ops.convert_to_tensor(1.0, dtype=self.dtype) is_use_default_value_tensor = False if counts != None: value = gen_kv_variable_ops.kv_resource_gather_v1(self._handle,