diff --git a/CHANGES.rst b/CHANGES.rst index ab499ff8..b2e4f9b3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -12,6 +12,7 @@ Improvements Bug fixes ^^^^^^^^^ +* GH248 Fix an issue causing a ValueError to be raised when using `dask_index_on` on non-integer columns * GH255 Fix an issue causing the python interpreter to shut down when reading an empty file (see also https://issues.apache.org/jira/browse/ARROW-8142) diff --git a/asv_bench/benchmarks/index.py b/asv_bench/benchmarks/index.py index 44b1eafd..7ee7b92c 100644 --- a/asv_bench/benchmarks/index.py +++ b/asv_bench/benchmarks/index.py @@ -80,6 +80,9 @@ def time_as_series_partitions_as_index( ): self.ktk_index.as_flat_series(partitions_as_index=True) + def time_observed_values(self, number_values, number_partitions, arrow_type): + self.ktk_index.observed_values() + class SerializeIndex(IndexBase): timeout = 180 diff --git a/kartothek/core/index.py b/kartothek/core/index.py index bdb68472..a1b6f34a 100644 --- a/kartothek/core/index.py +++ b/kartothek/core/index.py @@ -132,15 +132,13 @@ def __repr__(self) -> str: class_=type(self).__name__, attrs=", ".join(repr_str) ) - def observed_values(self) -> np.array: + def observed_values(self, date_as_object=True) -> np.array: """ Return an array of all observed values """ - return np.fromiter( - (self.normalize_value(self.dtype, x) for x in self.index_dct.keys()), - count=len(self.index_dct), - dtype=self.dtype.to_pandas_dtype(), - ) + keys = np.array(list(self.index_dct.keys())) + labeled_array = pa.array(keys, type=self.dtype) + return np.array(labeled_array.to_pandas(date_as_object=date_as_object)) @staticmethod def normalize_value(dtype: pa.DataType, value: Any) -> Any: diff --git a/tests/core/test_index.py b/tests/core/test_index.py index 75f48df8..c11cc1b7 100644 --- a/tests/core/test_index.py +++ b/tests/core/test_index.py @@ -14,6 +14,7 @@ from hypothesis import assume, given from pandas.testing import assert_series_equal +from kartothek.core._compat import ARROW_LARGER_EQ_0150 from kartothek.core.index import ExplicitSecondaryIndex, IndexBase, merge_indices from kartothek.core.testing import get_numpy_array_strategy @@ -472,6 +473,50 @@ def test_index_raises_null_dtype(): assert str(exc.value) == "Indices w/ null/NA type are not supported" +@pytest.mark.parametrize( + "dtype,value", + [ + (pa.bool_(), True), + (pa.int64(), 1), + (pa.float64(), 1.1), + (pa.binary(), b"x"), + (pa.string(), "x"), + (pa.timestamp("ns"), pd.Timestamp("2018-01-01").to_datetime64()), + (pa.date32(), datetime.date(2018, 1, 1)), + pytest.param( + pa.timestamp("ns", tz=pytz.timezone("Europe/Berlin")), + pd.Timestamp("2018-01-01", tzinfo=pytz.timezone("Europe/Berlin")), + marks=pytest.mark.xfail( + not ARROW_LARGER_EQ_0150, + reason="Timezone reoundtrips not supported in older versions", + ), + ), + ], +) +def test_observed_values_plain(dtype, value): + ind = ExplicitSecondaryIndex( + column="col", dtype=dtype, index_dct={value: ["part_label"]} + ) + observed = ind.observed_values() + assert len(observed) == 1 + assert list(observed) == [value] + + +@pytest.mark.parametrize("date_as_object", [None, True, False]) +def test_observed_values_date_as_object(date_as_object): + value = datetime.date(2020, 1, 1) + ind = ExplicitSecondaryIndex( + column="col", dtype=pa.date32(), index_dct={value: ["part_label"]} + ) + observed = ind.observed_values(date_as_object=date_as_object) + if date_as_object: + expected = value + else: + expected = pd.Timestamp(value).to_datetime64() + assert len(observed) == 1 + assert observed[0] == expected + + @pytest.mark.parametrize( "dtype,value,expected", [ diff --git a/tests/io/dask/dataframe/test_read.py b/tests/io/dask/dataframe/test_read.py index 752d7462..542b12a7 100644 --- a/tests/io/dask/dataframe/test_read.py +++ b/tests/io/dask/dataframe/test_read.py @@ -10,6 +10,7 @@ from pandas import testing as pdt from pandas.testing import assert_frame_equal +from kartothek.core.testing import get_dataframe_not_nested from kartothek.io.dask.dataframe import read_dataset_as_ddf from kartothek.io.eager import store_dataframes_as_dataset from kartothek.io.testing.read import * # noqa @@ -169,6 +170,34 @@ def test_reconstruct_dask_index(store_factory, index_type, monkeypatch): assert_frame_equal(ddf_expected_simple.compute(), ddf.compute()) +@pytest.fixture() +def setup_reconstruct_dask_index_types(store_factory, df_not_nested): + indices = list(df_not_nested.columns) + indices.remove("null") + return store_dataframes_as_dataset( + store=store_factory, + dataset_uuid="dataset_uuid", + dfs=[df_not_nested], + secondary_indices=indices, + ) + + +@pytest.mark.parametrize("col", get_dataframe_not_nested().columns) +def test_reconstruct_dask_index_types( + store_factory, setup_reconstruct_dask_index_types, col +): + if col == "null": + pytest.xfail(reason="Cannot index null column") + ddf = read_dataset_as_ddf( + dataset_uuid=setup_reconstruct_dask_index_types.uuid, + store=store_factory, + table="table", + dask_index_on=col, + ) + assert ddf.known_divisions + assert ddf.index.name == col + + def test_reconstruct_dask_index_sorting(store_factory, monkeypatch): # Make sure we're not shuffling anything