Skip to content

Commit

Permalink
Merge pull request #216 from taras-sereda/master
Browse files Browse the repository at this point in the history
simplified batch softmax
  • Loading branch information
Sean Naren authored Jan 12, 2018
2 parents 87fd100 + c31416e commit e2c2d83
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def __repr__(self):
class InferenceBatchSoftmax(nn.Module):
def forward(self, input_):
if not self.training:
batch_size = input_.size()[0]
return torch.stack([F.softmax(input_[i], dim=1) for i in range(batch_size)], 0)
return F.softmax(input_, dim=-1)
else:
return input_

Expand Down

0 comments on commit e2c2d83

Please sign in to comment.