-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEditDistanceMetric.py
41 lines (33 loc) · 1.31 KB
/
EditDistanceMetric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from logging import log
from ignite.metrics import Metric
from ignite.metrics.metric import sync_all_reduce, reinit__is_reduced
from ignite.exceptions import NotComputableError
import torch
import difflib
import logging
class EditDistanceMetric(Metric):
def __init__(self, output_transform=lambda x: x, device=None):
super().__init__(output_transform=output_transform, device=device)
self._edit_distances = 0
self._num_examples = 0
@reinit__is_reduced
def reset(self):
self._edit_distances = 0
self._num_examples = 0
return super().reset()
@reinit__is_reduced
def update(self, output):
y_pred, y = output
if isinstance(y_pred, torch.Tensor):
y_pred = list(y_pred.cpu())
if isinstance(y, torch.Tensor):
y = list(y.cpu())
for output, label in zip(y_pred, y):
self._edit_distances += difflib.SequenceMatcher(None, label, output).ratio()
self._num_examples += len(y)
return super().update(output)
@sync_all_reduce("_num_examples", "_edit_distances")
def compute(self):
if self._num_examples == 0:
raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
return self._edit_distances / self._num_examples