Skip to content

Commit

Permalink
feat(dc): make expand work if some fields are optional
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein committed Nov 22, 2024
1 parent 1fe4891 commit 639f443
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 48 deletions.
36 changes: 24 additions & 12 deletions src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,31 +116,43 @@ def infer_schema(chain: "DataChain", **kwargs) -> pa.Schema:
return pa.unify_schemas(schemas)


def schema_to_output(schema: pa.Schema, col_names: Optional[Sequence[str]] = None):
"""Generate UDF output schema from pyarrow schema."""
def schema_to_output(
schema: pa.Schema, col_names: Optional[Sequence[str]] = None
) -> tuple[dict[str, type], list[str]]:
"""
Generate UDF output schema from pyarrow schema.
Returns a tuple of output schema and original column names (since they may be
normalized in the output dict).
"""
signal_schema = _get_datachain_schema(schema)
if signal_schema:
return signal_schema.values, list(signal_schema.values.keys())

if col_names and (len(schema) != len(col_names)):
raise ValueError(
"Error generating output from Arrow schema - "
f"Schema has {len(schema)} columns but got {len(col_names)} column names."
)
if not col_names:
col_names = schema.names
signal_schema = _get_datachain_schema(schema)
if signal_schema:
return signal_schema.values
columns = list(normalize_col_names(col_names).keys()) # type: ignore[arg-type]
col_names = schema.names or []

normalized_col_dict = normalize_col_names(col_names)
col_names = list(normalized_col_dict.keys())

hf_schema = _get_hf_schema(schema)
if hf_schema:
return {
column: hf_type for hf_type, column in zip(hf_schema[1].values(), columns)
}
column: hf_type for hf_type, column in zip(hf_schema[1].values(), col_names)
}, list(normalized_col_dict.values())

output = {}
for field, column in zip(schema, columns):
dtype = arrow_type_mapper(field.type, column) # type: ignore[assignment]
for field, column in zip(schema, col_names):
dtype = arrow_type_mapper(field.type, column)
if field.nullable and not ModelStore.is_pydantic(dtype):
dtype = Optional[dtype] # type: ignore[assignment]
output[column] = dtype
return output

return output, list(normalized_col_dict.values())


def arrow_type_mapper(col_type: pa.DataType, column: str = "") -> type: # noqa: PLR0911
Expand Down
34 changes: 25 additions & 9 deletions src/datachain/lib/data_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections.abc import Sequence
from datetime import datetime
from typing import ClassVar, Union, get_args, get_origin
from typing import ClassVar, Optional, Union, get_args, get_origin

from pydantic import BaseModel, Field, create_model
from pydantic import AliasChoices, BaseModel, Field, create_model

from datachain.lib.model_store import ModelStore
from datachain.lib.utils import normalize_col_names
Expand Down Expand Up @@ -60,17 +60,33 @@ def is_chain_type(t: type) -> bool:
return False


def dict_to_data_model(name: str, data_dict: dict[str, DataType]) -> type[BaseModel]:
# Gets a map of a normalized_name -> original_name
columns = normalize_col_names(list(data_dict.keys()))
# We reverse if for convenience to original_name -> normalized_name
columns = {v: k for k, v in columns.items()}
def dict_to_data_model(
name: str,
data_dict: dict[str, DataType],
original_names: Optional[list[str]] = None,
) -> type[BaseModel]:
if not original_names:
# Gets a map of a normalized_name -> original_name
columns = normalize_col_names(list(data_dict))
data_dict = dict(zip(columns.keys(), data_dict.values()))
original_names = list(columns.values())

fields = {
columns[name]: (anno, Field(alias=name)) for name, anno in data_dict.items()
name: (
anno,
Field(
validation_alias=AliasChoices(name, original_names[idx] or name),
default=None,
),
)
for idx, (name, anno) in enumerate(data_dict.items())
}

class _DataModelStrict(BaseModel, extra="forbid"):
pass

