Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 3, 2024
1 parent 342b48a commit e25173b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 15 deletions.
8 changes: 2 additions & 6 deletions src/scvi/external/decipher/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def setup_anndata(
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True),
]
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)

Expand Down Expand Up @@ -140,9 +138,7 @@ def get_latent_representation(
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)

scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
latent_locs = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
Expand Down
12 changes: 3 additions & 9 deletions src/scvi/external/decipher/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ def device(self):
return self._dummy_param.device

@staticmethod
def _get_fn_args_from_batch(
tensor_dict: dict[str, torch.Tensor]
) -> Iterable | dict:
def _get_fn_args_from_batch(tensor_dict: dict[str, torch.Tensor]) -> Iterable | dict:
x = tensor_dict[REGISTRY_KEYS.X_KEY]
return (x,), {}

Expand Down Expand Up @@ -116,9 +114,7 @@ def model(self, x: torch.Tensor):
self.theta + self._epsilon
)
# noinspection PyUnresolvedReferences
x_dist = dist.NegativeBinomial(
total_count=self.theta + self._epsilon, logits=logit
)
x_dist = dist.NegativeBinomial(total_count=self.theta + self._epsilon, logits=logit)
pyro.sample("x", x_dist.to_event(1), obs=x)

@auto_move_data
Expand Down Expand Up @@ -174,9 +170,7 @@ def predictive_log_likelihood(self, x: torch.Tensor, n_samples=5):
model_trace = poutine.trace(
poutine.replay(self.model, trace=guide_trace)
).get_trace(x)
log_weights.append(
model_trace.log_prob_sum() - guide_trace.log_prob_sum()
)
log_weights.append(model_trace.log_prob_sum() - guide_trace.log_prob_sum())

finally:
self.beta = old_beta
Expand Down

0 comments on commit e25173b

Please sign in to comment.