diff --git a/CHANGES.rst b/CHANGES.rst index a8407f8f..73cc9da1 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,6 +6,9 @@ Version 3.9.0 (UNRELEASED) ========================== * Significant performance improvements for shuffle operations in :func:`~kartothek.io.dask.dataframe.update_dataset_from_ddf` +* Allow calling :func:`~kartothek.io.dask.dataframe.update_dataset_from_ddf` + without `partition_on` when `shuffle=True` + Version 3.8.2 (2020-04-09) ========================== diff --git a/kartothek/io/dask/_update.py b/kartothek/io/dask/_update.py index 3cebb91b..8f9fcb47 100644 --- a/kartothek/io/dask/_update.py +++ b/kartothek/io/dask/_update.py @@ -180,7 +180,7 @@ def update_dask_partitions_shuffle( store_factory: StoreFactoryType, df_serializer: DataFrameSerializer, dataset_uuid: str, - num_buckets: Optional[int], + num_buckets: int, sort_partitions_by: Optional[str], bucket_by: List[str], ) -> da.Array: @@ -236,11 +236,14 @@ def update_dask_partitions_shuffle( return ddf group_cols = partition_on.copy() - if num_buckets is not None: - meta = ddf._meta - meta[_KTK_HASH_BUCKET] = np.uint64(0) - ddf = ddf.map_partitions(_hash_bucket, bucket_by, num_buckets, meta=meta) - group_cols.append(_KTK_HASH_BUCKET) + + if num_buckets is None: + raise ValueError("``num_buckets`` must not be None when shuffling data.") + + meta = ddf._meta + meta[_KTK_HASH_BUCKET] = np.uint64(0) + ddf = ddf.map_partitions(_hash_bucket, bucket_by, num_buckets, meta=meta) + group_cols.append(_KTK_HASH_BUCKET) packed_meta = ddf._meta[group_cols] packed_meta[_PAYLOAD_COL] = b"" diff --git a/kartothek/io/dask/dataframe.py b/kartothek/io/dask/dataframe.py index 5c17bb5a..5d98c24d 100644 --- a/kartothek/io/dask/dataframe.py +++ b/kartothek/io/dask/dataframe.py @@ -39,6 +39,7 @@ def read_dataset_as_ddf( predicates=None, factory=None, dask_index_on=None, + dispatch_by=None, ): """ Retrieve a single table from a dataset as partition-individual :class:`~dask.dataframe.DataFrame` instance. @@ -50,7 +51,8 @@ def read_dataset_as_ddf( Parameters ---------- dask_index_on: str - Reconstruct (and set) a dask index on the provided index column. + Reconstruct (and set) a dask index on the provided index column. Cannot be used + in conjunction with `dispatch_by`. For details on performance, see also `dispatch_by` """ @@ -58,6 +60,13 @@ def read_dataset_as_ddf( raise TypeError( f"The paramter `dask_index_on` must be a string but got {type(dask_index_on)}" ) + + if dask_index_on is not None and dispatch_by is not None and len(dispatch_by) > 0: + raise ValueError( + "`read_dataset_as_ddf` got parameters `dask_index_on` and `dispatch_by`. " + "Note that `dispatch_by` can only be used if `dask_index_on` is None." + ) + ds_factory = _ensure_factory( dataset_uuid=dataset_uuid, store=store, @@ -84,7 +93,7 @@ def read_dataset_as_ddf( label_filter=label_filter, dates_as_object=dates_as_object, predicates=predicates, - dispatch_by=dask_index_on, + dispatch_by=dask_index_on if dask_index_on else dispatch_by, ) if dask_index_on: divisions = ds_factory.indices[dask_index_on].observed_values() @@ -239,10 +248,6 @@ def update_dataset_from_ddf( ds_factory=factory, ) - if shuffle and not partition_on: - raise ValueError( - "If ``shuffle`` is requested, at least one ``partition_on`` column needs to be provided." - ) if ds_factory is not None: check_single_table_dataset(ds_factory, table) @@ -260,7 +265,7 @@ def update_dataset_from_ddf( else: secondary_indices = _ensure_compatible_indices(ds_factory, secondary_indices) - if shuffle and partition_on: + if shuffle: mps = update_dask_partitions_shuffle( ddf=ddf, table=table, diff --git a/reference-data/arrow-compat/batch_generate_references.sh b/reference-data/arrow-compat/batch_generate_references.sh new file mode 100755 index 00000000..a7621dff --- /dev/null +++ b/reference-data/arrow-compat/batch_generate_references.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +# Note: this assumes you have kartothek installed in your current environment and you are using conda + +PYARROW_VERSIONS="0.14.1 0.15.0 0.16.0" + +for pyarrow_version in $PYARROW_VERSIONS; do + echo $pyarrow_version + conda install -y pyarrow==$pyarrow_version + ./generate_reference.py || (echo "Failed for version $pyarrow_version"; exit 1) +done diff --git a/tests/io/dask/dataframe/test_update.py b/tests/io/dask/dataframe/test_update.py index 18c8cba4..37d5f0f3 100644 --- a/tests/io/dask/dataframe/test_update.py +++ b/tests/io/dask/dataframe/test_update.py @@ -86,6 +86,53 @@ def _return_none(): return None +@pytest.mark.parametrize("bucket_by", [None, "range"]) +def test_update_shuffle_no_partition_on(store_factory, bucket_by): + df = pd.DataFrame( + { + "range": np.arange(10), + "range_duplicated": np.repeat(np.arange(2), 5), + "random": np.random.randint(0, 100, 10), + } + ) + ddf = dd.from_pandas(df, npartitions=10) + + with pytest.raises( + ValueError, match="``num_buckets`` must not be None when shuffling data." + ): + update_dataset_from_ddf( + ddf, + store_factory, + dataset_uuid="output_dataset_uuid", + table="table", + shuffle=True, + num_buckets=None, + bucket_by=bucket_by, + ).compute() + + res_default = update_dataset_from_ddf( + ddf, + store_factory, + dataset_uuid="output_dataset_uuid_default", + table="table", + shuffle=True, + bucket_by=bucket_by, + ).compute() + assert len(res_default.partitions) == 1 + + res = update_dataset_from_ddf( + ddf, + store_factory, + dataset_uuid="output_dataset_uuid", + table="table", + shuffle=True, + num_buckets=2, + bucket_by=bucket_by, + ).compute() + + assert len(res.partitions) == 2 + + @pytest.mark.parametrize("unique_primaries", [1, 4]) @pytest.mark.parametrize("unique_secondaries", [1, 3]) @pytest.mark.parametrize("num_buckets", [1, 5]) diff --git a/tests/serialization/test_arrow_compat.py b/tests/serialization/test_arrow_compat.py index a322538b..16918bde 100644 --- a/tests/serialization/test_arrow_compat.py +++ b/tests/serialization/test_arrow_compat.py @@ -34,10 +34,8 @@ def test_arrow_compat(arrow_version, reference_store, mocker): Test if reading/writing across the supported arrow versions is actually compatible - Generate new reference files with:: - - import pyarrow as pa - ParquetSerializer().store(reference_store, pa.__version__, orig) + Generate new reference files by going to the `reference-data/arrow-compat` directory and + executing `generate_reference.py` or `batch_generate_reference.sh`. """ uuid_hook = mocker.patch("kartothek.core.uuid._uuid_hook_object")