-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetric.py
26 lines (19 loc) · 815 Bytes
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# coding=utf-8
import mxnet as mx
class siamise_metric(mx.metric.EvalMetric):
def __init__(self, name='siamise_acc'):
super(siamise_metric, self).__init__(name=name)
def update(self, label, pred):
preds = pred[0]
labels = label[0]
preds_label = preds.asnumpy().ravel()
labels = labels.asnumpy().ravel()
self.sum_metric += labels[preds_label < 0.5].sum() + len(labels[preds_label >= 0.5]) - labels[preds_label >= 0.5].sum()
self.num_inst += len(labels)
class contrastive_loss(mx.metric.EvalMetric):
def __init__(self, name='contrastive_loss'):
super(contrastive_loss, self).__init__(name=name)
def update(self, label, pred):
loss = pred[1].asnumpy()
self.sum_metric += loss
self.num_inst += len(loss)