return create_model(
name,
__base__=(DataModel,), # type: ignore[call-overload]
__base__=_DataModelStrict,
**fields,
) # type: ignore[call-overload]
25 changes: 16 additions & 9 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ def explode(
col: str,
model_name: Optional[str] = None,
object_name: Optional[str] = None,
schema_sample_size: int = 1,
) -> "DataChain":
"""Explodes a column containing JSON objects (dict or str DataChain type) into
individual columns based on the schema of the JSON. Schema is inferred from
Expand All @@ -658,6 +659,9 @@ def explode(
automatically.
object_name: optional generated object column name. By default generates the
name automatically.
schema_sample_size: the number of rows to use for inferring the schema of
the JSON (in case some fields are optional and it's not enough to
analyze a single row).
Returns:
DataChain: A new DataChain instance with the new set of columns.
Expand All @@ -668,21 +672,22 @@ def explode(

from datachain.lib.arrow import schema_to_output

json_value = next(self.limit(1).collect(col))
json_dict = (
json_values = list(self.limit(schema_sample_size).collect(col))
json_dicts = [
json.loads(json_value) if isinstance(json_value, str) else json_value
)
for json_value in json_values
]

if not isinstance(json_dict, dict):
if any(not isinstance(json_dict, dict) for json_dict in json_dicts):
raise TypeError(f"Column {col} should be a string or dict type with JSON")

schema = pa.Table.from_pylist([json_dict]).schema
output = schema_to_output(schema, None)
schema = pa.Table.from_pylist(json_dicts).schema
output, original_names = schema_to_output(schema, None)

if not model_name:
model_name = f"{col.title()}ExplodedModel"

model = dict_to_data_model(model_name, output)
model = dict_to_data_model(model_name, output, original_names)

def json_to_model(json_value: Union[str, dict]):
json_dict = (
Expand Down Expand Up @@ -775,7 +780,7 @@ def print_json_schema( # type: ignore[override]
```py
uri = "gs://datachain-demo/coco2017/annotations_captions/"
chain = DataChain.from_storage(uri)
chain = chain.show_json_schema()
chain = chain.print_json_schema()
chain.save()
```
"""
Expand Down Expand Up @@ -1829,13 +1834,14 @@ def parse_tabular(
if col_names or not output:
try:
schema = infer_schema(self, **kwargs)
output = schema_to_output(schema, col_names)
output, _ = schema_to_output(schema, col_names)
except ValueError as e:
raise DatasetPrepareError(self.name, e) from e

if isinstance(output, dict):
model_name = model_name or object_name or ""
model = dict_to_data_model(model_name, output)
output = model
else:
model = output # type: ignore[assignment]

Expand All @@ -1846,6 +1852,7 @@ def parse_tabular(
name: info.annotation # type: ignore[misc]
for name, info in output.model_fields.items()
}

if source:
output = {"source": ArrowRow} | output # type: ignore[assignment,operator]
return self.gen(
Expand Down
1 change: 1 addition & 0 deletions src/datachain/lib/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, col_name, msg):


def normalize_col_names(col_names: Sequence[str]) -> dict[str, str]:
"""Returns normalized_name -> original_name dict."""
gen_col_counter = 0
new_col_names = {}
org_col_names = set(col_names)
Expand Down
40 changes: 34 additions & 6 deletions tests/unit/lib/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def test_arrow_generator_output_schema(tmp_path, catalog):
stream = File(path=pq_path.as_posix(), source="file://")
stream._set_stream(catalog, caching_enabled=False)

output_schema = dict_to_data_model("", schema_to_output(table.schema))
output, original_names = schema_to_output(table.schema)
output_schema = dict_to_data_model("", output, original_names)
func = ArrowGenerator(output_schema=output_schema)
objs = list(func.process(stream))

Expand All @@ -97,7 +98,9 @@ def test_arrow_generator_hf(tmp_path, catalog):
stream = File(path=pq_path.as_posix(), source="file:///")
stream._set_stream(catalog, caching_enabled=False)

output_schema = dict_to_data_model("", schema_to_output(ds._data.schema, ["col"]))
output, original_names = schema_to_output(ds._data.schema, ["col"])

output_schema = dict_to_data_model("", output, original_names)
func = ArrowGenerator(output_schema=output_schema)
for obj in func.process(stream):
assert isinstance(obj[1].col, HFClassLabel)
Expand Down Expand Up @@ -154,7 +157,11 @@ def test_schema_to_output():
("strict_int", pa.int32(), False),
]
)
assert schema_to_output(schema) == {

output, original_names = schema_to_output(schema)

assert original_names == ["some_int", "some_string", "strict_int"]
assert output == {
"some_int": Optional[int],
"some_string": Optional[str],
"strict_int": int,
Expand All @@ -174,7 +181,20 @@ def test_parquet_convert_column_names():
("trailing__underscores__", pa.int32()),
]
)
assert list(schema_to_output(schema)) == [

output, original_names = schema_to_output(schema)

assert original_names == [
"UpperCaseCol",
"dot.notation.col",
"with-dashes",
"with spaces",
"with-multiple--dashes",
"with__underscores",
"__leading__underscores",
"trailing__underscores__",
]
assert list(output) == [
"uppercasecol",
"dot_notation_col",
"with_dashes",
Expand All @@ -193,13 +213,21 @@ def test_parquet_missing_column_names():
("", pa.int32()),
]
)
assert list(schema_to_output(schema)) == ["c0", "c1"]

