Skip to content

Commit

Permalink
chg ! use chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
vitali-yanushchyk-valor committed Dec 31, 2024
1 parent ca3a76e commit fbea6b1
Show file tree
Hide file tree
Showing 22 changed files with 581 additions and 515 deletions.
1 change: 1 addition & 0 deletions src/hope_dedup_engine/apps/api/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .deduplicationset import DeduplicationSetAdmin # noqa
from .finding import FindingAdmin # noqa
from .hdetoken import HDETokenAdmin # noqa
from .ignored_pair import IgnoredFilenamePairAdmin, IgnoredReferencePkPairAdmin # noqa
from .image import ImageAdmin # noqa
from .jobs import DedupJob # noqa
2 changes: 2 additions & 0 deletions src/hope_dedup_engine/apps/api/admin/finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ class FindingAdmin(AdminFiltersMixin, ModelAdmin):
"id",
"deduplication_set",
"score",
"error",
"first_reference_pk",
"second_reference_pk",
)
list_filter = (
("deduplication_set", AutoCompleteFilter),
("score", NumberFilter),
("error", NumberFilter),
DjangoLookupFilter,
)

Expand Down
25 changes: 25 additions & 0 deletions src/hope_dedup_engine/apps/api/admin/ignored_pair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from django.contrib import admin

from adminfilters.autocomplete import AutoCompleteFilter
from adminfilters.mixin import AdminFiltersMixin

from hope_dedup_engine.apps.api.models import (
IgnoredFilenamePair,
IgnoredReferencePkPair,
)


class IgnoredPairBaseAdmin(AdminFiltersMixin, admin.ModelAdmin):
list_display = ("id", "first", "second", "deduplication_set")
list_filter = (("deduplication_set", AutoCompleteFilter),)
search_fields = ("first", "second")


@admin.register(IgnoredReferencePkPair)
class IgnoredReferencePkPairAdmin(IgnoredPairBaseAdmin):
pass


@admin.register(IgnoredFilenamePair)
class IgnoredFilenamePairAdmin(IgnoredPairBaseAdmin):
pass
67 changes: 32 additions & 35 deletions src/hope_dedup_engine/apps/api/deduplication/adapters.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from collections.abc import Callable, Generator

from constance import config

from hope_dedup_engine.apps.api.deduplication.registry import DuplicateKeyPair
from hope_dedup_engine.apps.api.models import DeduplicationSet
from hope_dedup_engine.apps.faces.services import FacialDetector
from hope_dedup_engine.constants import is_facial_error


class DuplicateFaceFinder:
Expand All @@ -18,34 +14,35 @@ def __init__(self, deduplication_set: DeduplicationSet):
def run(
self, tracker: Callable[[int], None] | None = None
) -> Generator[DuplicateKeyPair, None, None]:
filename_to_reference_pk = {
filename: reference_pk
for reference_pk, filename in self.deduplication_set.image_set.values_list(
"reference_pk", "filename"
)
}
options = {
"detector_backend": config.FACE_DETECTOR_MODEL,
"model_name": config.FACIAL_RECOGNITION_MODEL,
}
# options = ConfigDefaults()
# if self.deduplication_set.config:
# options.apply_config_overrides(self.deduplication_set.config.settings)
# ignored key pairs are not handled correctly in DuplicationDetector
detector = FacialDetector(
self.deduplication_set.pk,
tuple[str](filename_to_reference_pk.keys()),
options=options,
)
for first_filename, second_filename, distance in detector.find_duplicates(
# tracker
):
yield (
filename_to_reference_pk[first_filename],
(
filename_to_reference_pk[second_filename]
if second_filename in filename_to_reference_pk
else second_filename
),
distance if is_facial_error(distance) else (1 - distance),
)
...
# filename_to_reference_pk = {
# filename: reference_pk
# for reference_pk, filename in self.deduplication_set.image_set.values_list(
# "reference_pk", "filename"
# )
# }
# options = {
# "detector_backend": config.FACE_DETECTOR_MODEL,
# "model_name": config.FACIAL_RECOGNITION_MODEL,
# }
# # options = ConfigDefaults()
# # if self.deduplication_set.config:
# # options.apply_config_overrides(self.deduplication_set.config.settings)
# # ignored key pairs are not handled correctly in DuplicationDetector
# detector = FacialDetector(
# self.deduplication_set.pk,
# tuple[str](filename_to_reference_pk.keys()),
# options=options,
# )
# for first_filename, second_filename, distance in detector.find_duplicates(
# # tracker
# ):
# yield (
# filename_to_reference_pk[first_filename],
# (
# filename_to_reference_pk[second_filename]
# if second_filename in filename_to_reference_pk
# else second_filename
# ),
# distance if is_facial_error(distance) else (1 - distance),
# )
105 changes: 49 additions & 56 deletions src/hope_dedup_engine/apps/api/deduplication/config.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,63 @@
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Self
from uuid import UUID

