-
Notifications
You must be signed in to change notification settings - Fork 11
/
inference.py
74 lines (58 loc) · 2.05 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from absl import app, flags
from modules.models import PPG2ECG
flags.DEFINE_string("weights", "", "model weights for inferencing")
flags.DEFINE_string("input", "example/PPG.npy", "input data (numpy array)")
FLAGS = flags.FLAGS
def main(argv):
# prepare the parameters
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# prepare the model
model = PPG2ECG(
input_size=200,
use_stn=True,
use_attention=True).to(device)
# load the model state
model.load_state_dict(torch.load(Path(FLAGS.weights))["net"])
model.eval()
# prepare the inference data (PPG data)
ppg = np.load(Path(FLAGS.input))
print("ppg shape: {}".format(ppg.shape))
# run through the data
idx = 0
step = 100
ecg = []
while True:
# out of range for ppg data
if (idx+200) > len(ppg):
break
# preprocess the single data to match the input size
input_data = ppg[idx:idx+200]
# reshape the data to [1, 200]
input_data = input_data.reshape((1, -1))
# move ppg data to torch tensor and device
input_data = torch.from_numpy(input_data).to(device).float()
# inference
# in torch, you need (batch, data) for forward
# so you should unsqueeze the input data by unsqueeze(0)
# now the input size should be [1, 1, 200]
with torch.no_grad():
output_data = model(input_data.unsqueeze(0))
output_data = output_data["output"].cpu()
ecg.append(output_data[0, 0]) # [1, 1, 200] -> [200,]
idx += step
# model performs better in middle [50:150] for whole output [0:200]
# also we drop first 50 and last 50 for ppg to align the ppg and ecg
ecg = [e[50:150] for e in ecg]
ppg = ppg[50:-50]
# show the plot
ecg = torch.cat(ecg)
plt.plot(ppg, label="ppg")
plt.plot(ecg, label="ecg")
plt.legend()
plt.show()
if __name__ == "__main__":
app.run(main)