diff --git a/dagshub/data_engine/client/data_client.py b/dagshub/data_engine/client/data_client.py index 3f689830..8688edcb 100644 --- a/dagshub/data_engine/client/data_client.py +++ b/dagshub/data_engine/client/data_client.py @@ -12,7 +12,9 @@ from dagshub.common.analytics import send_analytics_event from dagshub.common.rich_util import get_rich_progress from dagshub.data_engine.client.models import DatasourceResult, DatasourceType, IntegrationStatus, \ - PreprocessingStatus, DatasetResult, MetadataFieldType, ScanOption + PreprocessingStatus, DatasetResult, MetadataFieldSchema +from dagshub.data_engine.dtypes import MetadataFieldType +from dagshub.data_engine.client.models import ScanOption from dagshub.data_engine.client.gql_mutations import GqlMutations from dagshub.data_engine.client.gql_queries import GqlQueries from dagshub.data_engine.model.datasource import Datasource, DatapointMetadataUpdateEntry @@ -31,7 +33,6 @@ class DataClient: FULL_LIST_PAGE_SIZE = 5000 def __init__(self, repo: str): - # TODO: add project authentication here self.repo = repo self.host = config.host self.client = self._init_client() @@ -156,6 +157,19 @@ def update_metadata(self, datasource: Datasource, entries: List[DatapointMetadat return self._exec(q, params) + def update_metadata_fields(self, datasource: Datasource, metadata_field_props: List[MetadataFieldSchema]): + q = GqlMutations.update_metadata_field() + + assert datasource.source.id is not None + # assert len(entries) > 0 + + params = GqlMutations.update_metadata_fields_params( + datasource_id=datasource.source.id, + metadata_field_props=[e.to_dict() for e in metadata_field_props] + ) + + return self._exec(q, params) + def get_datasources(self, id: Optional[str], name: Optional[str]) -> List[DatasourceResult]: q = GqlQueries.datasource() params = GqlQueries.datasource_params(id=id, name=name) diff --git a/dagshub/data_engine/client/gql_mutations.py b/dagshub/data_engine/client/gql_mutations.py index 5fd6db8e..32534885 100644 --- a/dagshub/data_engine/client/gql_mutations.py +++ b/dagshub/data_engine/client/gql_mutations.py @@ -65,6 +65,30 @@ def update_metadata(): ]).generate() return q + @staticmethod + @functools.lru_cache() + def update_metadata_field(): + q = GqlQuery().operation( + "mutation", + name="updateMetadataFieldProps", + input={ + "$datasource": "ID!", + "$props": "[MetadataFieldPropsInput!]!" + } + ).query( + "updateMetadataFieldProps", + input={ + "datasource": "$datasource", + "props": "$props" + } + ).fields([ + "name", + "valueType", + "multiple", + "tags", + ]).generate() + return q + @staticmethod def update_metadata_params(datasource_id: Union[int, str], datapoints: List[Dict[str, Any]]): return { @@ -72,6 +96,12 @@ def update_metadata_params(datasource_id: Union[int, str], datapoints: List[Dict "datapoints": datapoints, } + def update_metadata_fields_params(datasource_id: Union[int, str], metadata_field_props: List[Dict[str, Any]]): + return { + "datasource": datasource_id, + "props": metadata_field_props + } + @staticmethod @functools.lru_cache() def delete_datasource(): diff --git a/dagshub/data_engine/client/gql_queries.py b/dagshub/data_engine/client/gql_queries.py index ac2a16d3..c633de72 100644 --- a/dagshub/data_engine/client/gql_queries.py +++ b/dagshub/data_engine/client/gql_queries.py @@ -28,7 +28,7 @@ def datasource() -> str: "rootUrl", "integrationStatus", "preprocessingStatus", - "metadataFields {name valueType multiple}" + "metadataFields {name valueType multiple tags}" "type", ]).generate() return q @@ -97,7 +97,7 @@ def dataset() -> str: "id", "name", "datasource {id name rootUrl integrationStatus preprocessingStatus " - "metadataFields {name valueType multiple} type}", + "metadataFields {name valueType multiple tags} type}", "datasetQuery", ]).generate() return q diff --git a/dagshub/data_engine/client/models.py b/dagshub/data_engine/client/models.py index 8c9237bb..2aa602f3 100644 --- a/dagshub/data_engine/client/models.py +++ b/dagshub/data_engine/client/models.py @@ -4,6 +4,7 @@ from typing import Any, List, Union, Optional from dataclasses_json import dataclass_json, config +from dagshub.data_engine.dtypes import MetadataFieldType, ReservedTags logger = logging.getLogger(__name__) @@ -40,14 +41,6 @@ class DatasourceType(enum.Enum): CUSTOM = "CUSTOM" -class MetadataFieldType(enum.Enum): - BOOLEAN = "BOOLEAN" - INTEGER = "INTEGER" - FLOAT = "FLOAT" - STRING = "STRING" - BLOB = "BLOB" - - class ScanOption(str, enum.Enum): FORCE_REGENERATE_AUTO_SCAN_VALUES = "FORCE_REGENERATE_AUTO_SCAN_VALUES" @@ -55,6 +48,7 @@ class ScanOption(str, enum.Enum): @dataclass_json @dataclass class MetadataFieldSchema: + # This should match the GraphQL schema: MetadataFieldProps name: str valueType: MetadataFieldType = field( metadata=config( @@ -62,9 +56,16 @@ class MetadataFieldSchema: ) ) multiple: bool + tags: Optional[List[str]] def __repr__(self): - return f"{self.name} ({self.valueType.value})" + res = f"{self.name} ({self.valueType.value})" + if self.tags is not None and len(self.tags) > 0: + res += f" with tags: {self.tags}" + return res + + def is_annotation(self): + return ReservedTags.ANNOTATION.value in self.tags if self.tags else False @dataclass diff --git a/dagshub/data_engine/dtypes.py b/dagshub/data_engine/dtypes.py new file mode 100644 index 00000000..77c035dc --- /dev/null +++ b/dagshub/data_engine/dtypes.py @@ -0,0 +1,58 @@ +import enum +from abc import ABCMeta +from typing import List + + +class ReservedTags(enum.Enum): + ANNOTATION = "annotation" + + +# These are the base primitives that the data engine database is capable of storing +class MetadataFieldType(enum.Enum): + BOOLEAN = "BOOLEAN" + INTEGER = "INTEGER" + FLOAT = "FLOAT" + STRING = "STRING" + BLOB = "BLOB" + + +# Inheritors of this ABC define custom types +# They are backed by a primitive type, but they also may have additional tags, describing specialized behavior +class DagshubDataType(metaclass=ABCMeta): + """ + Attributes: + backing_field_type: primitive type in the data engine database + custom_tags: additional tags applied to this type + """ + backing_field_type: MetadataFieldType = None + custom_tags: List[str] = None + + +class Int(DagshubDataType): + backing_field_type = MetadataFieldType.INTEGER + + +class String(DagshubDataType): + backing_field_type = MetadataFieldType.STRING + + +class Blob(DagshubDataType): + backing_field_type = MetadataFieldType.BLOB + + +class Float(DagshubDataType): + backing_field_type = MetadataFieldType.FLOAT + + +class Bool(DagshubDataType): + backing_field_type = MetadataFieldType.BOOLEAN + + +class LabelStudioAnnotation(DagshubDataType): + backing_field_type = MetadataFieldType.BLOB + custom_tags = [ReservedTags.ANNOTATION.value] + + +class Voxel51Annotation(DagshubDataType): + backing_field_type = MetadataFieldType.BLOB + custom_tags = [ReservedTags.ANNOTATION.value] diff --git a/dagshub/data_engine/model/datapoint.py b/dagshub/data_engine/model/datapoint.py index 38157553..4fad5feb 100644 --- a/dagshub/data_engine/model/datapoint.py +++ b/dagshub/data_engine/model/datapoint.py @@ -5,7 +5,7 @@ from dagshub.common.download import download_files from dagshub.common.helpers import http_request -from dagshub.data_engine.client.models import MetadataFieldType +from dagshub.data_engine.dtypes import MetadataFieldType if TYPE_CHECKING: from dagshub.data_engine.model.datasource import Datasource diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index d11d0694..61fd3fc4 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -21,12 +21,15 @@ from dagshub.common.helpers import prompt_user, http_request, log_message from dagshub.common.rich_util import get_rich_progress from dagshub.common.util import lazy_load, multi_urljoin -from dagshub.data_engine.client.models import PreprocessingStatus, MetadataFieldType, MetadataFieldSchema, \ +from dagshub.data_engine.client.models import PreprocessingStatus, MetadataFieldSchema, \ ScanOption, autogenerated_columns +from dagshub.data_engine.dtypes import MetadataFieldType from dagshub.data_engine.model.datapoint import Datapoint from dagshub.data_engine.model.errors import WrongOperatorError, WrongOrderError, DatasetFieldComparisonError, \ FieldNotFoundError -from dagshub.data_engine.model.query import DatasourceQuery, _metadataTypeLookup, _metadataTypeLookupReverse +from dagshub.data_engine.model.metadata_field_builder import MetadataFieldBuilder +from dagshub.data_engine.model.query import DatasourceQuery +from dagshub.data_engine.model.schema_util import metadataTypeLookup, metadataTypeLookupReverse if TYPE_CHECKING: from dagshub.data_engine.model.query_result import QueryResult @@ -84,8 +87,7 @@ def get_query(self): @property def annotation_fields(self) -> List[str]: - # TODO: once the annotation type is implemented, expose those columns here - return ["annotation"] + return [f.name for f in self.fields if f.is_annotation()] def serialize_gql_query_input(self): return { @@ -123,6 +125,27 @@ def _check_preprocess(self): f"Datasource {self.source.name} is currently in the progress of rescanning. " f"Values might change if you requery later") + def metadata_field(self, field_name: str) -> MetadataFieldBuilder: + """ + Returns a builder for a metadata field. + The builder can be used to change properties of a field or create a new field altogether + + Example of creating a new annotation field: + ds.metadata_field("annotation").set_type(dtypes.LabelStudioAnnotation).apply() + NOTE: New fields have to have their type defined using .set_type() before doing anything else + + Example of marking an existing field as an annotation field: + ds.metadata_field("existing-field").set_annotation().apply() + """ + return MetadataFieldBuilder(self, field_name) + + def apply_field_changes(self, field_builders: List[MetadataFieldBuilder]): + """ + Applies one or multiple metadata field builders that can be constructed using the metadata_field() function + """ + self.source.client.update_metadata_fields(self, [builder.schema for builder in field_builders]) + self.source.get_from_dagshub() + def metadata_context(self) -> ContextManager["MetadataContextManager"]: """ Returns a metadata context, that you can upload metadata through via update_metadata @@ -204,7 +227,7 @@ def _df_to_metadata(self, df: "pandas.DataFrame", path_column: Optional[Union[st for sub_val in val: value_type = field_value_types.get(key) if value_type is None: - value_type = _metadataTypeLookup[type(sub_val)] + value_type = metadataTypeLookup[type(sub_val)] field_value_types[key] = value_type # Don't override bytes if they're not bytes - probably just undownloaded values if value_type == MetadataFieldType.BLOB and type(sub_val) is not bytes: @@ -224,7 +247,7 @@ def _df_to_metadata(self, df: "pandas.DataFrame", path_column: Optional[Union[st else: value_type = field_value_types.get(key) if value_type is None: - value_type = _metadataTypeLookup[type(val)] + value_type = metadataTypeLookup[type(val)] field_value_types[key] = value_type # Don't override bytes if they're not bytes - probably just undownloaded values if value_type == MetadataFieldType.BLOB and type(val) is not bytes: @@ -545,7 +568,7 @@ def contains(self, item: str): def is_null(self): field = self._get_filtering_field() - value_type = _metadataTypeLookupReverse[field.valueType.value] + value_type = metadataTypeLookupReverse[field.valueType.value] return self.add_query_op("isnull", value_type()) def is_not_null(self): @@ -625,7 +648,7 @@ def update_metadata(self, datapoints: Union[List[str], str], metadata: Dict[str, value_type = field_value_types.get(k) if value_type is None: - value_type = _metadataTypeLookup[type(sub_val)] + value_type = metadataTypeLookup[type(sub_val)] field_value_types[k] = value_type # Don't override bytes if they're not bytes - probably just undownloaded values if value_type == MetadataFieldType.BLOB and type(sub_val) is not bytes: @@ -646,7 +669,7 @@ def update_metadata(self, datapoints: Union[List[str], str], metadata: Dict[str, value_type = field_value_types.get(k) if value_type is None: - value_type = _metadataTypeLookup[type(v)] + value_type = metadataTypeLookup[type(v)] field_value_types[k] = value_type # Don't override bytes if they're not bytes - probably just undownloaded values if value_type == MetadataFieldType.BLOB and type(v) is not bytes: @@ -654,6 +677,7 @@ def update_metadata(self, datapoints: Union[List[str], str], metadata: Dict[str, if type(v) is bytes: v = self.wrap_bytes(v) + self._metadata_entries.append(DatapointMetadataUpdateEntry( url=dp, key=k, diff --git a/dagshub/data_engine/model/metadata_field_builder.py b/dagshub/data_engine/model/metadata_field_builder.py new file mode 100644 index 00000000..0326d173 --- /dev/null +++ b/dagshub/data_engine/model/metadata_field_builder.py @@ -0,0 +1,113 @@ +import dataclasses +import logging +from typing import TYPE_CHECKING, Type, Union, List + +from dagshub.data_engine.client.models import MetadataFieldSchema +from dagshub.data_engine.dtypes import DagshubDataType, MetadataFieldType, ReservedTags +from dagshub.data_engine.model.schema_util import metadataTypeLookup + +if TYPE_CHECKING: + from dagshub.data_engine.model.datasource import Datasource + +logger = logging.getLogger(__name__) + + +class MetadataFieldBuilder: + """ + Builder class for changing properties of a metadata field in a datasource. + It is also possible to create a new empty field with predefined schema with this builder + """ + + def __init__(self, datasource: "Datasource", field_name: str): + self.datasource = datasource + + self._field_name = field_name + + preexisting_schema = next(filter(lambda f: f.name == field_name, datasource.fields), None) + + # Make a copy of the dataclass, so we don't change the base schema + if preexisting_schema is not None: + preexisting_schema = dataclasses.replace(preexisting_schema) + if preexisting_schema.tags is not None: + preexisting_schema.tags = preexisting_schema.tags.copy() + + self._schema = preexisting_schema + self.already_exists = self._schema is not None + + @property + def schema(self) -> MetadataFieldSchema: + if self._schema is None: + raise RuntimeError(f"Field {self._field_name} is a new field. " + "Make sure to set_type() the field before setting any other properties") + return self._schema + + def set_type(self, t: Union[Type, DagshubDataType]) -> "MetadataFieldBuilder": + """ + Set the type of the field. + The type can be either a Python primitive supported by the Data Engine (str, bool, int, float, bytes) + Or it can be a DagshubDataType inheritor (found in dagshub.data_engine.dtypes) + The DataType inheritors can define additional tags on top of just the basic backing type + """ + backing_type = self._get_backing_type(t) + + if self._schema is None: + self._schema = MetadataFieldSchema( + name=self._field_name, + valueType=backing_type, + multiple=False, + tags=[] + ) + if issubclass(t, DagshubDataType) and t.custom_tags is not None: + self._schema.tags = t.custom_tags.copy() + else: + if backing_type != self._schema.valueType: + raise ValueError("Can't change a type of an already existing field " + f"(changing from {self._schema.valueType.value} to {backing_type.value})") + if issubclass(t, DagshubDataType) and t.custom_tags is not None: + self._add_tags(t.custom_tags) + + return self + + def set_annotation(self, is_annotation: bool = True) -> "MetadataFieldBuilder": + """ + Mark or unmark the field as annotation field + """ + self._set_or_unset(ReservedTags.ANNOTATION.value, is_annotation) + return self + + def _set_or_unset(self, tag, is_set): + if is_set: + self._add_tags([tag]) + else: + self._remove_tag(tag) + + def _add_tags(self, tags: List[str]): + if self.schema.tags is None: + self.schema.tags = [] + self.schema.tags.extend(tags) + + def _remove_tag(self, tag: str): + if self.schema.tags is None: + return + try: + self.schema.tags.remove(tag) + except ValueError: + logger.warning(f"Tag {tag} doesn't exist on the field, nothing to delete") + + @staticmethod + def _get_backing_type(t: Union[Type, DagshubDataType]) -> MetadataFieldType: + if issubclass(t, DagshubDataType): + return t.backing_field_type + + if type(t) == type: + if t not in metadataTypeLookup.keys(): + raise ValueError(f"Primitive type {type(t)} is not supported") + return metadataTypeLookup[t] + + raise ValueError(f"{t} of type ({type(t)}) is not a valid primitive type or DagshubDataType") + + def apply(self): + """ + Apply the outgoing changes to the metadata field + """ + self.datasource.apply_field_changes([self]) diff --git a/dagshub/data_engine/model/query.py b/dagshub/data_engine/model/query.py index 541c1aa7..129f8c5b 100644 --- a/dagshub/data_engine/model/query.py +++ b/dagshub/data_engine/model/query.py @@ -1,25 +1,17 @@ import enum import logging -from typing import TYPE_CHECKING, Optional, Union, Dict, Type +from typing import TYPE_CHECKING, Optional, Union, Dict from treelib import Tree, Node -from dagshub.data_engine.client.models import MetadataFieldType from dagshub.data_engine.model.errors import WrongOperatorError +from dagshub.data_engine.model.schema_util import metadataTypeLookup, metadataTypeLookupReverse if TYPE_CHECKING: from dagshub.data_engine.model.datasource import Datasource logger = logging.getLogger(__name__) -_metadataTypeLookup = { - int: MetadataFieldType.INTEGER, - bool: MetadataFieldType.BOOLEAN, - float: MetadataFieldType.FLOAT, - str: MetadataFieldType.STRING, - bytes: MetadataFieldType.BLOB, -} - def bytes_deserializer(val: str) -> bytes: if val.startswith('b"') or val.startswith("b'"): @@ -33,10 +25,6 @@ def bytes_deserializer(val: str) -> bytes: bytes: bytes_deserializer, } -_metadataTypeLookupReverse: Dict[str, Type] = {} -for k, v in _metadataTypeLookup.items(): - _metadataTypeLookupReverse[v.value] = k - class FieldFilterOperand(enum.Enum): EQUAL = "EQUAL" @@ -143,7 +131,7 @@ def _serialize_node(node: Node, tree: Tree) -> Dict: raise WrongOperatorError(f"Operator {operand} is not supported") key = node.data["field"] value = node.data["value"] - value_type = _metadataTypeLookup[type(value)].value + value_type = metadataTypeLookup[type(value)].value if type(value) is bytes: # TODO: this will need to probably be changed when we allow actual binary field comparisons value = value.decode("utf-8") @@ -151,11 +139,11 @@ def _serialize_node(node: Node, tree: Tree) -> Dict: value = str(value) if value_type is None: raise RuntimeError(f"Value type {value_type} is not supported for querying.\r\n" - f"Supported types: {list(_metadataTypeLookup.keys())}") + f"Supported types: {list(metadataTypeLookup.keys())}") return { "filter": { "key": key, - "value": value, + "value": str(value), "valueType": value_type, "comparator": query_op.value, } @@ -188,7 +176,7 @@ def _deserialize_node(node_dict: Dict, tree: Tree, parent_node=None) -> None: if op_type == "filter": comparator = fieldFilterOperandMapReverseMap[val["comparator"]] key = val["key"] - value_type = _metadataTypeLookupReverse[val["valueType"]] + value_type = metadataTypeLookupReverse[val["valueType"]] converter = _metadataTypeCustomConverters.get(value_type, lambda x: value_type(x)) value = converter(val["value"]) node = Node(tag=comparator, data={"field": key, "value": value}) diff --git a/dagshub/data_engine/model/schema_util.py b/dagshub/data_engine/model/schema_util.py new file mode 100644 index 00000000..ea20e525 --- /dev/null +++ b/dagshub/data_engine/model/schema_util.py @@ -0,0 +1,15 @@ +from typing import Type, Dict + +from dagshub.data_engine.dtypes import MetadataFieldType + +metadataTypeLookup = { + int: MetadataFieldType.INTEGER, + bool: MetadataFieldType.BOOLEAN, + float: MetadataFieldType.FLOAT, + str: MetadataFieldType.STRING, + bytes: MetadataFieldType.BLOB, +} + +metadataTypeLookupReverse: Dict[str, Type] = {} +for k, v in metadataTypeLookup.items(): + metadataTypeLookupReverse[v.value] = k diff --git a/test_dir/1.csv b/test_dir/1.csv new file mode 100644 index 00000000..e69de29b diff --git a/tests/data_engine/test_datasource.py b/tests/data_engine/test_datasource.py index 77f9ff5e..e65dffbe 100644 --- a/tests/data_engine/test_datasource.py +++ b/tests/data_engine/test_datasource.py @@ -1,7 +1,7 @@ import pandas as pd import pytest -from dagshub.data_engine.client.models import MetadataFieldType +from dagshub.data_engine.dtypes import MetadataFieldType from dagshub.data_engine.model.datasource import Datasource, DatapointMetadataUpdateEntry, MetadataContextManager diff --git a/tests/data_engine/test_metadata_field_update.py b/tests/data_engine/test_metadata_field_update.py new file mode 100644 index 00000000..8c9ae0f6 --- /dev/null +++ b/tests/data_engine/test_metadata_field_update.py @@ -0,0 +1,32 @@ +import pytest + +from dagshub.data_engine import dtypes +from dagshub.data_engine.dtypes import ReservedTags, MetadataFieldType +from tests.data_engine.util import add_int_fields + + +def test_have_to_set_type_on_new_field(ds): + fb = ds.metadata_field("new_field") + with pytest.raises(RuntimeError): + fb.set_annotation() + + +def test_can_set_annotation_on_existing_field(ds): + add_int_fields(ds, "field1") + fb = ds.metadata_field("field1") + fb.set_annotation() + assert ReservedTags.ANNOTATION.value in fb.schema.tags + + +def test_builder_doesnt_change_original_field_schema(ds): + add_int_fields(ds, "field1") + fb = ds.metadata_field("field1").set_annotation() + + assert ReservedTags.ANNOTATION.value in fb.schema.tags + assert ReservedTags.ANNOTATION.value not in ds.fields[0].tags + + +def test_annotation_type_sets_annotation_tag(ds): + fb = ds.metadata_field("new_field").set_type(dtypes.LabelStudioAnnotation) + assert fb.schema.valueType == MetadataFieldType.BLOB + assert ReservedTags.ANNOTATION.value in fb.schema.tags diff --git a/tests/data_engine/util.py b/tests/data_engine/util.py index 7a30c5b2..c88e7fe7 100644 --- a/tests/data_engine/util.py +++ b/tests/data_engine/util.py @@ -1,4 +1,7 @@ -from dagshub.data_engine.client.models import MetadataFieldType, MetadataFieldSchema +from typing import List + +from dagshub.data_engine.client.models import MetadataFieldSchema +from dagshub.data_engine.dtypes import MetadataFieldType from dagshub.data_engine.model.datasource import Datasource @@ -27,6 +30,9 @@ def add_boolean_fields(ds: Datasource, *names: str): add_metadata_field(ds, name, MetadataFieldType.BOOLEAN) -def add_metadata_field(ds: Datasource, name: str, value_type: MetadataFieldType, is_multiple: bool = False): - field = MetadataFieldSchema(name, value_type, is_multiple) +def add_metadata_field(ds: Datasource, name: str, value_type: MetadataFieldType, is_multiple: bool = False, + tags: List[str] = None): + if tags is None: + tags = [] + field = MetadataFieldSchema(name, value_type, is_multiple, tags) ds.source.metadata_fields.append(field)