Skip to content

Commit

Permalink
Merge pull request #253 from fjetter/bugfix/gh248
Browse files Browse the repository at this point in the history
Allow dask_index_on for all col types
  • Loading branch information
fjetter authored Mar 19, 2020
2 parents 5bc0a53 + c29bf98 commit 4d4f9e0
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Improvements

Bug fixes
^^^^^^^^^
* GH248 Fix an issue causing a ValueError to be raised when using `dask_index_on` on non-integer columns
* GH255 Fix an issue causing the python interpreter to shut down when reading an
empty file (see also https://issues.apache.org/jira/browse/ARROW-8142)

Expand Down
3 changes: 3 additions & 0 deletions asv_bench/benchmarks/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def time_as_series_partitions_as_index(
):
self.ktk_index.as_flat_series(partitions_as_index=True)

def time_observed_values(self, number_values, number_partitions, arrow_type):
self.ktk_index.observed_values()


class SerializeIndex(IndexBase):
timeout = 180
Expand Down
10 changes: 4 additions & 6 deletions kartothek/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,13 @@ def __repr__(self) -> str:
class_=type(self).__name__, attrs=", ".join(repr_str)
)

def observed_values(self) -> np.array:
def observed_values(self, date_as_object=True) -> np.array:
"""
Return an array of all observed values
"""
return np.fromiter(
(self.normalize_value(self.dtype, x) for x in self.index_dct.keys()),
count=len(self.index_dct),
dtype=self.dtype.to_pandas_dtype(),
)
keys = np.array(list(self.index_dct.keys()))
labeled_array = pa.array(keys, type=self.dtype)
return np.array(labeled_array.to_pandas(date_as_object=date_as_object))

@staticmethod
def normalize_value(dtype: pa.DataType, value: Any) -> Any:
Expand Down
45 changes: 45 additions & 0 deletions tests/core/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hypothesis import assume, given
from pandas.testing import assert_series_equal

from kartothek.core._compat import ARROW_LARGER_EQ_0150
from kartothek.core.index import ExplicitSecondaryIndex, IndexBase, merge_indices
from kartothek.core.testing import get_numpy_array_strategy

Expand Down Expand Up @@ -472,6 +473,50 @@ def test_index_raises_null_dtype():
assert str(exc.value) == "Indices w/ null/NA type are not supported"


@pytest.mark.parametrize(
"dtype,value",
[
(pa.bool_(), True),
(pa.int64(), 1),
(pa.float64(), 1.1),
(pa.binary(), b"x"),
(pa.string(), "x"),
(pa.timestamp("ns"), pd.Timestamp("2018-01-01").to_datetime64()),
(pa.date32(), datetime.date(2018, 1, 1)),
pytest.param(
pa.timestamp("ns", tz=pytz.timezone("Europe/Berlin")),
pd.Timestamp("2018-01-01", tzinfo=pytz.timezone("Europe/Berlin")),
marks=pytest.mark.xfail(
not ARROW_LARGER_EQ_0150,
reason="Timezone reoundtrips not supported in older versions",
),
),
],
)
def test_observed_values_plain(dtype, value):
ind = ExplicitSecondaryIndex(
column="col", dtype=dtype, index_dct={value: ["part_label"]}
)
observed = ind.observed_values()
assert len(observed) == 1
assert list(observed) == [value]


@pytest.mark.parametrize("date_as_object", [None, True, False])
def test_observed_values_date_as_object(date_as_object):
value = datetime.date(2020, 1, 1)
ind = ExplicitSecondaryIndex(
column="col", dtype=pa.date32(), index_dct={value: ["part_label"]}
)
observed = ind.observed_values(date_as_object=date_as_object)
if date_as_object:
expected = value
else:
expected = pd.Timestamp(value).to_datetime64()
assert len(observed) == 1
assert observed[0] == expected


@pytest.mark.parametrize(
"dtype,value,expected",
[
Expand Down
29 changes: 29 additions & 0 deletions tests/io/dask/dataframe/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pandas import testing as pdt
from pandas.testing import assert_frame_equal

from kartothek.core.testing import get_dataframe_not_nested
from kartothek.io.dask.dataframe import read_dataset_as_ddf
from kartothek.io.eager import store_dataframes_as_dataset
from kartothek.io.testing.read import * # noqa
Expand Down Expand Up @@ -169,6 +170,34 @@ def test_reconstruct_dask_index(store_factory, index_type, monkeypatch):
assert_frame_equal(ddf_expected_simple.compute(), ddf.compute())


@pytest.fixture()
def setup_reconstruct_dask_index_types(store_factory, df_not_nested):
indices = list(df_not_nested.columns)
indices.remove("null")
return store_dataframes_as_dataset(
store=store_factory,
dataset_uuid="dataset_uuid",
dfs=[df_not_nested],
secondary_indices=indices,
)


@pytest.mark.parametrize("col", get_dataframe_not_nested().columns)
def test_reconstruct_dask_index_types(
store_factory, setup_reconstruct_dask_index_types, col
):
if col == "null":
pytest.xfail(reason="Cannot index null column")
ddf = read_dataset_as_ddf(
dataset_uuid=setup_reconstruct_dask_index_types.uuid,
store=store_factory,
table="table",
dask_index_on=col,
)
assert ddf.known_divisions
assert ddf.index.name == col


def test_reconstruct_dask_index_sorting(store_factory, monkeypatch):

# Make sure we're not shuffling anything
Expand Down

0 comments on commit 4d4f9e0

Please sign in to comment.