Skip to content

Commit

Permalink
Merge pull request #348 from DagsHub/classify-metadata-fields
Browse files Browse the repository at this point in the history
Add enrichment tag annotation
  • Loading branch information
kbolashev authored Oct 1, 2023
2 parents 111a301 + 8b394b7 commit d41eda0
Show file tree
Hide file tree
Showing 14 changed files with 326 additions and 45 deletions.
18 changes: 16 additions & 2 deletions dagshub/data_engine/client/data_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from dagshub.common.analytics import send_analytics_event
from dagshub.common.rich_util import get_rich_progress
from dagshub.data_engine.client.models import DatasourceResult, DatasourceType, IntegrationStatus, \
PreprocessingStatus, DatasetResult, MetadataFieldType, ScanOption
PreprocessingStatus, DatasetResult, MetadataFieldSchema
from dagshub.data_engine.dtypes import MetadataFieldType
from dagshub.data_engine.client.models import ScanOption
from dagshub.data_engine.client.gql_mutations import GqlMutations
from dagshub.data_engine.client.gql_queries import GqlQueries
from dagshub.data_engine.model.datasource import Datasource, DatapointMetadataUpdateEntry
Expand All @@ -31,7 +33,6 @@ class DataClient:
FULL_LIST_PAGE_SIZE = 5000

def __init__(self, repo: str):
# TODO: add project authentication here
self.repo = repo
self.host = config.host
self.client = self._init_client()
Expand Down Expand Up @@ -156,6 +157,19 @@ def update_metadata(self, datasource: Datasource, entries: List[DatapointMetadat

return self._exec(q, params)

def update_metadata_fields(self, datasource: Datasource, metadata_field_props: List[MetadataFieldSchema]):
q = GqlMutations.update_metadata_field()

assert datasource.source.id is not None
# assert len(entries) > 0

params = GqlMutations.update_metadata_fields_params(
datasource_id=datasource.source.id,
metadata_field_props=[e.to_dict() for e in metadata_field_props]
)

return self._exec(q, params)

def get_datasources(self, id: Optional[str], name: Optional[str]) -> List[DatasourceResult]:
q = GqlQueries.datasource()
params = GqlQueries.datasource_params(id=id, name=name)
Expand Down
30 changes: 30 additions & 0 deletions dagshub/data_engine/client/gql_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,43 @@ def update_metadata():
]).generate()
return q

@staticmethod
@functools.lru_cache()
def update_metadata_field():
q = GqlQuery().operation(
"mutation",
name="updateMetadataFieldProps",
input={
"$datasource": "ID!",
"$props": "[MetadataFieldPropsInput!]!"
}
).query(
"updateMetadataFieldProps",
input={
"datasource": "$datasource",
"props": "$props"
}
).fields([
"name",
"valueType",
"multiple",
"tags",
]).generate()
return q

@staticmethod
def update_metadata_params(datasource_id: Union[int, str], datapoints: List[Dict[str, Any]]):
return {
"datasource": datasource_id,
"datapoints": datapoints,
}

def update_metadata_fields_params(datasource_id: Union[int, str], metadata_field_props: List[Dict[str, Any]]):
return {
"datasource": datasource_id,
"props": metadata_field_props
}

@staticmethod
@functools.lru_cache()
def delete_datasource():
Expand Down
4 changes: 2 additions & 2 deletions dagshub/data_engine/client/gql_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def datasource() -> str:
"rootUrl",
"integrationStatus",
"preprocessingStatus",
"metadataFields {name valueType multiple}"
"metadataFields {name valueType multiple tags}"
"type",
]).generate()
return q
Expand Down Expand Up @@ -97,7 +97,7 @@ def dataset() -> str:
"id",
"name",
"datasource {id name rootUrl integrationStatus preprocessingStatus "
"metadataFields {name valueType multiple} type}",
"metadataFields {name valueType multiple tags} type}",
"datasetQuery",
]).generate()
return q
Expand Down
19 changes: 10 additions & 9 deletions dagshub/data_engine/client/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, List, Union, Optional

from dataclasses_json import dataclass_json, config
from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,31 +41,31 @@ class DatasourceType(enum.Enum):
CUSTOM = "CUSTOM"


