Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor bulk_create calls #9047

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from django.utils import timezone

from cvat.apps.dataset_manager.formats.utils import get_label_color
from cvat.apps.dataset_manager.util import add_prefetch_fields
from cvat.apps.engine import models
from cvat.apps.engine.frame_provider import FrameOutputType, FrameQuality, TaskFrameProvider
from cvat.apps.engine.lazy_list import LazyList
from cvat.apps.engine.model_utils import add_prefetch_fields
from cvat.apps.engine.models import (
AttributeSpec,
AttributeType,
Expand Down
3 changes: 2 additions & 1 deletion cvat/apps/dataset_manager/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from cvat.apps.dataset_manager.util import TmpDirManager
from cvat.apps.engine import models
from cvat.apps.engine.log import DatasetLogManager
from cvat.apps.engine.model_utils import bulk_create
from cvat.apps.engine.rq_job_handler import RQJobMetaField
from cvat.apps.engine.serializers import DataSerializer, TaskWriteSerializer
from cvat.apps.engine.task import _create_thread as create_task
Expand Down Expand Up @@ -128,7 +129,7 @@ def add_labels(self, labels: list[models.Label], attributes: list[tuple[str, mod
label, = filter(lambda l: l.name == label_name, labels)
attribute.label = label
if attributes:
models.AttributeSpec.objects.bulk_create([a[1] for a in attributes])
bulk_create(models.AttributeSpec, [a[1] for a in attributes])

def init_from_db(self):
self.reset()
Expand Down
29 changes: 10 additions & 19 deletions cvat/apps/dataset_manager/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,10 @@
TaskData,
)
from cvat.apps.dataset_manager.formats.registry import make_exporter, make_importer
from cvat.apps.dataset_manager.util import (
TmpDirManager,
add_prefetch_fields,
bulk_create,
faster_deepcopy,
get_cached,
)
from cvat.apps.dataset_manager.util import TmpDirManager, faster_deepcopy
from cvat.apps.engine import models, serializers
from cvat.apps.engine.log import DatasetLogManager
from cvat.apps.engine.model_utils import add_prefetch_fields, bulk_create, get_cached
from cvat.apps.engine.plugins import plugin_decorator
from cvat.apps.engine.utils import take_by
from cvat.apps.events.handlers import handle_annotations_change
Expand Down Expand Up @@ -286,7 +281,7 @@ def create_tracks(tracks, parent_track=None):

db_tracks = bulk_create(
db_model=models.LabeledTrack,
objects=db_tracks,
objs=db_tracks,
flt_param={"job_id": self.db_job.id}
)

Expand All @@ -295,16 +290,15 @@ def create_tracks(tracks, parent_track=None):

bulk_create(
db_model=models.LabeledTrackAttributeVal,
objects=db_track_attr_vals,
flt_param={}
objs=db_track_attr_vals,
)

for db_shape in db_shapes:
db_shape.track_id = db_tracks[db_shape.track_id].id

db_shapes = bulk_create(
db_model=models.TrackedShape,
objects=db_shapes,
objs=db_shapes,
flt_param={"track__job_id": self.db_job.id}
)

Expand All @@ -313,8 +307,7 @@ def create_tracks(tracks, parent_track=None):

bulk_create(
db_model=models.TrackedShapeAttributeVal,
objects=db_shape_attr_vals,
flt_param={}
objs=db_shape_attr_vals,
)

shape_idx = 0
Expand Down Expand Up @@ -357,7 +350,7 @@ def create_shapes(shapes, parent_shape=None):

db_shapes = bulk_create(
db_model=models.LabeledShape,
objects=db_shapes,
objs=db_shapes,
flt_param={"job_id": self.db_job.id}
)

Expand All @@ -366,8 +359,7 @@ def create_shapes(shapes, parent_shape=None):

bulk_create(
db_model=models.LabeledShapeAttributeVal,
objects=db_attr_vals,
flt_param={}
objs=db_attr_vals,
)

for shape, db_shape in zip(shapes, db_shapes):
Expand Down Expand Up @@ -401,7 +393,7 @@ def _save_tags_to_db(self, tags):

db_tags = bulk_create(
db_model=models.LabeledImage,
objects=db_tags,
objs=db_tags,
flt_param={"job_id": self.db_job.id}
)

Expand All @@ -410,8 +402,7 @@ def _save_tags_to_db(self, tags):

bulk_create(
db_model=models.LabeledImageAttributeVal,
objects=db_attr_vals,
flt_param={}
objs=db_attr_vals
)

for tag, db_tag in zip(tags, db_tags):
Expand Down
61 changes: 2 additions & 59 deletions cvat/apps/dataset_manager/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,19 @@
import re
import tempfile
import zipfile
from collections.abc import Generator, Iterable, Sequence
from collections.abc import Generator
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
from enum import Enum
from threading import Lock
from typing import Any, TypeVar
from typing import Any

import attrs
import django_rq
from datumaro.util import to_snake_case
from datumaro.util.os_util import make_file_name
from django.conf import settings
from django.db import models
from pottery import Redlock


Expand All @@ -38,62 +37,6 @@ def make_zip_archive(src_path, dst_path):
archive.write(path, osp.relpath(path, src_path))


_ModelT = TypeVar("_ModelT", bound=models.Model)

def bulk_create(
db_model: type[_ModelT],
objects: Iterable[_ModelT],
*,
flt_param: dict[str, Any] | None = None,
batch_size: int | None = 10000
) -> list[_ModelT]:
if objects:
if flt_param:
if "postgresql" in settings.DATABASES["default"]["ENGINE"]:
return db_model.objects.bulk_create(objects, batch_size=batch_size)
else:
ids = list(db_model.objects.filter(**flt_param).values_list('id', flat=True))
db_model.objects.bulk_create(objects, batch_size=batch_size)

return list(db_model.objects.exclude(id__in=ids).filter(**flt_param))
else:
return db_model.objects.bulk_create(objects, batch_size=batch_size)

return []


def is_prefetched(queryset: models.QuerySet, field: str) -> bool:
return field in queryset._prefetch_related_lookups


def add_prefetch_fields(queryset: models.QuerySet, fields: Sequence[str]) -> models.QuerySet:
for field in fields:
if not is_prefetched(queryset, field):
queryset = queryset.prefetch_related(field)

return queryset


def get_cached(queryset: models.QuerySet, pk: int) -> models.Model:
"""
Like regular queryset.get(), but checks for the cached values first
instead of just making a request.
"""

# Read more about caching insights:
# https://www.mattduck.com/2021-01-django-orm-result-cache.html
# The field is initialized on accessing the query results, eg. on iteration
if getattr(queryset, '_result_cache'):
result = next((obj for obj in queryset if obj.pk == pk), None)
else:
result = None

if result is None:
result = queryset.get(id=pk)

return result


def faster_deepcopy(v):
"A slightly optimized version of the default deepcopy, can be used as a drop-in replacement."
# Default deepcopy is very slow, here we do shallow copy for primitive types and containers
Expand Down
2 changes: 2 additions & 0 deletions cvat/apps/engine/default_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,5 @@
MAX_CONSENSUS_REPLICAS = int(os.getenv("CVAT_MAX_CONSENSUS_REPLICAS", 11))
if MAX_CONSENSUS_REPLICAS < 1:
raise ImproperlyConfigured(f"MAX_CONSENSUS_REPLICAS must be >= 1, got {MAX_CONSENSUS_REPLICAS}")

DEFAULT_DB_BULK_CREATE_BATCH_SIZE = int(os.getenv("CVAT_DEFAULT_DB_BULK_CREATE_BATCH_SIZE", 1000))
89 changes: 88 additions & 1 deletion cvat/apps/engine/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

from __future__ import annotations

from typing import TypeVar, Union
from collections.abc import Iterable
from typing import Any, Sequence, TypeVar, Union

from django.conf import settings
from django.db import models

_T = TypeVar("_T")

Expand All @@ -18,3 +22,86 @@ class Undefined:
The reverse side of one-to-one relationship.
May be undefined in the object, should be accessed via getattr().
"""


_ModelT = TypeVar("_ModelT", bound=models.Model)


def bulk_create(
db_model: type[_ModelT],
objs: Iterable[_ModelT],
batch_size: int | None = ...,
ignore_conflicts: bool = ...,
update_conflicts: bool | None = ...,
update_fields: Sequence[str] | None = ...,
unique_fields: Sequence[str] | None = ...,
*,
flt_param: dict[str, Any] | None = None,
) -> list[_ModelT]:
"Like Django's Model.objects.bulk_create(), but applies the default batch size"

if batch_size is Ellipsis:
batch_size = settings.DEFAULT_DB_BULK_CREATE_BATCH_SIZE

kwargs = {}
for k, v in {
"ignore_conflicts": ignore_conflicts,
"update_conflicts": update_conflicts,
"update_fields": update_fields,
"unique_fields": unique_fields,
}.items():
if v is not Ellipsis:
kwargs[k] = v

if not objs:
return []

flt_param = flt_param or {}

if flt_param:
if "postgresql" in settings.DATABASES["default"]["ENGINE"]:
return db_model.objects.bulk_create(objs, batch_size=batch_size, **kwargs)
else:
# imitate RETURNING
ids = list(db_model.objects.filter(**flt_param).values_list("id", flat=True))
db_model.objects.bulk_create(objs, batch_size=batch_size, **kwargs)

return list(db_model.objects.exclude(id__in=ids).filter(**flt_param))
else:
return db_model.objects.bulk_create(objs, batch_size=batch_size, **kwargs)


def is_prefetched(queryset: models.QuerySet, field: str) -> bool:
"Checks if a field is being prefetched in the queryset"
return field in queryset._prefetch_related_lookups


_QuerysetT = TypeVar("_QuerysetT", bound=models.QuerySet)


def add_prefetch_fields(queryset: _QuerysetT, fields: Sequence[str]) -> _QuerysetT:
for field in fields:
if not is_prefetched(queryset, field):
queryset = queryset.prefetch_related(field)

return queryset


def get_cached(queryset: _QuerysetT, pk: int) -> _ModelT:
"""
Like regular queryset.get(), but checks for the cached values first
instead of just making a request.
"""

# Read more about caching insights:
# https://www.mattduck.com/2021-01-django-orm-result-cache.html
# The field is initialized on accessing the query results, eg. on iteration
if getattr(queryset, "_result_cache"):
result = next((obj for obj in queryset if obj.pk == pk), None)
else:
result = None

if result is None:
result = queryset.get(id=pk)

return result
Loading
Loading