From 631e1cf66fd17a62c0155c97ae4e910fb2186c13 Mon Sep 17 00:00:00 2001 From: hank0626 Date: Thu, 25 Jul 2024 14:34:18 +0800 Subject: [PATCH] update get_spatial_pos_embed --- opensora/models/dit/dit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opensora/models/dit/dit.py b/opensora/models/dit/dit.py index 68a3eb7f..d4a25ed1 100644 --- a/opensora/models/dit/dit.py +++ b/opensora/models/dit/dit.py @@ -158,7 +158,7 @@ def __init__( def get_spatial_pos_embed(self): pos_embed = get_2d_sincos_pos_embed( self.hidden_size, - self.input_size[1] // self.patch_size[1], + (self.input_size[1] // self.patch_size[1], self.input_size[2] // self.patch_size[2]) ) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).requires_grad_(False) return pos_embed