Skip to content

Commit

Permalink
fix: handle double column selects with dataframe lookup (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored May 24, 2024
1 parent 7b220a8 commit e7c7dd4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,7 @@ def withColumn(self, colName: str, col: Column) -> Self:
expression = self.expression.copy()
expression.expressions[existing_col_index] = col.alias(col_name).expression
return self.copy(expression=expression)
return self.copy().select(col.alias(col_name), append=True)
return self.select.__wrapped__(self, col.alias(col_name), append=True) # type: ignore

@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str) -> Self:
Expand Down
1 change: 0 additions & 1 deletion sqlframe/base/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def normalize(session: SESSION, expression_context: exp.Select, expr: t.List[NOR
expr = ensure_list(expr)
expressions = _ensure_expressions(expr)
for expression in expressions:
# normalize_identifiers(expression, session.input_dialect)
identifiers = expression.find_all(exp.Identifier)
for identifier in identifiers:
identifier.transform(session.input_dialect.normalize_identifier)
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/standalone/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,25 @@ def test_with_column_duplicate_alias(standalone_employee: StandaloneDataFrame):
)


# https://github.com/eakmanrq/sqlframe/issues/19
def test_with_column_dual_expression(standalone_employee: StandaloneDataFrame):
df1 = standalone_employee.withColumn("new_col1", standalone_employee.age)
df2 = df1.withColumn("new_col2", standalone_employee.store_id)
assert df2.columns == [
"employee_id",
"fname",
"lname",
"age",
"store_id",
"new_col1",
"new_col2",
]
assert (
df2.sql(pretty=False)
== "SELECT `a1`.`employee_id` AS `employee_id`, CAST(`a1`.`fname` AS STRING) AS `fname`, CAST(`a1`.`lname` AS STRING) AS `lname`, `a1`.`age` AS `age`, `a1`.`store_id` AS `store_id`, `a1`.`age` AS `new_col1`, `a1`.`store_id` AS `new_col2` FROM VALUES (1, 'Jack', 'Shephard', 37, 1), (2, 'John', 'Locke', 65, 1), (3, 'Kate', 'Austen', 37, 2), (4, 'Claire', 'Littleton', 27, 2), (5, 'Hugo', 'Reyes', 29, 100) AS `a1`(`employee_id`, `fname`, `lname`, `age`, `store_id`)"
)


def test_where_expr(standalone_employee: StandaloneDataFrame):
df = standalone_employee.where("fname = 'Jack' AND age = 37")
assert df.columns == ["employee_id", "fname", "lname", "age", "store_id"]
Expand Down

0 comments on commit e7c7dd4

Please sign in to comment.