Skip to content

Commit

Permalink
feat(dc): add explode function (#581)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Nov 16, 2024
1 parent 059241a commit 1fe4891
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,59 @@ def jmespath_to_name(s: str):
}
return chain.gen(**signal_dict) # type: ignore[misc, arg-type]

def explode(
self,
col: str,
model_name: Optional[str] = None,
object_name: Optional[str] = None,
) -> "DataChain":
"""Explodes a column containing JSON objects (dict or str DataChain type) into
individual columns based on the schema of the JSON. Schema is inferred from
the first row of the column.
Args:
col: the name of the column containing JSON to be exploded.
model_name: optional generated model name. By default generates the name
automatically.
object_name: optional generated object column name. By default generates the
name automatically.
Returns:
DataChain: A new DataChain instance with the new set of columns.
"""
import json

import pyarrow as pa

from datachain.lib.arrow import schema_to_output

json_value = next(self.limit(1).collect(col))
json_dict = (
json.loads(json_value) if isinstance(json_value, str) else json_value
)

if not isinstance(json_dict, dict):
raise TypeError(f"Column {col} should be a string or dict type with JSON")

schema = pa.Table.from_pylist([json_dict]).schema
output = schema_to_output(schema, None)

if not model_name:
model_name = f"{col.title()}ExplodedModel"

model = dict_to_data_model(model_name, output)

def json_to_model(json_value: Union[str, dict]):
json_dict = (
json.loads(json_value) if isinstance(json_value, str) else json_value
)
return model.model_validate(json_dict)

if not object_name:
object_name = f"{col}_expl"

return self.map(json_to_model, params=col, output={object_name: model})

@classmethod
def datasets(
cls,
Expand Down
42 changes: 42 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,48 @@ def test_to_csv_features_nested(tmp_dir, test_session):
]


@pytest.mark.parametrize("column_type", (str, dict))
@pytest.mark.parametrize("object_name", (None, "test_object_name"))
@pytest.mark.parametrize("model_name", (None, "TestModelNameExploded"))
def test_explode(tmp_dir, test_session, column_type, object_name, model_name):
df = pd.DataFrame(DF_DATA)
path = tmp_dir / "test.json"
df.to_json(path, orient="records", lines=True)

dc = (
DataChain.from_storage(path.as_uri(), session=test_session)
.gen(
content=lambda file: (ln for ln in file.read_text().split("\n") if ln),
output=column_type,
)
.explode("content", object_name=object_name, model_name=model_name)
)

object_name = object_name or "content_expl"
model_name = model_name or "ContentExplodedModel"

assert set(
dc.collect(
f"{object_name}.first_name", f"{object_name}.age", f"{object_name}.city"
)
) == {
("Alice", 25, "New York"),
("Bob", 30, "Los Angeles"),
("Charlie", 35, "Chicago"),
("David", 40, "Houston"),
("Eva", 45, "Phoenix"),
}

assert next(dc.limit(1).collect(object_name)).__class__.__name__ == model_name


def test_explode_raises_on_wrong_column_type(test_session):
dc = DataChain.from_values(f1=features, session=test_session)

with pytest.raises(TypeError):
dc.explode("f1.count")


# These deprecation warnings occur in the datamodel-code-generator package.
@pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20")
def test_to_from_json(tmp_dir, test_session):
Expand Down

0 comments on commit 1fe4891

Please sign in to comment.