diff --git a/grl/generative_models/diffusion_model/diffusion_model.py b/grl/generative_models/diffusion_model/diffusion_model.py index 093f7e8c..f2aee116 100644 --- a/grl/generative_models/diffusion_model/diffusion_model.py +++ b/grl/generative_models/diffusion_model/diffusion_model.py @@ -181,9 +181,10 @@ def sample_forward_process( assert False, "Invalid batch size" if x_0 is not None and condition is not None: - assert ( - x_0.shape[0] == condition.shape[0] - ), "The batch size of x_0 and condition must be the same" + if type(x_0) == type(condition): + assert ( + x_0.shape[0] == condition.shape[0] + ), "The batch size of x_0 and condition must be the same" data_batch_size = x_0.shape[0] elif x_0 is not None: data_batch_size = x_0.shape[0] @@ -225,10 +226,20 @@ def sample_forward_process( # x.shape = (B*N, D) if condition is not None: - condition = torch.repeat_interleave( - condition, torch.prod(extra_batch_size), dim=0 - ) - # condition.shape = (B*N, D) + if isinstance(condition, torch.Tensor): + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + elif isinstance(condition, treetensor.torch.Tensor): + for key in condition.keys(): + condition[key] = torch.repeat_interleave( + condition[key], torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + else: + raise NotImplementedError("Not implemented") + if isinstance(solver, DPMSolver): # Note: DPMSolver does not support t_span argument assignment