Skip to content

Commit

Permalink
[SPARK-45918][PS] Optimize MultiIndex.symmetric_difference
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Optimize `MultiIndex.symmetric_difference`

### Why are the changes needed?
currently, the `XOR` operation `a.union(b).subtract(a.intersect(b))` is not optimum:

```
        >>> midx1 = pd.MultiIndex([['lama', 'cow', 'falcon'],
        ...                        ['speed', 'weight', 'length']],
        ...                       [[0, 0, 0, 1, 1, 1, 2, 2, 2],
        ...                        [0, 0, 0, 0, 1, 2, 0, 1, 2]])
        >>> midx2 = pd.MultiIndex([['pandas-on-Spark', 'cow', 'falcon'],
        ...                        ['speed', 'weight', 'length']],
        ...                       [[0, 0, 0, 1, 1, 1, 2, 2, 2],
        ...                        [0, 0, 0, 0, 1, 2, 0, 1, 2]])
        >>> s1 = ps.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3],
        ...                index=midx1)
        >>> s2 = ps.Series([45, 200, 1.2, 30, 250, 1.5, 320, 1, 0.3],
        ...              index=midx2)
        >>> s1.index.symmetric_difference(s2.index)._internal.spark_frame.explain("extended")

```

before this PR:
```
== Optimized Logical Plan ==
Aggregate [__index_level_0__#0, __index_level_1__#1], [__index_level_0__#0, __index_level_1__#1, monotonically_increasing_id() AS __natural_order__#161L]
+- Union false, false
   :- Join LeftAnti, ((__index_level_0__#0 <=> __index_level_0__#145) AND (__index_level_1__#1 <=> __index_level_1__#146))
   :  :- Project [__index_level_0__#0, __index_level_1__#1]
   :  :  +- LogicalRDD [__index_level_0__#0, __index_level_1__#1, 0#2], false
   :  +- Aggregate [__index_level_0__#145, __index_level_1__#146], [__index_level_0__#145, __index_level_1__#146]
   :     +- Join LeftSemi, ((__index_level_0__#145 <=> __index_level_0__#149) AND (__index_level_1__#146 <=> __index_level_1__#150))
   :        :- Project [__index_level_0__#145, __index_level_1__#146]
   :        :  +- LogicalRDD [__index_level_0__#145, __index_level_1__#146, 0#147], false
   :        +- Project [__index_level_0__#149, __index_level_1__#150]
   :           +- LogicalRDD [__index_level_0__#149, __index_level_1__#150, 0#151], false
   +- Join LeftAnti, ((__index_level_0__#11 <=> __index_level_0__#145) AND (__index_level_1__#12 <=> __index_level_1__#146))
      :- Project [__index_level_0__#11, __index_level_1__#12]
      :  +- LogicalRDD [__index_level_0__#11, __index_level_1__#12, 0#13], false
      +- Aggregate [__index_level_0__#145, __index_level_1__#146], [__index_level_0__#145, __index_level_1__#146]
         +- Join LeftSemi, ((__index_level_0__#145 <=> __index_level_0__#149) AND (__index_level_1__#146 <=> __index_level_1__#150))
            :- Project [__index_level_0__#145, __index_level_1__#146]
            :  +- LogicalRDD [__index_level_0__#145, __index_level_1__#146, 0#147], false
            +- Project [__index_level_0__#149, __index_level_1__#150]
               +- LogicalRDD [__index_level_0__#149, __index_level_1__#150, 0#151], false

```

after this PR:
```
== Optimized Logical Plan ==
Project [__index_level_0__#0, __index_level_1__#1, monotonically_increasing_id() AS __natural_order__#64L]
+- Filter ((isnotnull(__multi_index_min_tag__#46) AND isnotnull(__multi_index_max_tag__#47)) AND (__multi_index_min_tag__#46 = __multi_index_max_tag__#47))
   +- Aggregate [__index_level_0__#0, __index_level_1__#1], [__index_level_0__#0, __index_level_1__#1, min(__multi_index_tag__#30) AS __multi_index_min_tag__#46, max(__multi_index_tag__#30) AS __multi_index_max_tag__#47]
      +- Union false, false
         :- Project [__index_level_0__#0, __index_level_1__#1, 0 AS __multi_index_tag__#30]
         :  +- LogicalRDD [__index_level_0__#0, __index_level_1__#1, 0#2], false
         +- Project [__index_level_0__#11, __index_level_1__#12, 1 AS __multi_index_tag__#34]
            +- LogicalRDD [__index_level_0__#11, __index_level_1__#12, 0#13], false
```

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#43795 from zhengruifeng/ps_multi_index_opt.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Nov 15, 2023
1 parent 121b6b8 commit b2de91c
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion python/pyspark/pandas/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,18 @@ def symmetric_difference( # type: ignore[override]
sdf_self = self._psdf._internal.spark_frame.select(self._internal.index_spark_columns)
sdf_other = other._psdf._internal.spark_frame.select(other._internal.index_spark_columns)

sdf_symdiff = sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other))
tmp_tag_col = verify_temp_column_name(sdf_self, "__multi_index_tag__")
tmp_max_col = verify_temp_column_name(sdf_self, "__multi_index_max_tag__")
tmp_min_col = verify_temp_column_name(sdf_self, "__multi_index_min_tag__")

sdf_symdiff = (
sdf_self.withColumn(tmp_tag_col, F.lit(0))
.union(sdf_other.withColumn(tmp_tag_col, F.lit(1)))
.groupBy(*self._internal.index_spark_column_names)
.agg(F.min(tmp_tag_col).alias(tmp_min_col), F.max(tmp_tag_col).alias(tmp_max_col))
.where(F.col(tmp_min_col) == F.col(tmp_max_col))
.select(*self._internal.index_spark_column_names)
)

if sort:
sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_column_names)
Expand Down

0 comments on commit b2de91c

Please sign in to comment.