Skip to content

Commit

Permalink
Polish diffusion sample API
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jul 1, 2024
1 parent f22277f commit 6fda30d
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions grl/generative_models/diffusion_model/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,10 @@ def sample_forward_process(
assert False, "Invalid batch size"

if x_0 is not None and condition is not None:
assert (
x_0.shape[0] == condition.shape[0]
), "The batch size of x_0 and condition must be the same"
if type(x_0) == type(condition):
assert (
x_0.shape[0] == condition.shape[0]
), "The batch size of x_0 and condition must be the same"
data_batch_size = x_0.shape[0]
elif x_0 is not None:
data_batch_size = x_0.shape[0]
Expand Down Expand Up @@ -225,10 +226,20 @@ def sample_forward_process(
# x.shape = (B*N, D)

if condition is not None:
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
if isinstance(condition, torch.Tensor):
condition = torch.repeat_interleave(
condition, torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
elif isinstance(condition, treetensor.torch.Tensor):
for key in condition.keys():
condition[key] = torch.repeat_interleave(
condition[key], torch.prod(extra_batch_size), dim=0
)
# condition.shape = (B*N, D)
else:
raise NotImplementedError("Not implemented")


if isinstance(solver, DPMSolver):
# Note: DPMSolver does not support t_span argument assignment
Expand Down

0 comments on commit 6fda30d

Please sign in to comment.