diff --git a/normflows/core.py b/normflows/core.py index b4b0c88..9920193 100644 --- a/normflows/core.py +++ b/normflows/core.py @@ -585,12 +585,12 @@ def sample(self, num_samples=1, y=None, temperature=None): self.reset_temperature() return z, log_q - def log_prob(self, x, y): + def log_prob(self, x, y=None): """Get log probability for batch Args: x: Batch - y: Classes of x + y: Classes of x. Must be passed in if `class_cond` is True. Returns: log probability