diff --git a/turftopic/models/ctm.py b/turftopic/models/ctm.py index f3d6665..8781ac2 100644 --- a/turftopic/models/ctm.py +++ b/turftopic/models/ctm.py @@ -213,9 +213,7 @@ def fit( seed = self.random_state or random.randint(0, 10_000) torch.manual_seed(seed) pyro.set_rng_seed(seed) - device = torch.device( - "cuda:0" if torch.cuda.is_available() else "cpu" - ) + device = torch.device("cpu") pyro.clear_param_store() contextualized_size = embeddings.shape[1] if self.combined: