Skip to content

Commit

Permalink
[BUG] Support creation and reading of StructuredDataset with local or…
Browse files Browse the repository at this point in the history
… remote uri (#2914)

* Manually fill sd literal and metadata

Signed-off-by: JiaWei Jiang <[email protected]>

* Use StructuredDatasetTransformerEngine to set sd literal

Signed-off-by: JiaWei Jiang <[email protected]>

* Add unit test for reading sd from uri

Signed-off-by: JiaWei Jiang <[email protected]>

* Put tasks into wf to mimic real-world use cases

Signed-off-by: JiaWei Jiang <[email protected]>

* add env

Signed-off-by: Future-Outlier <[email protected]>

* env again

Signed-off-by: Future-Outlier <[email protected]>

* Modify python ff reading logic to a flyte task

Signed-off-by: JiaWei Jiang <[email protected]>

* Use task param instead of global path const

Signed-off-by: JiaWei Jiang <[email protected]>

* Remove unit tests that need to interact with s3

Signed-off-by: JiaWei Jiang <[email protected]>

---------

Signed-off-by: JiaWei Jiang <[email protected]>
Signed-off-by: Future-Outlier <[email protected]>
Co-authored-by: Future-Outlier <[email protected]>
  • Loading branch information
JiangJiaWei1103 and Future-Outlier authored Nov 19, 2024
1 parent faee3da commit b04bc8d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
25 changes: 25 additions & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from flytekit.models import types as type_models
from flytekit.models.literals import Binary, Literal, Scalar, StructuredDatasetMetadata
from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType
from flytekit.utils.asyn import loop_manager

if typing.TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -176,8 +177,32 @@ def all(self) -> DF: # type: ignore
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
ctx = FlyteContextManager.current_context()

if self.uri is not None and self.dataframe is None:
expected = TypeEngine.to_literal_type(StructuredDataset)
self._set_literal(ctx, expected)

return flyte_dataset_transformer.open_as(ctx, self.literal, self._dataframe_type, self.metadata)

def _set_literal(self, ctx: FlyteContext, expected: LiteralType) -> None:
"""
Explicitly set the StructuredDataset Literal to handle the following cases:
1. Read a dataframe from a StructuredDataset with an uri, for example:
@task
def return_sd() -> StructuredDataset:
sd = StructuredDataset(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", file_format="parquet")
df = sd.open(pd.DataFrame).all()
return df
For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5954.
"""
to_literal = loop_manager.synced(flyte_dataset_transformer.async_to_literal)
self._literal_sd = to_literal(ctx, self, StructuredDataset, expected).scalar.structured_dataset
if self.metadata is None:
self._metadata = self._literal_sd.metadata

def iter(self) -> Generator[DF, None, None]:
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tempfile
import typing
from collections import OrderedDict
from pathlib import Path

import google.cloud.bigquery
import pytest
Expand All @@ -21,6 +22,7 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType
from flytekit.tools.translator import get_serializable
from flytekit.types.file import FlyteFile
from flytekit.types.structured.structured_dataset import (
PARQUET,
StructuredDataset,
Expand Down Expand Up @@ -59,6 +61,21 @@ def generate_pandas() -> pd.DataFrame:
return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]})


@pytest.fixture
def local_tmp_pqt_file():
df = generate_pandas()

# Create a temporary parquet file
with tempfile.NamedTemporaryFile(delete=False, mode="w+b", suffix=".parquet") as pqt_file:
pqt_path = pqt_file.name
df.to_parquet(pqt_path)

yield pqt_path

# Cleanup
Path(pqt_path).unlink(missing_ok=True)


def test_formats_make_sense():
@task
def t1(a: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -643,3 +660,27 @@ def wf_with_input() -> pd.DataFrame:

pd.testing.assert_frame_equal(wf_no_input(), default_val)
pd.testing.assert_frame_equal(wf_with_input(), input_val)



def test_read_sd_from_local_uri(local_tmp_pqt_file):

@task
def read_sd_from_uri(uri: str) -> pd.DataFrame:
sd = StructuredDataset(uri=uri, file_format="parquet")
df = sd.open(pd.DataFrame).all()

return df

@workflow
def read_sd_from_local_uri(uri: str) -> pd.DataFrame:
df = read_sd_from_uri(uri=uri)

return df


df = generate_pandas()

# Read sd from local uri
df_local = read_sd_from_local_uri(uri=local_tmp_pqt_file)
pd.testing.assert_frame_equal(df, df_local)

0 comments on commit b04bc8d

Please sign in to comment.