-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict_pretrained.py
42 lines (32 loc) · 1.49 KB
/
predict_pretrained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
"""
Diese Datei sollte nicht verändert werden und wird von uns gestellt und zurückgesetzt.
Skript testet das vortrainierte Modell
@author: Maurice Rohr
"""
# import socket
# def guard(*args, **kwargs):
# raise Exception("Internet Access Forbidden")
# socket.socket = guard
from predict import predict_labels
from wettbewerb import load_references, save_predictions
import argparse
import time
import logging
logger = logging.getLogger("main_log")
if __name__ == '__main__':
logging.basicConfig(filename="logfile_predict.log",
format='%(asctime)s %(message)s',
filemode='w')
# Setting the threshold of logger to DEBUG
logger.setLevel(logging.DEBUG)
parser = argparse.ArgumentParser(description='Predict given Model')
parser.add_argument('--test_dir', action='store', type=str, default='../test/')
parser.add_argument('--model_name', action='store', type=str, default='international_CO1')
args = parser.parse_args()
ecg_leads, ecg_labels, fs, ecg_names = load_references(args.test_dir) # Importiere EKG-Dateien, zugehörige Diagnose, Sampling-Frequenz (Hz) und Name # Sampling-Frequenz 300 Hz
start_time = time.time()
predictions = predict_labels(ecg_leads, fs, ecg_names, model_name=args.model_name)
pred_time = time.time()-start_time
save_predictions(predictions) # speichert Prädiktion in CSV Datei
print("Runtime", pred_time, "s")