from constance import config as constance_cfg

@dataclass
class DetectionConfig:
...
# dnn_files_source: str = field(
# default_factory=lambda: constance_cfg.DNN_FILES_SOURCE
# )
# dnn_backend: int = field(default_factory=lambda: constance_cfg.DNN_BACKEND)
# dnn_target: int = field(default_factory=lambda: constance_cfg.DNN_TARGET)
# blob_from_image_scale_factor: float = field(
# default_factory=lambda: constance_cfg.BLOB_FROM_IMAGE_SCALE_FACTOR
# )
# blob_from_image_mean_values: tuple[float, float, float] = field(
# default_factory=lambda: tuple(
# map(float, constance_cfg.BLOB_FROM_IMAGE_MEAN_VALUES.split(", "))
# )
# )
# confidence: float = field(
# default_factory=lambda: constance_cfg.FACE_DETECTION_CONFIDENCE
# )
# nms_threshold: float = field(default_factory=lambda: constance_cfg.NMS_THRESHOLD)
from hope_dedup_engine.apps.api.models import DeduplicationSet


@dataclass
class RecognitionConfig:
...
# num_jitters: int = field(
# default_factory=lambda: constance_cfg.FACE_ENCODINGS_NUM_JITTERS
# )
# model: Literal["small", "large"] = field(
# default_factory=lambda: constance_cfg.FACE_ENCODINGS_MODEL
# )
# preprocessors: list[str] = field(default_factory=list)
class ModelOptions:
model_name: str = field(
default_factory=lambda: constance_cfg.FACE_RECOGNITION_MODEL
)
detector_backend: str = field(
default_factory=lambda: constance_cfg.FACE_DETECTOR_BACKEND
)

def update(self, overrides: dict[str, Any]) -> None:
for k, v in overrides.items():
if hasattr(self, k):
setattr(self, k, v)


@dataclass
class DuplicatesConfig:
...
# tolerance: float = field(
# default_factory=lambda: constance_cfg.FACE_DISTANCE_THRESHOLD
# )
class EncodingOptions(ModelOptions):
pass


@dataclass
class ConfigDefaults:
detection: DetectionConfig = field(default_factory=DetectionConfig)
recognition: RecognitionConfig = field(default_factory=RecognitionConfig)
duplicates: DuplicatesConfig = field(default_factory=DuplicatesConfig)
class DeduplicateOptions(ModelOptions):
threshold: float = field(
default_factory=lambda: constance_cfg.FACE_DISTANCE_THRESHOLD
)
silent: bool = True

def apply_config_overrides(
self, config_settings: dict[str, Any] | None = None
) -> None:
"""
Updates the instance with values from the provided config settings.

Parameters:
config_settings (dict | None): Optional dictionary of configuration overrides, structured to match
sections in ConfigDefaults (e.g., "detection", "recognition", "duplicates"). Only matching attributes
are updated. No changes are made if `config_settings` is `None` or empty.
"""
if config_settings:
for section_name, section_data in config_settings.items():
dataclass_section = getattr(self, section_name, None)
if dataclass_section and isinstance(section_data, dict):
for k, v in section_data.items():
if hasattr(dataclass_section, k):
setattr(dataclass_section, k, v)
@dataclass
class DeduplicationSetConfig:
deduplication_set_id: UUID | None = None
encoding: EncodingOptions = field(default_factory=EncodingOptions)
deduplicate: DeduplicateOptions = field(default_factory=DeduplicateOptions)

