Skip to content

Commit

Permalink
fix bug in dit1d
Browse files Browse the repository at this point in the history
  • Loading branch information
zjowowen committed Jan 1, 2025
1 parent 829be8b commit 8c7814b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion grl/neural_network/transformers/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,6 @@ def _basic_init(module):
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.x_embedder.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)

# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
Expand Down Expand Up @@ -1301,6 +1300,7 @@ def forward(
# x is of shape (N, T, C), reshape to (N, C, T)
x = torch.einsum("ntc->nct", x)
x = self.x_embedder(x) + torch.einsum("th->ht", self.pos_embed)
x = torch.einsum("nht->nth", x) # (N, total_patches, hidden_size)

t = self.t_embedder(t) # (N, hidden_size)

Expand Down

0 comments on commit 8c7814b

Please sign in to comment.