Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Taka008 committed Aug 8, 2023
1 parent 3073e76 commit 34ddade
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/kwja/modules/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
else None,
**self.hparams.decoding,
)
if isinstance(generations, torch.Tensor):
seq2seq_predictions = generations
else:
seq2seq_predictions = generations["sequences"]

Check warning on line 109 in src/kwja/modules/seq2seq.py

View check run for this annotation

Codecov / codecov/patch

src/kwja/modules/seq2seq.py#L109

Added line #L109 was not covered by tests
return {
"example_ids": batch["example_ids"] if "example_ids" in batch else [],
"seq2seq_predictions": generations,
"seq2seq_predictions": seq2seq_predictions,
}

0 comments on commit 34ddade

Please sign in to comment.