def update(self, overrides: dict[str, Any]) -> None:
if not isinstance(overrides, dict):
raise ValueError("Overrides values must be a dictionary.")
for k, v in overrides.items():
match k:
case "encoding" if isinstance(v, dict):
self.encoding.update(v)
case "deduplicate" if isinstance(v, dict):
self.deduplicate.update(v)
case _ if hasattr(self, k):
setattr(self, k, v)
case _:
raise KeyError(f"Unknown config key: {k}")

@classmethod
def from_deduplication_set(cls, deduplication_set: DeduplicationSet) -> Self:
instance = cls(deduplication_set_id=deduplication_set.pk)
if deduplication_set.config:
instance.update(deduplication_set.config.settings)
return instance
68 changes: 45 additions & 23 deletions src/hope_dedup_engine/apps/api/deduplication/process.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
from functools import partial
from dataclasses import asdict

from django.db.models import F

from celery import shared_task
from celery import chord, shared_task

from hope_dedup_engine.apps.api.deduplication.registry import ( # DuplicateFinder,; DuplicateKeyPair,
get_finders,
)
from hope_dedup_engine.apps.api.deduplication.config import DeduplicationSetConfig

# from hope_dedup_engine.apps.api.deduplication.registry import ( # DuplicateFinder,; DuplicateKeyPair,
# get_finders,
# )
from hope_dedup_engine.apps.api.models import DedupJob, DeduplicationSet, Finding
from hope_dedup_engine.apps.api.utils.notification import send_notification
from hope_dedup_engine.apps.api.utils.progress import track_progress_multi

# from hope_dedup_engine.apps.api.utils.progress import track_progress_multi
from hope_dedup_engine.apps.faces.celery_tasks import (
callback_encodings,
encode_chunk,
get_chunks,
)

CHUNK_SIZE = 100

# def _sort_keys(pair: DuplicateKeyPair) -> DuplicateKeyPair:
# first, second, score = pair
Expand Down Expand Up @@ -66,7 +76,7 @@

def update_job_progress(job: DedupJob, progress: int) -> None:
job.progress = progress
job.save()
job.save(update_fields=["progress"])


@shared_task(soft_time_limit=0.5 * HOUR, time_limit=1 * HOUR)
Expand All @@ -76,30 +86,42 @@ def find_duplicates(dedup_job_id: int, version: int) -> None:
deduplication_set = dedup_job.deduplication_set

deduplication_set.state = DeduplicationSet.State.DIRTY
deduplication_set.save()
deduplication_set.save(update_fields=["state"])
send_notification(deduplication_set.notification_url)

config = asdict(
DeduplicationSetConfig.from_deduplication_set(deduplication_set)
)

# clean results
Finding.objects.filter(deduplication_set=deduplication_set).delete()
dedup_job.progress = 0
dedup_job.save(update_fields=["progress"])

weight_total = 0
# weight_total = 0
# for finder, tracker in zip(
for finder, _ in zip(
get_finders(deduplication_set),
track_progress_multi(partial(update_job_progress, dedup_job)),
):
# _save_duplicates(finder, deduplication_set, tracker)
weight_total += finder.weight

# for finder, _ in zip(
# get_finders(deduplication_set),
# track_progress_multi(partial(update_job_progress, dedup_job)),
# ):
# # _save_duplicates(finder, deduplication_set, tracker)
# weight_total += finder.weight

weight_total = 1
deduplication_set.finding_set.update(score=F("score") / weight_total)

for finder, tracker in zip(
get_finders(deduplication_set),
track_progress_multi(partial(update_job_progress, dedup_job)),
):
for first, second, score in finder.run(tracker):
finding = (first, second, score * finder.weight)
deduplication_set.update_findings(finding)
files = deduplication_set.image_set.values_list("filename", flat=True)
chunks = get_chunks(files, (len(files) // CHUNK_SIZE) + 1)
tasks = [encode_chunk.s(chunk, config) for n, chunk in enumerate(chunks)]
chord(tasks)(callback_encodings.s(config=config))

# for finder, tracker in zip(
# get_finders(deduplication_set),
# track_progress_multi(partial(update_job_progress, dedup_job)),
# ):
# for first, second, score in finder.run(tracker):
# finding = (first, second, score * finder.weight)
# deduplication_set.update_findings(finding)

deduplication_set.state = deduplication_set.State.CLEAN
deduplication_set.save(update_fields=["state"])
Expand Down
Loading

0 comments on commit fbea6b1

Please sign in to comment.