output, original_names = schema_to_output(schema)

assert original_names == ["", ""]
assert list(output) == ["c0", "c1"]


def test_parquet_override_column_names():
schema = pa.schema([("some_int", pa.int32()), ("some_string", pa.string())])
col_names = ["n1", "n2"]
assert schema_to_output(schema, col_names) == {

output, original_names = schema_to_output(schema, col_names)

assert original_names == ["n1", "n2"]
assert output == {
"n1": Optional[int],
"n2": Optional[str],
}
Expand Down
33 changes: 21 additions & 12 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,16 @@
}

DF_DATA_NESTED_NOT_NORMALIZED = {
"nAmE": [
"nA-mE": [
{"first-SELECT": "Ivan"},
{"first-SELECT": "Alice", "l--as@t": "Smith"},
{"l--as@t": "Jones", "first-SELECT": "Bob"},
{"first-SELECT": "Charlie", "l--as@t": "Brown"},
{"first-SELECT": "David", "l--as@t": "White"},
{"first-SELECT": "Eva", "l--as@t": "Black"},
],
"AgE": [25, 30, 35, 40, 45],
"citY": ["New York", "Los Angeles", "Chicago", "Houston", "Phoenix"],
"AgE": [41, 25, 30, 35, 40, 45],
"citY": ["San Francisco", "New York", "Los Angeles", None, "Houston", "Phoenix"],
}

DF_OTHER_DATA = {
Expand Down Expand Up @@ -1011,13 +1012,13 @@ def test_parse_nested_json(tmp_dir, test_session):
)
# Field names are normalized, values are preserved
# E.g. nAmE -> name, l--as@t -> l_as_t, etc
df1 = dc.select("name", "age", "city").to_pandas()
df1 = dc.select("na_me", "age", "city").to_pandas()

assert df1["name"]["first_select"].to_list() == [
d["first-SELECT"] for d in df["nAmE"].to_list()
assert df1["na_me"]["first_select"].to_list() == [
d["first-SELECT"] for d in df["nA-mE"].to_list()
]
assert df1["name"]["l_as_t"].to_list() == [
d["l--as@t"] for d in df["nAmE"].to_list()
assert df1["na_me"]["l_as_t"].to_list() == [
d.get("l--as@t") for d in df["nA-mE"].to_list()
]


Expand Down Expand Up @@ -1317,7 +1318,7 @@ def test_to_csv_features_nested(tmp_dir, test_session):
@pytest.mark.parametrize("object_name", (None, "test_object_name"))
@pytest.mark.parametrize("model_name", (None, "TestModelNameExploded"))
def test_explode(tmp_dir, test_session, column_type, object_name, model_name):
df = pd.DataFrame(DF_DATA)
df = pd.DataFrame(DF_DATA_NESTED_NOT_NORMALIZED)
path = tmp_dir / "test.json"
df.to_json(path, orient="records", lines=True)

Expand All @@ -1327,22 +1328,30 @@ def test_explode(tmp_dir, test_session, column_type, object_name, model_name):
content=lambda file: (ln for ln in file.read_text().split("\n") if ln),
output=column_type,
)
.explode("content", object_name=object_name, model_name=model_name)
.explode(
"content",
object_name=object_name,
model_name=model_name,
schema_sample_size=2,
)
)

object_name = object_name or "content_expl"
model_name = model_name or "ContentExplodedModel"

assert set(
dc.collect(
f"{object_name}.first_name", f"{object_name}.age", f"{object_name}.city"
f"{object_name}.na_me.first_select",
f"{object_name}.age",
f"{object_name}.city",
)
) == {
("Alice", 25, "New York"),
("Bob", 30, "Los Angeles"),
("Charlie", 35, "Chicago"),
("Charlie", 35, None),
("David", 40, "Houston"),
("Eva", 45, "Phoenix"),
("Ivan", 41, "San Francisco"),
}

assert next(dc.limit(1).collect(object_name)).__class__.__name__ == model_name
Expand Down

0 comments on commit 639f443

Please sign in to comment.