diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index 4491dbe4..5e04c349 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -642,6 +642,59 @@ def jmespath_to_name(s: str): } return chain.gen(**signal_dict) # type: ignore[misc, arg-type] + def explode( + self, + col: str, + model_name: Optional[str] = None, + object_name: Optional[str] = None, + ) -> "DataChain": + """Explodes a column containing JSON objects (dict or str DataChain type) into + individual columns based on the schema of the JSON. Schema is inferred from + the first row of the column. + + Args: + col: the name of the column containing JSON to be exploded. + model_name: optional generated model name. By default generates the name + automatically. + object_name: optional generated object column name. By default generates the + name automatically. + + Returns: + DataChain: A new DataChain instance with the new set of columns. + """ + import json + + import pyarrow as pa + + from datachain.lib.arrow import schema_to_output + + json_value = next(self.limit(1).collect(col)) + json_dict = ( + json.loads(json_value) if isinstance(json_value, str) else json_value + ) + + if not isinstance(json_dict, dict): + raise TypeError(f"Column {col} should be a string or dict type with JSON") + + schema = pa.Table.from_pylist([json_dict]).schema + output = schema_to_output(schema, None) + + if not model_name: + model_name = f"{col.title()}ExplodedModel" + + model = dict_to_data_model(model_name, output) + + def json_to_model(json_value: Union[str, dict]): + json_dict = ( + json.loads(json_value) if isinstance(json_value, str) else json_value + ) + return model.model_validate(json_dict) + + if not object_name: + object_name = f"{col}_expl" + + return self.map(json_to_model, params=col, output={object_name: model}) + @classmethod def datasets( cls, diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index ecae75dd..f07d69dc 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -1313,6 +1313,48 @@ def test_to_csv_features_nested(tmp_dir, test_session): ] +@pytest.mark.parametrize("column_type", (str, dict)) +@pytest.mark.parametrize("object_name", (None, "test_object_name")) +@pytest.mark.parametrize("model_name", (None, "TestModelNameExploded")) +def test_explode(tmp_dir, test_session, column_type, object_name, model_name): + df = pd.DataFrame(DF_DATA) + path = tmp_dir / "test.json" + df.to_json(path, orient="records", lines=True) + + dc = ( + DataChain.from_storage(path.as_uri(), session=test_session) + .gen( + content=lambda file: (ln for ln in file.read_text().split("\n") if ln), + output=column_type, + ) + .explode("content", object_name=object_name, model_name=model_name) + ) + + object_name = object_name or "content_expl" + model_name = model_name or "ContentExplodedModel" + + assert set( + dc.collect( + f"{object_name}.first_name", f"{object_name}.age", f"{object_name}.city" + ) + ) == { + ("Alice", 25, "New York"), + ("Bob", 30, "Los Angeles"), + ("Charlie", 35, "Chicago"), + ("David", 40, "Houston"), + ("Eva", 45, "Phoenix"), + } + + assert next(dc.limit(1).collect(object_name)).__class__.__name__ == model_name + + +def test_explode_raises_on_wrong_column_type(test_session): + dc = DataChain.from_values(f1=features, session=test_session) + + with pytest.raises(TypeError): + dc.explode("f1.count") + + # These deprecation warnings occur in the datamodel-code-generator package. @pytest.mark.filterwarnings("ignore::pydantic.warnings.PydanticDeprecatedSince20") def test_to_from_json(tmp_dir, test_session):