Skip to content

Commit

Permalink
Read offloaded literals (#2685)
Browse files Browse the repository at this point in the history
* [WIP] - Read offloaded literals

Signed-off-by: Eduardo Apolinario <[email protected]>

* Use LiteralOffloadedMetadata field

Signed-off-by: Eduardo Apolinario <[email protected]>

* Assert use of offloaded uri to get around typing constraint

Signed-off-by: Eduardo Apolinario <[email protected]>

* Add a bunch of unit tests

Signed-off-by: Eduardo Apolinario <[email protected]>

* Remove TODO and fix comment

Signed-off-by: Eduardo Apolinario <[email protected]>

* Simplify generation of local file to store literal

Signed-off-by: Eduardo Apolinario <[email protected]>

* Rename variable: `local_literal_file` to `literal_local_file`

Signed-off-by: Eduardo Apolinario <[email protected]>

* Fix lint errors

Signed-off-by: Eduardo Apolinario <[email protected]>

---------

Signed-off-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
  • Loading branch information
eapolinario and eapolinario authored Sep 18, 2024
1 parent 11c3a18 commit 2dcbb90
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 4 deletions.
10 changes: 9 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from flytekit.core.context_manager import FlyteContext
from flytekit.core.hash import HashMethod
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.core.utils import timeit
from flytekit.core.utils import load_proto_from_file, timeit
from flytekit.exceptions import user as user_exceptions
from flytekit.interaction.string_literals import literal_map_string_repr
from flytekit.lazy_import.lazy_module import is_imported
Expand Down Expand Up @@ -1155,6 +1155,14 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T
"""
Converts a Literal value with an expected python type into a python value.
"""
# Initiate the process of loading the offloaded literal if offloaded_metadata is set
if lv.offloaded_metadata:
literal_local_file = ctx.file_access.get_random_local_path()
assert lv.offloaded_metadata.uri, "missing offloaded uri"
ctx.file_access.download(lv.offloaded_metadata.uri, literal_local_file)
input_proto = load_proto_from_file(literals_pb2.Literal, literal_local_file)
lv = Literal.from_flyte_idl(input_proto)

transformer = cls.get_transformer(expected_python_type)
return transformer.to_python_value(ctx, lv, expected_python_type)

Expand Down
63 changes: 61 additions & 2 deletions flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flytekit.exceptions import user as _user_exceptions
from flytekit.models import common as _common
from flytekit.models.core import types as _core_types
from flytekit.models.types import Error, StructuredDatasetType
from flytekit.models.types import Error, LiteralType, StructuredDatasetType
from flytekit.models.types import LiteralType as _LiteralType
from flytekit.models.types import OutputReference as _OutputReference
from flytekit.models.types import SchemaType as _SchemaType
Expand Down Expand Up @@ -852,6 +852,52 @@ def from_flyte_idl(cls, pb2_object):
)


class LiteralOffloadedMetadata(_common.FlyteIdlEntity):
def __init__(
self,
uri: Optional[str] = None,
size_bytes: Optional[int] = None,
inferred_type: Optional[LiteralType] = None,
):
"""
:param Text uri: URI of the offloaded literal
:param int size_bytes: Size in bytes of the offloaded literal proto
:param LiteralType inferred_type: Inferred type of the offloaded literal
"""
self._uri = uri
self._size_bytes = size_bytes
self._inferred_type = inferred_type

@property
def uri(self):
return self._uri

@property
def size_bytes(self):
return self._size_bytes

@property
def inferred_type(self):
return self._inferred_type