class MetadataFieldType(enum.Enum):
BOOLEAN = "BOOLEAN"
INTEGER = "INTEGER"
FLOAT = "FLOAT"
STRING = "STRING"
BLOB = "BLOB"


class ScanOption(str, enum.Enum):
FORCE_REGENERATE_AUTO_SCAN_VALUES = "FORCE_REGENERATE_AUTO_SCAN_VALUES"


@dataclass_json
@dataclass
class MetadataFieldSchema:
# This should match the GraphQL schema: MetadataFieldProps
name: str
valueType: MetadataFieldType = field(
metadata=config(
encoder=lambda val: val.value
)
)
multiple: bool
tags: Optional[List[str]]

def __repr__(self):
return f"{self.name} ({self.valueType.value})"
res = f"{self.name} ({self.valueType.value})"
if self.tags is not None and len(self.tags) > 0:
res += f" with tags: {self.tags}"
return res

def is_annotation(self):
return ReservedTags.ANNOTATION.value in self.tags if self.tags else False


@dataclass
Expand Down
58 changes: 58 additions & 0 deletions dagshub/data_engine/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import enum
from abc import ABCMeta
from typing import List


class ReservedTags(enum.Enum):
ANNOTATION = "annotation"


# These are the base primitives that the data engine database is capable of storing
class MetadataFieldType(enum.Enum):
BOOLEAN = "BOOLEAN"
INTEGER = "INTEGER"
FLOAT = "FLOAT"
STRING = "STRING"
BLOB = "BLOB"


# Inheritors of this ABC define custom types
# They are backed by a primitive type, but they also may have additional tags, describing specialized behavior
class DagshubDataType(metaclass=ABCMeta):
"""
Attributes:
backing_field_type: primitive type in the data engine database
custom_tags: additional tags applied to this type
"""
backing_field_type: MetadataFieldType = None
custom_tags: List[str] = None


class Int(DagshubDataType):
backing_field_type = MetadataFieldType.INTEGER


class String(DagshubDataType):
backing_field_type = MetadataFieldType.STRING


class Blob(DagshubDataType):
backing_field_type = MetadataFieldType.BLOB


class Float(DagshubDataType):
backing_field_type = MetadataFieldType.FLOAT


class Bool(DagshubDataType):
backing_field_type = MetadataFieldType.BOOLEAN


class LabelStudioAnnotation(DagshubDataType):
backing_field_type = MetadataFieldType.BLOB
custom_tags = [ReservedTags.ANNOTATION.value]


class Voxel51Annotation(DagshubDataType):
backing_field_type = MetadataFieldType.BLOB
custom_tags = [ReservedTags.ANNOTATION.value]
2 changes: 1 addition & 1 deletion dagshub/data_engine/model/datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dagshub.common.download import download_files
from dagshub.common.helpers import http_request
from dagshub.data_engine.client.models import MetadataFieldType
from dagshub.data_engine.dtypes import MetadataFieldType

if TYPE_CHECKING:
from dagshub.data_engine.model.datasource import Datasource
Expand Down
42 changes: 33 additions & 9 deletions dagshub/data_engine/model/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
from dagshub.common.helpers import prompt_user, http_request, log_message
from dagshub.common.rich_util import get_rich_progress
from dagshub.common.util import lazy_load, multi_urljoin
from dagshub.data_engine.client.models import PreprocessingStatus, MetadataFieldType, MetadataFieldSchema, \
from dagshub.data_engine.client.models import PreprocessingStatus, MetadataFieldSchema, \
ScanOption, autogenerated_columns
from dagshub.data_engine.dtypes import MetadataFieldType
from dagshub.data_engine.model.datapoint import Datapoint
from dagshub.data_engine.model.errors import WrongOperatorError, WrongOrderError, DatasetFieldComparisonError, \
FieldNotFoundError
from dagshub.data_engine.model.query import DatasourceQuery, _metadataTypeLookup, _metadataTypeLookupReverse
from dagshub.data_engine.model.metadata_field_builder import MetadataFieldBuilder
from dagshub.data_engine.model.query import DatasourceQuery
from dagshub.data_engine.model.schema_util import metadataTypeLookup, metadataTypeLookupReverse

