-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpredict.py
96 lines (73 loc) · 2.75 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#! /usr/bin/env python
import argparse
import os
import cv2
import numpy as np
from tqdm import tqdm
from preprocessing import parse_annotation
from utils import draw_boxes
from frontend import YOLO
import json
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
argparser = argparse.ArgumentParser(
description='Train and validate YOLO_v2 model on any dataset')
argparser.add_argument(
'-c',
'--conf',
help='path to configuration file')
argparser.add_argument(
'-w',
'--weights',
help='path to pretrained weights')
argparser.add_argument(
'-i',
'--input',
help='path to an image or an video (mp4 format)')
def _main_(args):
config_path = args.conf
weights_path = args.weights
image_path = args.input
with open(config_path) as config_buffer:
config = json.load(config_buffer)
###############################
# Make the model
###############################
yolo = YOLO(backend = config['model']['backend'],
input_size = config['model']['input_size'],
labels = config['model']['labels'],
max_box_per_image = config['model']['max_box_per_image'],
anchors = config['model']['anchors'])
###############################
# Load trained weights
###############################
yolo.load_weights(weights_path)
###############################
# Predict bounding boxes
###############################
if image_path[-4:] == '.mp4':
video_out = image_path[:-4] + '_detected' + image_path[-4:]
video_reader = cv2.VideoCapture(image_path)
nb_frames = int(video_reader.get(cv2.CAP_PROP_FRAME_COUNT))
frame_h = int(video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_w = int(video_reader.get(cv2.CAP_PROP_FRAME_WIDTH))
video_writer = cv2.VideoWriter(video_out,
cv2.VideoWriter_fourcc(*'MPEG'),
50.0,
(frame_w, frame_h))
for i in tqdm(range(nb_frames)):
_, image = video_reader.read()
boxes = yolo.predict(image)
image = draw_boxes(image, boxes, config['model']['labels'])
video_writer.write(np.uint8(image))
video_reader.release()
video_writer.release()
else:
image = cv2.imread(image_path)
boxes = yolo.predict(image)
image = draw_boxes(image, boxes, config['model']['labels'])
print(len(boxes), 'objects are found')
cv2.imwrite(image_path[:-4] + '_detected' + image_path[-4:], image)
if __name__ == '__main__':
args = argparser.parse_args()
_main_(args)