Skip to content

Commit

Permalink
fix: resolve self-join issue
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq committed Nov 29, 2024
1 parent 36c621b commit 89cf488
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
21 changes: 20 additions & 1 deletion sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,25 @@ def join(
if on is None:
logger.warning("Got no value for on. This appears to change the join to a cross join.")
how = "cross"
on_cols = self._ensure_list_of_columns(on)
# If we are joining in another dataframe that comes from the same branch, we need to treat the other dataframe
# as a new branch so we properly differentiate the two dataframes in the final SQL output.
if self.branch_id == other_df.branch_id:
new_branch_id = self.session._random_branch_id
replacement_mapping = {
exp.to_identifier(other_df.branch_id): exp.to_identifier(new_branch_id)
}
other_df.branch_id = new_branch_id
other_df.expression = other_df.expression.transform(
replace_id_value, replacement_mapping
).assert_is(exp.Select)
for col in on_cols:
# We only want to update one side of the join if EQ so we find all EQs and update the right side
for eq in col.expression.find_all(exp.EQ):
eq.set(
"expression", eq.expression.transform(replace_id_value, replacement_mapping)
)

other_df = other_df._convert_leaf_to_cte()
join_expression = self._add_ctes_to_expression(self.expression, other_df.expression.ctes)
# We will determine actual "join on" expression later so we don't provide it at first
Expand All @@ -874,7 +893,7 @@ def join(
)
self_columns = self._get_outer_select_columns(join_expression)
other_columns = self._get_outer_select_columns(other_df.expression)
join_columns = self._ensure_and_normalize_cols(on)
join_columns = self._ensure_and_normalize_cols(on_cols)
# Determines the join clause and select columns to be used passed on what type of columns were provided for
# the join. The columns returned changes based on how the on expression is provided.
if how != "cross":
Expand Down
25 changes: 25 additions & 0 deletions tests/integration/test_int_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2247,3 +2247,28 @@ def test_chaining_joins_with_selects(
)

compare_frames(df, dfs, compare_schema=False)


# https://github.com/eakmanrq/sqlframe/issues/210
def test_self_join(
pyspark_employee: PySparkDataFrame,
get_df: t.Callable[[str], _BaseDataFrame],
compare_frames: t.Callable,
):
df_filtered = pyspark_employee.where(F.col("age") > 40)
df_joined = pyspark_employee.join(
df_filtered,
pyspark_employee["employee_id"] == df_filtered["employee_id"],
how="inner",
)

employee = get_df("employee")

dfs_filtered = employee.where(SF.col("age") > 40)
dfs_joined = employee.join(
dfs_filtered,
employee["employee_id"] == dfs_filtered["employee_id"],
how="inner",
)

compare_frames(df_joined, dfs_joined, compare_schema=False)

0 comments on commit 89cf488

Please sign in to comment.