diff --git a/models.py b/models.py index c90eeba..d505659 100644 --- a/models.py +++ b/models.py @@ -227,7 +227,7 @@ def unpatchify(self, x): x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) return imgs def forward(self, x, t, y):