diff --git a/opensora/models/dit/dit.py b/opensora/models/dit/dit.py index 68a3eb7f..d4a25ed1 100644 --- a/opensora/models/dit/dit.py +++ b/opensora/models/dit/dit.py @@ -158,7 +158,7 @@ def __init__( def get_spatial_pos_embed(self): pos_embed = get_2d_sincos_pos_embed( self.hidden_size, - self.input_size[1] // self.patch_size[1], + (self.input_size[1] // self.patch_size[1], self.input_size[2] // self.patch_size[2]) ) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) return pos_embed