diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 8f07e0494..acda7827c 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -45,10 +45,10 @@ def to_torch( x.dtype.type, np.bool_ | np.number, ): # most often case - x = torch.from_numpy(x).to(device) + x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) - return x + return x.to(device) if isinstance(x, torch.Tensor): # second often case if dtype is not None: x = x.type(dtype)