-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlite_classifier.py
46 lines (35 loc) · 1.43 KB
/
lite_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import cv2
import pydload
import numpy as np
from image_utils import load_images
class LiteClassifier:
def __init__(self):
url = "https://github.com/notAI-tech/NudeNet/releases/download/v0/classifier_lite.onnx"
home = os.path.expanduser("~")
model_folder = os.path.join(home, ".NudeNet/")
if not os.path.exists(model_folder):
os.mkdir(model_folder)
model_path = os.path.join(model_folder, os.path.basename(url))
if not os.path.exists(model_path):
print("Downloading the checkpoint to", model_path)
pydload.dload(url, save_to_path=model_path, max_time=None)
self.lite_model = cv2.dnn.readNet(model_path)
def classify(self, image_paths, size=(256, 256)):
if isinstance(image_paths, str):
image_paths = [image_paths]
else : image_paths = [image_paths]
result = {}
for image_path in image_paths:
#print (image_path.shape)
loaded_images, _ = load_images([image_path], size, image_names=[0])
#print(loaded_images[0].shape)
loaded_images = np.rollaxis(loaded_images, 3, 1)
#print(loaded_images[0].shape)
self.lite_model.setInput(loaded_images)
pred = self.lite_model.forward()
result[0] = {
"unsafe": pred[0][0],
"safe": pred[0][1],
}
return result