if TYPE_CHECKING:
from dagshub.data_engine.model.query_result import QueryResult
Expand Down Expand Up @@ -84,8 +87,7 @@ def get_query(self):

@property
def annotation_fields(self) -> List[str]:
# TODO: once the annotation type is implemented, expose those columns here
return ["annotation"]
return [f.name for f in self.fields if f.is_annotation()]

def serialize_gql_query_input(self):
return {
Expand Down Expand Up @@ -123,6 +125,27 @@ def _check_preprocess(self):
f"Datasource {self.source.name} is currently in the progress of rescanning. "
f"Values might change if you requery later")

def metadata_field(self, field_name: str) -> MetadataFieldBuilder:
"""
Returns a builder for a metadata field.
The builder can be used to change properties of a field or create a new field altogether
Example of creating a new annotation field:
ds.metadata_field("annotation").set_type(dtypes.LabelStudioAnnotation).apply()
NOTE: New fields have to have their type defined using .set_type() before doing anything else
Example of marking an existing field as an annotation field:
ds.metadata_field("existing-field").set_annotation().apply()
"""
return MetadataFieldBuilder(self, field_name)

def apply_field_changes(self, field_builders: List[MetadataFieldBuilder]):
"""
Applies one or multiple metadata field builders that can be constructed using the metadata_field() function
"""
self.source.client.update_metadata_fields(self, [builder.schema for builder in field_builders])
self.source.get_from_dagshub()

def metadata_context(self) -> ContextManager["MetadataContextManager"]:
"""
Returns a metadata context, that you can upload metadata through via update_metadata
Expand Down Expand Up @@ -204,7 +227,7 @@ def _df_to_metadata(self, df: "pandas.DataFrame", path_column: Optional[Union[st
for sub_val in val:
value_type = field_value_types.get(key)
if value_type is None:
value_type = _metadataTypeLookup[type(sub_val)]
value_type = metadataTypeLookup[type(sub_val)]
field_value_types[key] = value_type
# Don't override bytes if they're not bytes - probably just undownloaded values
if value_type == MetadataFieldType.BLOB and type(sub_val) is not bytes:
Expand All @@ -224,7 +247,7 @@ def _df_to_metadata(self, df: "pandas.DataFrame", path_column: Optional[Union[st
else:
value_type = field_value_types.get(key)
if value_type is None:
value_type = _metadataTypeLookup[type(val)]
value_type = metadataTypeLookup[type(val)]
field_value_types[key] = value_type
# Don't override bytes if they're not bytes - probably just undownloaded values
if value_type == MetadataFieldType.BLOB and type(val) is not bytes:
Expand Down Expand Up @@ -545,7 +568,7 @@ def contains(self, item: str):

def is_null(self):
field = self._get_filtering_field()
value_type = _metadataTypeLookupReverse[field.valueType.value]
value_type = metadataTypeLookupReverse[field.valueType.value]
return self.add_query_op("isnull", value_type())

def is_not_null(self):
Expand Down Expand Up @@ -625,7 +648,7 @@ def update_metadata(self, datapoints: Union[List[str], str], metadata: Dict[str,

value_type = field_value_types.get(k)
if value_type is None:
value_type = _metadataTypeLookup[type(sub_val)]
value_type = metadataTypeLookup[type(sub_val)]
field_value_types[k] = value_type
# Don't override bytes if they're not bytes - probably just undownloaded values
if value_type == MetadataFieldType.BLOB and type(sub_val) is not bytes:
Expand All @@ -646,14 +669,15 @@ def update_metadata(self, datapoints: Union[List[str], str], metadata: Dict[str,

value_type = field_value_types.get(k)
if value_type is None:
value_type = _metadataTypeLookup[type(v)]
value_type = metadataTypeLookup[type(v)]
field_value_types[k] = value_type
# Don't override bytes if they're not bytes - probably just undownloaded values
if value_type == MetadataFieldType.BLOB and type(v) is not bytes:
continue

if type(v) is bytes:
v = self.wrap_bytes(v)

self._metadata_entries.append(DatapointMetadataUpdateEntry(
url=dp,
key=k,
Expand Down
Loading

0 comments on commit d41eda0

Please sign in to comment.