def to_flyte_idl(self):
return _literals_pb2.LiteralOffloadedMetadata(
uri=self.uri,
size_bytes=self.size_bytes,
inferred_type=self.inferred_type.to_flyte_idl() if self.inferred_type else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
return cls(
uri=pb2_object.uri,
size_bytes=pb2_object.size_bytes,
inferred_type=_LiteralType.from_flyte_idl(pb2_object.inferred_type)
if pb2_object.HasField("inferred_type")
else None,
)


class Literal(_common.FlyteIdlEntity):
def __init__(
self,
Expand All @@ -860,6 +906,7 @@ def __init__(
map: Optional[LiteralMap] = None,
hash: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
offloaded_metadata: Optional[LiteralOffloadedMetadata] = None,
):
"""
This IDL message represents a literal value in the Flyte ecosystem.
Expand All @@ -873,6 +920,7 @@ def __init__(
self._map = map
self._hash = hash
self._metadata = metadata
self._offloaded_metadata = offloaded_metadata

@property
def scalar(self):
Expand Down Expand Up @@ -925,6 +973,13 @@ def metadata(self) -> Optional[Dict[str, str]]:
"""
return self._metadata

@property
def offloaded_metadata(self) -> Optional[LiteralOffloadedMetadata]:
"""
This value holds metadata about the offloaded literal.
"""
return self._offloaded_metadata

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.literals_pb2.Literal
Expand All @@ -935,10 +990,11 @@ def to_flyte_idl(self):
map=self.map.to_flyte_idl() if self.map is not None else None,
hash=self.hash,
metadata=self.metadata,
offloaded_metadata=self.offloaded_metadata.to_flyte_idl() if self.offloaded_metadata else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
def from_flyte_idl(cls, pb2_object: _literals_pb2.Literal):
"""
:param flyteidl.core.literals_pb2.Literal pb2_object:
:rtype: Literal
Expand All @@ -953,6 +1009,9 @@ def from_flyte_idl(cls, pb2_object):
map=LiteralMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None,
hash=pb2_object.hash if pb2_object.hash else None,
metadata={k: v for k, v in pb2_object.metadata.items()} if pb2_object.metadata else None,
offloaded_metadata=LiteralOffloadedMetadata.from_flyte_idl(pb2_object.offloaded_metadata)
if pb2_object.HasField("offloaded_metadata")
else None,
)

def set_metadata(self, metadata: Dict[str, str]):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.13.1",
"flyteidl>=1.13.4",
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down
179 changes: 179 additions & 0 deletions tests/flytekit/unit/core/test_offloaded_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from dataclasses import dataclass
import typing

from mashumaro.mixins.json import DataClassJSONMixin
import pytest
from flytekit import task
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.models import literals as literal_models
from flytekit.core import context_manager
from flytekit.models.types import SimpleType
from flytekit.core.type_engine import TypeEngine

@pytest.fixture
def flyte_ctx():
return context_manager.FlyteContext.current_context()


def test_task_offloaded_literal_single_input(tmp_path):
@task
def t1(a: int) -> str:
return str(a)

original_input_literal = literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3))
)

# Write offloaded_lv as bytes to a temp file
with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f:
f.write(original_input_literal.to_flyte_idl().SerializeToString())

offloaded_input_literal = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto.pb",
inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER),
)
)

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t1.dispatch_execute(
ctx,
literal_models.LiteralMap(
literals={
"a": offloaded_input_literal,
}
),
)
assert output_lm.literals["o0"].scalar.primitive.string_value == "3"


def test_task_offloaded_literal_multiple_input(tmp_path):
@task
def t1(a: int, b: int) -> int:
return a + b

original_input_literal_a = literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3))
)
original_input_literal_b = literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4))
)

# Write offloaded_lv as bytes to a temp file
with open(f"{tmp_path}/offloaded_proto_a.pb", "wb") as f:
f.write(original_input_literal_a.to_flyte_idl().SerializeToString())
with open(f"{tmp_path}/offloaded_proto_b.pb", "wb") as f:
f.write(original_input_literal_b.to_flyte_idl().SerializeToString())

offloaded_input_literal_a = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto_a.pb",
inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER),
)
)
offloaded_input_literal_b = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto_b.pb",
inferred_type=literal_models.LiteralType(simple=SimpleType.INTEGER),
)
)

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t1.dispatch_execute(
ctx,
literal_models.LiteralMap(
literals={
"a": offloaded_input_literal_a,
"b": offloaded_input_literal_b,
}
),
)
assert output_lm.literals["o0"].scalar.primitive.integer == 7


def test_task_offloaded_literal_single_dataclass(tmp_path, flyte_ctx):
@dataclass
class DC(DataClassJSONMixin):
x: int
y: str
z: typing.List[int]

@task
def t1(dc: DC) -> DC:
return dc

lt = TypeEngine.to_literal_type(DC)
original_input_literal = TypeEngine.to_literal(flyte_ctx, DC(x=3, y="hello", z=[1, 2, 3]), DC, lt)

with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f:
f.write(original_input_literal.to_flyte_idl().SerializeToString())

offloaded_input_literal = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto.pb",
inferred_type=lt,
)
)

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t1.dispatch_execute(
ctx,
literal_models.LiteralMap(
literals={
"dc": offloaded_input_literal,
}
),
)
assert output_lm.literals["o0"] == original_input_literal


def test_task_offloaded_literal_list_int(tmp_path):
@task
def t1(xs: typing.List[int]) -> typing.List[str]:
return [str(a) for a in xs]

original_input_literal = literal_models.Literal(
collection=literal_models.LiteralCollection(
literals=[
literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=3))
),
literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(integer=4))
),
]
)
)
expected_output_literal = literal_models.Literal(
collection=literal_models.LiteralCollection(
literals=[
literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="3"))
),
literal_models.Literal(
scalar=literal_models.Scalar(primitive=literal_models.Primitive(string_value="4"))
),
]
)
)

with open(f"{tmp_path}/offloaded_proto.pb", "wb") as f:
f.write(original_input_literal.to_flyte_idl().SerializeToString())

offloaded_input_literal = literal_models.Literal(
offloaded_metadata=literal_models.LiteralOffloadedMetadata(
uri=f"{tmp_path}/offloaded_proto.pb",
inferred_type=literal_models.LiteralType(collection_type=SimpleType.INTEGER),
)
)

ctx = context_manager.FlyteContextManager.current_context()
output_lm = t1.dispatch_execute(
ctx,
literal_models.LiteralMap(
literals={
"xs": offloaded_input_literal,
}
),
)
assert output_lm.literals["o0"] == expected_output_literal
Loading

0 comments on commit 2dcbb90

Please sign in to comment.