From 9b2cb9e8c6dd6b99ead048e2736b28c863a89fc0 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 3 Feb 2025 19:57:29 +0200 Subject: [PATCH 1/4] Refactor bulk_create calls --- cvat/apps/dataset_manager/bindings.py | 2 +- cvat/apps/dataset_manager/project.py | 3 +- cvat/apps/dataset_manager/task.py | 15 +---- cvat/apps/dataset_manager/util.py | 57 ----------------- cvat/apps/engine/default_settings.py | 2 + cvat/apps/engine/model_utils.py | 89 ++++++++++++++++++++++++++- cvat/apps/engine/serializers.py | 28 +++++---- cvat/apps/engine/task.py | 22 ++++--- cvat/apps/engine/views.py | 2 +- 9 files changed, 125 insertions(+), 95 deletions(-) diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index fc74ee53d0b7..af6c558488d8 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -27,10 +27,10 @@ from django.utils import timezone from cvat.apps.dataset_manager.formats.utils import get_label_color -from cvat.apps.dataset_manager.util import add_prefetch_fields from cvat.apps.engine import models from cvat.apps.engine.frame_provider import FrameOutputType, FrameQuality, TaskFrameProvider from cvat.apps.engine.lazy_list import LazyList +from cvat.apps.engine.model_utils import add_prefetch_fields from cvat.apps.engine.models import ( AttributeSpec, AttributeType, diff --git a/cvat/apps/dataset_manager/project.py b/cvat/apps/dataset_manager/project.py index 162a1ef8a5bd..ae03e480aa25 100644 --- a/cvat/apps/dataset_manager/project.py +++ b/cvat/apps/dataset_manager/project.py @@ -17,6 +17,7 @@ from cvat.apps.dataset_manager.util import TmpDirManager from cvat.apps.engine import models from cvat.apps.engine.log import DatasetLogManager +from cvat.apps.engine.model_utils import bulk_create from cvat.apps.engine.rq_job_handler import RQJobMetaField from cvat.apps.engine.serializers import DataSerializer, TaskWriteSerializer from cvat.apps.engine.task import _create_thread as create_task @@ -128,7 +129,7 @@ def add_labels(self, labels: list[models.Label], attributes: list[tuple[str, mod label, = filter(lambda l: l.name == label_name, labels) attribute.label = label if attributes: - models.AttributeSpec.objects.bulk_create([a[1] for a in attributes]) + bulk_create(models.AttributeSpec, [a[1] for a in attributes]) def init_from_db(self): self.reset() diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index 147772593ac2..ded9a77cccff 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -25,15 +25,10 @@ TaskData, ) from cvat.apps.dataset_manager.formats.registry import make_exporter, make_importer -from cvat.apps.dataset_manager.util import ( - TmpDirManager, - add_prefetch_fields, - bulk_create, - faster_deepcopy, - get_cached, -) +from cvat.apps.dataset_manager.util import TmpDirManager, faster_deepcopy from cvat.apps.engine import models, serializers from cvat.apps.engine.log import DatasetLogManager +from cvat.apps.engine.model_utils import add_prefetch_fields, bulk_create, get_cached from cvat.apps.engine.plugins import plugin_decorator from cvat.apps.engine.utils import take_by from cvat.apps.events.handlers import handle_annotations_change @@ -296,7 +291,6 @@ def create_tracks(tracks, parent_track=None): bulk_create( db_model=models.LabeledTrackAttributeVal, objects=db_track_attr_vals, - flt_param={} ) for db_shape in db_shapes: @@ -314,7 +308,6 @@ def create_tracks(tracks, parent_track=None): bulk_create( db_model=models.TrackedShapeAttributeVal, objects=db_shape_attr_vals, - flt_param={} ) shape_idx = 0 @@ -367,7 +360,6 @@ def create_shapes(shapes, parent_shape=None): bulk_create( db_model=models.LabeledShapeAttributeVal, objects=db_attr_vals, - flt_param={} ) for shape, db_shape in zip(shapes, db_shapes): @@ -410,8 +402,7 @@ def _save_tags_to_db(self, tags): bulk_create( db_model=models.LabeledImageAttributeVal, - objects=db_attr_vals, - flt_param={} + objects=db_attr_vals ) for tag, db_tag in zip(tags, db_tags): diff --git a/cvat/apps/dataset_manager/util.py b/cvat/apps/dataset_manager/util.py index c88292809785..3b1bbbcc15ad 100644 --- a/cvat/apps/dataset_manager/util.py +++ b/cvat/apps/dataset_manager/util.py @@ -22,7 +22,6 @@ from datumaro.util import to_snake_case from datumaro.util.os_util import make_file_name from django.conf import settings -from django.db import models from pottery import Redlock @@ -38,62 +37,6 @@ def make_zip_archive(src_path, dst_path): archive.write(path, osp.relpath(path, src_path)) -_ModelT = TypeVar("_ModelT", bound=models.Model) - -def bulk_create( - db_model: type[_ModelT], - objects: Iterable[_ModelT], - *, - flt_param: dict[str, Any] | None = None, - batch_size: int | None = 10000 -) -> list[_ModelT]: - if objects: - if flt_param: - if "postgresql" in settings.DATABASES["default"]["ENGINE"]: - return db_model.objects.bulk_create(objects, batch_size=batch_size) - else: - ids = list(db_model.objects.filter(**flt_param).values_list('id', flat=True)) - db_model.objects.bulk_create(objects, batch_size=batch_size) - - return list(db_model.objects.exclude(id__in=ids).filter(**flt_param)) - else: - return db_model.objects.bulk_create(objects, batch_size=batch_size) - - return [] - - -def is_prefetched(queryset: models.QuerySet, field: str) -> bool: - return field in queryset._prefetch_related_lookups - - -def add_prefetch_fields(queryset: models.QuerySet, fields: Sequence[str]) -> models.QuerySet: - for field in fields: - if not is_prefetched(queryset, field): - queryset = queryset.prefetch_related(field) - - return queryset - - -def get_cached(queryset: models.QuerySet, pk: int) -> models.Model: - """ - Like regular queryset.get(), but checks for the cached values first - instead of just making a request. - """ - - # Read more about caching insights: - # https://www.mattduck.com/2021-01-django-orm-result-cache.html - # The field is initialized on accessing the query results, eg. on iteration - if getattr(queryset, '_result_cache'): - result = next((obj for obj in queryset if obj.pk == pk), None) - else: - result = None - - if result is None: - result = queryset.get(id=pk) - - return result - - def faster_deepcopy(v): "A slightly optimized version of the default deepcopy, can be used as a drop-in replacement." # Default deepcopy is very slow, here we do shallow copy for primitive types and containers diff --git a/cvat/apps/engine/default_settings.py b/cvat/apps/engine/default_settings.py index e679c90aeab4..a7ee3a2e3ce3 100644 --- a/cvat/apps/engine/default_settings.py +++ b/cvat/apps/engine/default_settings.py @@ -95,3 +95,5 @@ MAX_CONSENSUS_REPLICAS = int(os.getenv("CVAT_MAX_CONSENSUS_REPLICAS", 11)) if MAX_CONSENSUS_REPLICAS < 1: raise ImproperlyConfigured(f"MAX_CONSENSUS_REPLICAS must be >= 1, got {MAX_CONSENSUS_REPLICAS}") + +DEFAULT_DB_BULK_CREATE_BATCH_SIZE = int(os.getenv("CVAT_DEFAULT_DB_BULK_CREATE_BATCH_SIZE", 1000)) diff --git a/cvat/apps/engine/model_utils.py b/cvat/apps/engine/model_utils.py index 45f3a16b9031..ee8c0f89186c 100644 --- a/cvat/apps/engine/model_utils.py +++ b/cvat/apps/engine/model_utils.py @@ -4,7 +4,11 @@ from __future__ import annotations -from typing import TypeVar, Union +from collections.abc import Iterable +from typing import Any, Sequence, TypeVar, Union + +from django.conf import settings +from django.db import models _T = TypeVar("_T") @@ -18,3 +22,86 @@ class Undefined: The reverse side of one-to-one relationship. May be undefined in the object, should be accessed via getattr(). """ + + +_ModelT = TypeVar("_ModelT", bound=models.Model) + + +def bulk_create( + db_model: type[_ModelT], + objs: Iterable[_ModelT], + batch_size: int | None = ..., + ignore_conflicts: bool = ..., + update_conflicts: bool | None = ..., + update_fields: Sequence[str] | None = ..., + unique_fields: Sequence[str] | None = ..., + *, + flt_param: dict[str, Any] | None = None, +) -> list[_ModelT]: + "Like Django's Model.objects.bulk_create(), but applies the default batch size" + + if batch_size is Ellipsis: + batch_size = settings.DEFAULT_DB_BULK_CREATE_BATCH_SIZE + + kwargs = {} + for k, v in { + "ignore_conflicts": ignore_conflicts, + "update_conflicts": update_conflicts, + "update_fields": update_fields, + "unique_fields": unique_fields, + }.items(): + if v is not Ellipsis: + kwargs[k] = v + + if not objs: + return [] + + flt_param = flt_param or {} + + if flt_param: + if "postgresql" in settings.DATABASES["default"]["ENGINE"]: + return db_model.objects.bulk_create(objs, batch_size=batch_size, **kwargs) + else: + # imitate RETURNING + ids = list(db_model.objects.filter(**flt_param).values_list("id", flat=True)) + db_model.objects.bulk_create(objs, batch_size=batch_size, **kwargs) + + return list(db_model.objects.exclude(id__in=ids).filter(**flt_param)) + else: + return db_model.objects.bulk_create(objs, batch_size=batch_size, **kwargs) + + +def is_prefetched(queryset: models.QuerySet, field: str) -> bool: + "Checks if a field is being prefetched in the queryset" + return field in queryset._prefetch_related_lookups + + +_QuerysetT = TypeVar("_QuerysetT", bound=models.QuerySet) + + +def add_prefetch_fields(queryset: _QuerysetT, fields: Sequence[str]) -> _QuerysetT: + for field in fields: + if not is_prefetched(queryset, field): + queryset = queryset.prefetch_related(field) + + return queryset + + +def get_cached(queryset: _QuerysetT, pk: int) -> _ModelT: + """ + Like regular queryset.get(), but checks for the cached values first + instead of just making a request. + """ + + # Read more about caching insights: + # https://www.mattduck.com/2021-01-django-orm-result-cache.html + # The field is initialized on accessing the query results, eg. on iteration + if getattr(queryset, "_result_cache"): + result = next((obj for obj in queryset if obj.pk == pk), None) + else: + result = None + + if result is None: + result = queryset.get(id=pk) + + return result diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 1d1124661f49..02849c5dbd67 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -39,6 +39,7 @@ from cvat.apps.engine.cloud_provider import Credentials, Status, get_cloud_storage_instance from cvat.apps.engine.frame_provider import FrameQuality, TaskFrameProvider from cvat.apps.engine.log import ServerLogManager +from cvat.apps.engine.model_utils import bulk_create from cvat.apps.engine.permissions import TaskPermission from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField from cvat.apps.engine.task_validation import HoneypotFrameSelector @@ -1596,7 +1597,8 @@ def _update_frames_in_bulk( # The django generated bulk_update() query is too slow, so we use bulk_create() instead # NOTE: Silk doesn't show these queries in the list of queries # for some reason, but they can be seen in the profile - models.Image.objects.bulk_create( + bulk_create( + models.Image, list(bulk_context.updated_honeypots.values()), update_conflicts=True, update_fields=['path', 'width', 'height', 'real_frame'], @@ -1605,7 +1607,6 @@ def _update_frames_in_bulk( # https://docs.djangoproject.com/en/4.2/ref/models/querysets/#bulk-create 'id' ], - batch_size=1000, ) # Update related images in 2 steps: remove all m2m for honeypots, then add (copy) new ones @@ -1653,7 +1654,7 @@ def _update_frames_in_bulk( for m2m_obj in validation_frame_m2m_objects ) - models.RelatedFile.images.through.objects.bulk_create(new_m2m_objects, batch_size=1000) + bulk_create(models.RelatedFile.images.through, new_m2m_objects) def _clear_annotations_on_frames(self, db_task: models.Task, frames: Sequence[int]): models.clear_annotations_on_frames_in_honeypot_task(db_task, frames=frames) @@ -1960,9 +1961,9 @@ def create(self, validated_data: dict[str, Any]) -> models.ValidationParams: instance = super().create(validated_data) if frames: - models.ValidationFrame.objects.bulk_create( - models.ValidationFrame(validation_params=instance, path=frame) - for frame in frames + bulk_create( + models.ValidationFrame, + [models.ValidationFrame(validation_params=instance, path=frame) for frame in frames] ) return instance @@ -1978,9 +1979,9 @@ def update( if frames: models.ValidationFrame.objects.filter(validation_params=instance).delete() - models.ValidationFrame.objects.bulk_create( - models.ValidationFrame(validation_params=instance, path=frame) - for frame in frames + bulk_create( + models.ValidationFrame, + [models.ValidationFrame(validation_params=instance, path=frame) for frame in frames] ) return instance @@ -2210,8 +2211,9 @@ def _create_files(self, instance, files): (models.ClientFile, models.ServerFile, models.RemoteFile), ): if files_type in files: - files_model.objects.bulk_create( - files_model(data=instance, **f) for f in files[files_type] + bulk_create( + files_model, + [files_model(data=instance, **f) for f in files[files_type]] ) class TaskReadSerializer(serializers.ModelSerializer): @@ -3215,7 +3217,7 @@ def create(self, validated_data): db_storage.save() manifest_file_instances = [models.Manifest(filename=manifest, cloud_storage=db_storage) for manifest in manifests] - models.Manifest.objects.bulk_create(manifest_file_instances) + bulk_create(models.Manifest, manifest_file_instances) cloud_storage_path = db_storage.get_storage_dirname() if os.path.isdir(cloud_storage_path): @@ -3294,7 +3296,7 @@ def update(self, instance, validated_data): # check manifest files existing self._manifests_validation(storage, delta_to_create) manifest_instances = [models.Manifest(filename=f, cloud_storage=instance) for f in delta_to_create] - models.Manifest.objects.bulk_create(manifest_instances) + bulk_create(models.Manifest, manifest_instances) if temporary_file: # so, gcs key file is valid and we need to set correct path to the file real_path_to_key_file = instance.get_key_file_path() diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index db4b951da98d..eb9c882e818c 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -46,6 +46,7 @@ load_image, sort, ) +from cvat.apps.engine.model_utils import bulk_create from cvat.apps.engine.models import RequestAction, RequestTarget from cvat.apps.engine.rq_job_handler import RQId from cvat.apps.engine.task_validation import HoneypotFrameSelector @@ -1358,7 +1359,7 @@ def _update_status(msg: str) -> None: )) if db_task.mode == 'annotation': - images = models.Image.objects.bulk_create(images) + images = bulk_create(models.Image, images) db_related_files = [ models.RelatedFile( @@ -1367,20 +1368,23 @@ def _update_status(msg: str) -> None: ) for related_file_path in set(itertools.chain.from_iterable(related_images.values())) ] - db_related_files = models.RelatedFile.objects.bulk_create(db_related_files) + db_related_files = bulk_create(models.RelatedFile, db_related_files) db_related_files_by_path = { os.path.relpath(rf.path.path, upload_dir): rf for rf in db_related_files } ThroughModel = models.RelatedFile.images.through - models.RelatedFile.images.through.objects.bulk_create(( - ThroughModel( - relatedfile_id=db_related_files_by_path[related_file_path].id, - image_id=image.id + bulk_create( + ThroughModel, + ( + ThroughModel( + relatedfile_id=db_related_files_by_path[related_file_path].id, + image_id=image.id + ) + for image in images + for related_file_path in related_images.get(image.path, []) ) - for image in images - for related_file_path in related_images.get(image.path, []) - )) + ) else: models.Video.objects.create( data=db_data, diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 084945e7681d..1be85a4b9945 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -1153,7 +1153,7 @@ def _append_upload_info_entries(self, client_files: list[dict[str, Any]]): task_data.client_files.bulk_create([ ClientFile(file=self._prepare_upload_info_entry(cf['file'].name), data=task_data) for cf in client_files - ]) + ], batch_size=settings.DEFAULT_DB_BULK_CREATE_BATCH_SIZE) def _sort_uploaded_files(self, uploaded_files: list[str], ordering: list[str]) -> list[str]: """ From 4fb64ec889c0ba3e65a52c4dacda7e1e3f35d5b6 Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Mon, 3 Feb 2025 20:01:54 +0200 Subject: [PATCH 2/4] Fix linter errors --- cvat/apps/dataset_manager/task.py | 16 ++++++++-------- cvat/apps/dataset_manager/util.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py index ded9a77cccff..b19f421438af 100644 --- a/cvat/apps/dataset_manager/task.py +++ b/cvat/apps/dataset_manager/task.py @@ -281,7 +281,7 @@ def create_tracks(tracks, parent_track=None): db_tracks = bulk_create( db_model=models.LabeledTrack, - objects=db_tracks, + objs=db_tracks, flt_param={"job_id": self.db_job.id} ) @@ -290,7 +290,7 @@ def create_tracks(tracks, parent_track=None): bulk_create( db_model=models.LabeledTrackAttributeVal, - objects=db_track_attr_vals, + objs=db_track_attr_vals, ) for db_shape in db_shapes: @@ -298,7 +298,7 @@ def create_tracks(tracks, parent_track=None): db_shapes = bulk_create( db_model=models.TrackedShape, - objects=db_shapes, + objs=db_shapes, flt_param={"track__job_id": self.db_job.id} ) @@ -307,7 +307,7 @@ def create_tracks(tracks, parent_track=None): bulk_create( db_model=models.TrackedShapeAttributeVal, - objects=db_shape_attr_vals, + objs=db_shape_attr_vals, ) shape_idx = 0 @@ -350,7 +350,7 @@ def create_shapes(shapes, parent_shape=None): db_shapes = bulk_create( db_model=models.LabeledShape, - objects=db_shapes, + objs=db_shapes, flt_param={"job_id": self.db_job.id} ) @@ -359,7 +359,7 @@ def create_shapes(shapes, parent_shape=None): bulk_create( db_model=models.LabeledShapeAttributeVal, - objects=db_attr_vals, + objs=db_attr_vals, ) for shape, db_shape in zip(shapes, db_shapes): @@ -393,7 +393,7 @@ def _save_tags_to_db(self, tags): db_tags = bulk_create( db_model=models.LabeledImage, - objects=db_tags, + objs=db_tags, flt_param={"job_id": self.db_job.id} ) @@ -402,7 +402,7 @@ def _save_tags_to_db(self, tags): bulk_create( db_model=models.LabeledImageAttributeVal, - objects=db_attr_vals + objs=db_attr_vals ) for tag, db_tag in zip(tags, db_tags): diff --git a/cvat/apps/dataset_manager/util.py b/cvat/apps/dataset_manager/util.py index 3b1bbbcc15ad..4998a5d9dca2 100644 --- a/cvat/apps/dataset_manager/util.py +++ b/cvat/apps/dataset_manager/util.py @@ -9,13 +9,13 @@ import re import tempfile import zipfile -from collections.abc import Generator, Iterable, Sequence +from collections.abc import Generator from contextlib import contextmanager from copy import deepcopy from datetime import timedelta from enum import Enum from threading import Lock -from typing import Any, TypeVar +from typing import Any import attrs import django_rq From 0df394c4896c7e43e83f0a323a0816de28c1770e Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 4 Feb 2025 12:10:39 +0200 Subject: [PATCH 3/4] Fix import --- cvat/apps/quality_control/quality_reports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py index 8d16c47c7af3..39e9390238f5 100644 --- a/cvat/apps/quality_control/quality_reports.py +++ b/cvat/apps/quality_control/quality_reports.py @@ -39,9 +39,9 @@ ) from cvat.apps.dataset_manager.formats.registry import dm_env from cvat.apps.dataset_manager.task import JobAnnotation -from cvat.apps.dataset_manager.util import bulk_create from cvat.apps.engine import serializers as engine_serializers from cvat.apps.engine.frame_provider import TaskFrameProvider +from cvat.apps.engine.model_utils import bulk_create from cvat.apps.engine.models import ( DimensionType, Image, From f3e8bc6bea4f9582891ef43e6e79f17b62c7760b Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Tue, 4 Feb 2025 12:21:54 +0200 Subject: [PATCH 4/4] Update calls --- cvat/apps/quality_control/quality_reports.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py index 39e9390238f5..3ad3c8e97359 100644 --- a/cvat/apps/quality_control/quality_reports.py +++ b/cvat/apps/quality_control/quality_reports.py @@ -2540,7 +2540,7 @@ def _save_reports(self, *, task_report: dict, job_reports: list[dict]) -> models ) db_job_reports.append(db_job_report) - db_job_reports = bulk_create(db_model=models.QualityReport, objects=db_job_reports) + db_job_reports = bulk_create(db_model=models.QualityReport, objs=db_job_reports) db_conflicts = [] db_report_iter = itertools.chain([db_task_report], db_job_reports) @@ -2555,7 +2555,7 @@ def _save_reports(self, *, task_report: dict, job_reports: list[dict]) -> models ) db_conflicts.append(db_conflict) - db_conflicts = bulk_create(db_model=models.AnnotationConflict, objects=db_conflicts) + db_conflicts = bulk_create(db_model=models.AnnotationConflict, objs=db_conflicts) db_ann_ids = [] db_conflicts_iter = iter(db_conflicts) @@ -2571,7 +2571,7 @@ def _save_reports(self, *, task_report: dict, job_reports: list[dict]) -> models ) db_ann_ids.append(db_ann_id) - db_ann_ids = bulk_create(db_model=models.AnnotationId, objects=db_ann_ids) + db_ann_ids = bulk_create(db_model=models.AnnotationId, objs=db_ann_ids) return db_task_report