Skip to content

Commit

Permalink
Merge pull request #16 from unicef/feature/pickle
Browse files Browse the repository at this point in the history
chg ! do not use pickle
  • Loading branch information
saxix authored May 21, 2024
2 parents d9bcd93 + 08eed9c commit 376c029
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
1 change: 0 additions & 1 deletion bandit.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
exclude_dirs: ['tests',]
#tests = ["B201", "B301"]
#skips = ["B101", "B601"]
skips: ["B403", "B301"]

Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import pickle
from typing import Dict, List, Tuple

from django.conf import settings
Expand Down Expand Up @@ -36,7 +35,7 @@ def __init__(self, filename: str) -> None:
self.net.setPreferableTarget(settings.DNN_TARGET)

self.filename: str = filename
self.encodings_filename = f"{self.filename}.pkl"
self.encodings_filename = f"{self.filename}.npy"

self.confidence: float = settings.FACE_DETECTION_CONFIDENCE
self.threshold: float = settings.DISTANCE_THRESHOLD
Expand Down Expand Up @@ -72,9 +71,9 @@ def _load_encodings_all(self) -> Dict[str, List[np.ndarray]]:
try:
_, files = self.storages["encoded"].listdir("")
for file in files:
if file.endswith(".pkl"):
if file.endswith(".npy"):
with self.storages["encoded"].open(file, "rb") as f:
data[os.path.splitext(file)[0]] = pickle.load(f)
data[os.path.splitext(file)[0]] = np.load(f, allow_pickle=False)
except Exception as e:
self.logger.exception(f"Error loading encodings: {e}", exc_info=True)
return data
Expand All @@ -93,7 +92,7 @@ def _encode_face(self) -> None:
else:
self.logger.error(f"Invalid face region {region}")
with self.storages["encoded"].open(self.encodings_filename, "wb") as f:
pickle.dump(encodings, f)
np.save(f, encodings)
except Exception as e:
self.logger.exception(f"Error processing face encodings for image {self.filename}", exc_info=e)

Expand Down
13 changes: 5 additions & 8 deletions tests/faces/test_duplication_detector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import pickle
from unittest.mock import MagicMock, mock_open, patch

from django.conf import settings
Expand All @@ -18,7 +17,7 @@ def test_duplication_detector_initialization(dd):
assert dd.confidence == settings.FACE_DETECTION_CONFIDENCE
assert dd.threshold == settings.DISTANCE_THRESHOLD
assert dd.filename == FILENAME
assert dd.encodings_filename == f"{FILENAME}.pkl"
assert dd.encodings_filename == f"{FILENAME}.npy"
for storage_name, storage in dd.storages.items():
assert isinstance(storage, MagicMock)
if storage_name == "cv2dnn":
Expand Down Expand Up @@ -92,24 +91,22 @@ def test_load_encodings_all_no_files(dd):


def test_load_encodings_all_with_files(dd):
mock_encoded_data = {f"{filename}.pkl": [np.array([1, 2, 3]), np.array([4, 5, 6])] for filename in FILENAMES}
mock_encoded_data = {f"{filename}.npy": [np.array([1, 2, 3]), np.array([4, 5, 6])] for filename in FILENAMES}
encoded_data = {os.path.splitext(key)[0]: value for key, value in mock_encoded_data.items()}
print(f"\n{mock_encoded_data=}\n{encoded_data=}")

# Mock the storage's listdir method to return the file names
with patch.object(
dd.storages["encoded"],
"listdir",
return_value=(None, [f"{filename}.pkl" for filename in FILENAMES]),
return_value=(None, [f"{filename}.npy" for filename in FILENAMES]),
):
print(f"{dd.storages['encoded'].listdir()[1]=}")
# Mock the storage's open method to return the data for each file
with patch(
"builtins.open",
side_effect=lambda f: mock_open(read_data=pickle.dumps(mock_encoded_data[f])).return_value,
side_effect=lambda f: mock_open(read_data=np.save(mock_encoded_data[f])).return_value,
):
mo = mock_open()
mo.return_value = pickle.dumps(mock_encoded_data)
dd._load_encodings_all()
# Assert that the returned encodings match the expected data
# TODO: Fix
Expand Down Expand Up @@ -184,7 +181,7 @@ def test_find_duplicates_successful(dd, mock_hde_azure_storage):

dd._encode_face.assert_not_called()
dd._load_encodings_all.assert_called_once()
mock_hde_azure_storage.exists.assert_called_once_with(f"{FILENAME}.pkl")
mock_hde_azure_storage.exists.assert_called_once_with(f"{FILENAME}.npy")


def test_find_duplicates_calls_encode_face_when_no_encodings(dd):
Expand Down

0 comments on commit 376c029

Please sign in to comment.