-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathlogger.py
151 lines (118 loc) · 4.7 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python
import os
import re
import time
import numpy as np
from sklearn.model_selection import train_test_split
from mmwave_gesture.data import GESTURE, DataLoader
from mmwave_gesture.utils.thread_safe_print import print, warning
import colorama
from colorama import Fore
colorama.init(autoreset=True)
class Logger:
def __init__(self, timeout=.5, data_dir=None):
self.start_timeout = 3*timeout
self.data_dir = data_dir
self.end_timeout = timeout
self.reset()
def reset(self):
self.data = None
self.detected_time = 0
self.empty_frames_cnt = 0
self.timeout = self.start_timeout
def log(self, frame, echo=False, max_frames=None):
if self.data is None:
self.data = []
self.detected_time = time.perf_counter()
self.timeout = self.start_timeout
self.empty_frames_cnt = 0
if echo:
print('Saving sample...', end='')
if time.perf_counter() - self.detected_time > self.timeout:
data = self.data
self.reset()
return data
if frame and frame.get('tlvs', {}).get('detectedPoints'):
self.timeout = self.end_timeout
self.detected_time = time.perf_counter()
if self.data:
self.data.extend([None]*self.empty_frames_cnt)
self.data.append(frame['tlvs']['detectedPoints']['objs'])
if max_frames is not None and len(self.data) >= max_frames:
data = self.data
self.reset()
return data
self.empty_frames_cnt = 0
return None
self.empty_frames_cnt += 1
return None
def check_len(self, sample, echo=True):
if sum(1 for frame in sample if frame is not None) < 3:
if echo and not all(frame == None for frame in sample):
warning('Gesture too short.\n')
return False
return True
@staticmethod
def get_gesture(gesture, dir):
gesture = gesture if isinstance(gesture, GESTURE) else GESTURE[gesture]
if dir is not None:
gesture.dir = dir
return gesture
def save(self, data, gesture):
if not data:
warning('Nothing to save.')
return
if not self.check_len(data):
return
gesture = self.get_gesture(gesture, self.data_dir)
if not os.path.exists(gesture.dir):
os.makedirs(gesture.dir)
np.savez_compressed(gesture.next_file(), data=np.array(data, dtype=object))
print(f'{Fore.GREEN}Done.')
def discard_last_sample(self, gesture):
last_sample = self.get_gesture(gesture, self.data_dir).last_file()
if last_sample is None:
print('No files.')
return
os.remove(last_sample)
print('File deleted.')
@staticmethod
def get_data(gesture=None):
X_paths, y = Logger.get_paths(gesture)
X = [DataLoader(path).load() for path in X_paths]
return X, y
@staticmethod
def _get_paths(gesture):
paths, y = [], []
if not os.path.exists(gesture.dir):
return paths, y
extensions = DataLoader.get_extensions()
for root, _, files in os.walk(gesture.dir):
for f in files:
if re.match(fr'.*\.({"|".join(extensions)})$', f):
paths.append(os.path.join(root, f))
y.append(gesture.value)
return paths, y
@staticmethod
def get_paths(gesture=None, dir=None):
if gesture is not None:
return Logger._get_paths(Logger.get_gesture(gesture, dir))
# Get all gestures instead
paths, y = [], []
for gesture in GESTURE:
gesture_paths, labels = Logger._get_paths(Logger.get_gesture(gesture, dir))
if not gesture_paths or not labels:
continue
paths.extend(gesture_paths)
y.extend(labels)
return paths, y
@staticmethod
def split(paths, y, test_size=1/3):
train_paths, test_paths, y_train, y_test = train_test_split(paths, y, stratify=y,
test_size=test_size,
random_state=12)
val_paths, test_paths, y_val, y_test = train_test_split(test_paths, y_test,
stratify=y_test,
test_size=.5,
random_state=12)
return (train_paths, y_train), (val_paths, y_val), (test_paths, y_test)