From 659abb96a9c344020531a6ea905d021ac1b3006b Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Fri, 29 Nov 2024 08:41:29 +0100 Subject: [PATCH] Make sure that migrations use correct dtypes --- src/ert/storage/migration/to8.py | 70 ++++++++------ .../storage/test_storage_migration.py | 96 +++++++++++++++++++ 2 files changed, 139 insertions(+), 27 deletions(-) diff --git a/src/ert/storage/migration/to8.py b/src/ert/storage/migration/to8.py index b6803e3a976..4de17545c17 100644 --- a/src/ert/storage/migration/to8.py +++ b/src/ert/storage/migration/to8.py @@ -22,7 +22,22 @@ def from_path(cls, path: Path) -> "ObservationDatasetInfo": response_key = ds.attrs["response"] response_type = "summary" if response_key == "summary" else "gen_data" - df = polars.from_pandas(ds.to_dataframe().dropna().reset_index()) + df = polars.from_pandas( + ds.to_dataframe().dropna().reset_index(), + schema_overrides={ + "report_step": polars.UInt16, + "index": polars.UInt16, + "observations": polars.Float32, + "std": polars.Float32, + } + if response_type == "gen_data" + else { + "time": polars.Datetime("ms"), # type: ignore + "observations": polars.Float32, + "std": polars.Float32, + }, + ) + df = df.with_columns(observation_key=polars.lit(observation_key)) primary_key = ( @@ -30,22 +45,12 @@ def from_path(cls, path: Path) -> "ObservationDatasetInfo": ) if response_type == "summary": df = df.rename({"name": "response_key"}) - df = df.with_columns(polars.col("time").dt.cast_time_unit("ms")) if response_type == "gen_data": df = df.with_columns( - polars.col("report_step").cast(polars.UInt16), - polars.col("index").cast(polars.UInt16), response_key=polars.lit(response_key), ) - df = df.with_columns( - [ - polars.col("std").cast(polars.Float32), - polars.col("observations").cast(polars.Float32), - ] - ) - df = df[ ["response_key", "observation_key", *primary_key, "observations", "std"] ] @@ -71,27 +76,38 @@ def _migrate_responses_from_netcdf_to_parquet(path: Path) -> None: real_dirs = [*ens.glob("realization-*")] for real_dir in real_dirs: - for ds_name in ["gen_data", "summary"]: - if (real_dir / f"{ds_name}.nc").exists(): - gen_data_ds = xr.open_dataset( - real_dir / f"{ds_name}.nc", engine="scipy" + for response_type, schema_overrides in [ + ( + "gen_data", + { + "realization": polars.UInt16, + "report_step": polars.UInt16, + "index": polars.UInt16, + "values": polars.Float32, + }, + ), + ( + "summary", + { + "realization": polars.UInt16, + "time": polars.Datetime("ms"), + "values": polars.Float32, + }, + ), + ]: + if (real_dir / f"{response_type}.nc").exists(): + xr_ds = xr.open_dataset( + real_dir / f"{response_type}.nc", + engine="scipy", ) - pandas_df = gen_data_ds.to_dataframe().dropna().reset_index() + pandas_df = xr_ds.to_dataframe().dropna().reset_index() polars_df = polars.from_pandas( pandas_df, - schema_overrides={ - "values": polars.Float32, - "realization": polars.UInt16, - }, + schema_overrides=schema_overrides, # type: ignore ) polars_df = polars_df.rename({"name": "response_key"}) - if "time" in polars_df: - polars_df = polars_df.with_columns( - polars.col("time").dt.cast_time_unit("ms") - ) - # Ensure "response_key" is the first column polars_df = polars_df.select( ["response_key"] @@ -101,9 +117,9 @@ def _migrate_responses_from_netcdf_to_parquet(path: Path) -> None: if col != "response_key" ] ) - polars_df.write_parquet(real_dir / f"{ds_name}.parquet") + polars_df.write_parquet(real_dir / f"{response_type}.parquet") - os.remove(real_dir / f"{ds_name}.nc") + os.remove(real_dir / f"{response_type}.nc") def _migrate_observations_to_grouped_parquet(path: Path) -> None: diff --git a/tests/ert/unit_tests/storage/test_storage_migration.py b/tests/ert/unit_tests/storage/test_storage_migration.py index 6c7e8b4854c..d699cba141a 100644 --- a/tests/ert/unit_tests/storage/test_storage_migration.py +++ b/tests/ert/unit_tests/storage/test_storage_migration.py @@ -5,9 +5,11 @@ from pathlib import Path import numpy as np +import polars import pytest from packaging import version +from ert.analysis import ErtAnalysisError, smoother_update from ert.config import ErtConfig from ert.storage import open_storage from ert.storage.local_storage import ( @@ -355,3 +357,97 @@ def test_that_migrate_blockfs_creates_backup_folder(tmp_path, caplog): assert ( tmp_path / "storage" / "_blockfs_backup" / "ensembles" / "ens_dummy.txt" ).exists() + + +@pytest.mark.integration_test +@pytest.mark.usefixtures("copy_shared") +@pytest.mark.parametrize( + "ert_version", + [ + "10.3.1", + "8.4.5", + "8.0.11", + "6.0.5", + "5.0.0", + ], +) +def test_that_manual_update_from_migrated_storage_works( + tmp_path, + block_storage_path, + snapshot, + monkeypatch, + ert_version, +): + shutil.copytree( + block_storage_path / f"all_data_types/storage-{ert_version}", + tmp_path / "all_data_types" / f"storage-{ert_version}", + ) + monkeypatch.chdir(tmp_path / "all_data_types") + ert_config = ErtConfig.with_plugins().from_file("config.ert") + local_storage_set_ert_config(ert_config) + # To make sure all tests run against the same snapshot + snapshot.snapshot_dir = snapshot.snapshot_dir.parent + with open_storage(f"storage-{ert_version}", "w") as storage: + experiments = list(storage.experiments) + assert len(experiments) == 1 + experiment = experiments[0] + ensembles = list(experiment.ensembles) + assert len(ensembles) == 1 + prior_ens = ensembles[0] + + assert set(experiment.observations["gen_data"].schema.items()) == { + ("index", polars.UInt16), + ("observation_key", polars.String), + ("observations", polars.Float32), + ("report_step", polars.UInt16), + ("response_key", polars.String), + ("std", polars.Float32), + } + + assert set(experiment.observations["summary"].schema.items()) == { + ("observation_key", polars.String), + ("observations", polars.Float32), + ("response_key", polars.String), + ("std", polars.Float32), + ("time", polars.Datetime(time_unit="ms")), + } + + prior_gendata = prior_ens.load_responses( + "gen_data", tuple(range(prior_ens.ensemble_size)) + ) + prior_smry = prior_ens.load_responses( + "summary", tuple(range(prior_ens.ensemble_size)) + ) + + assert set(prior_gendata.schema.items()) == { + ("response_key", polars.String), + ("index", polars.UInt16), + ("realization", polars.UInt16), + ("report_step", polars.UInt16), + ("values", polars.Float32), + } + + assert set(prior_smry.schema.items()) == { + ("response_key", polars.String), + ("time", polars.Datetime(time_unit="ms")), + ("realization", polars.UInt16), + ("values", polars.Float32), + } + + posterior_ens = storage.create_ensemble( + prior_ens.experiment_id, + ensemble_size=prior_ens.ensemble_size, + iteration=1, + name="posterior", + prior_ensemble=prior_ens, + ) + + with pytest.raises( + ErtAnalysisError, match="No active observations for update step" + ): + smoother_update( + prior_ens, + posterior_ens, + list(experiment.observation_keys), + list(ert_config.ensemble_config.parameters), + )