Skip to content

Commit

Permalink
feat: add printSchema support for duckdb and postgres (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored May 25, 2024
1 parent e7c7dd4 commit 80591aa
Show file tree
Hide file tree
Showing 14 changed files with 240 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/duckdb.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ df_store = session.createDataFrame(
* [na](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.na.html)
* [orderBy](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.orderBy.html)
* [persist](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.persist.html)
* [printSchema](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.printSchema.html)
* [replace](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.replace.html)
* [select](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.select.html)
* [show](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.show.html)
Expand Down
1 change: 1 addition & 0 deletions docs/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ df_store = session.createDataFrame(
* [na](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.na.html)
* [orderBy](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.orderBy.html)
* [persist](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.persist.html)
* [printSchema](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.printSchema.html)
* [replace](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.replace.html)
* [select](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.select.html)
* [show](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.show.html)
Expand Down
7 changes: 6 additions & 1 deletion sqlframe/base/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
class _BaseCatalog(t.Generic[SESSION, DF]):
"""User-facing catalog API, accessible through `SparkSession.catalog`."""

TEMP_CATALOG_FILTER: t.Optional[exp.Expression] = None
TEMP_SCHEMA_FILTER: t.Optional[exp.Expression] = None

def __init__(self, sparkSession: SESSION, schema: t.Optional[MappingSchema] = None) -> None:
"""Create a new Catalog that wraps the underlying JVM object."""
self.session = sparkSession
Expand Down Expand Up @@ -569,7 +572,9 @@ def listTables(
"""
raise NotImplementedError

def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]:
def listColumns(
self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False
) -> t.List[Column]:
"""Returns a t.List of columns for the given table/view in the specified database.
.. versionadded:: 2.0.0
Expand Down
20 changes: 11 additions & 9 deletions sqlframe/base/mixins/catalog_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def listTables(

class ListColumnsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]):
@normalize(["tableName", "dbName"])
def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]:
def listColumns(
self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False
) -> t.List[Column]:
"""Returns a t.List of columns for the given table/view in the specified database.
.. versionadded:: 2.0.0
Expand Down Expand Up @@ -385,12 +387,6 @@ def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[
"catalog",
exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect),
)
# if self.QUALIFY_INFO_SCHEMA_WITH_DATABASE:
# if not table.db:
# raise ValueError("dbName must be specified when listing columns from INFORMATION_SCHEMA")
# source_table = f"{table.db}.INFORMATION_SCHEMA.COLUMNS"
# else:
# source_table = "INFORMATION_SCHEMA.COLUMNS"
source_table = self._get_info_schema_table("columns", database=table.db)
select = (
exp.select(
Expand All @@ -402,9 +398,15 @@ def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[
.where(exp.column("table_name").eq(table.name))
)
if table.db:
select = select.where(exp.column("table_schema").eq(table.db))
schema_filter: exp.Expression = exp.column("table_schema").eq(table.db)
if include_temp and self.TEMP_SCHEMA_FILTER:
schema_filter = exp.Or(this=schema_filter, expression=self.TEMP_SCHEMA_FILTER)
select = select.where(schema_filter)
if table.catalog:
select = select.where(exp.column("table_catalog").eq(table.catalog))
catalog_filter: exp.Expression = exp.column("table_catalog").eq(table.catalog)
if include_temp and self.TEMP_CATALOG_FILTER:
catalog_filter = exp.Or(this=catalog_filter, expression=self.TEMP_CATALOG_FILTER)
select = select.where(catalog_filter)
results = self.session._fetch_rows(select)
return [
Column(
Expand Down
63 changes: 63 additions & 0 deletions sqlframe/base/mixins/dataframe_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import typing as t

from sqlglot import exp

from sqlframe.base.catalog import Column
from sqlframe.base.dataframe import (
GROUP_DATA,
NA,
SESSION,
STAT,
WRITER,
_BaseDataFrame,
)


class PrintSchemaFromTempObjectsMixin(
_BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA]
):
def _get_columns_from_temp_object(self) -> t.List[Column]:
table = exp.to_table(self.session._random_id)
self.session._execute(
exp.Create(
this=table,
kind="VIEW",
replace=True,
properties=exp.Properties(expressions=[exp.TemporaryProperty()]),
expression=self.expression,
)
)
return self.session.catalog.listColumns(
table.sql(dialect=self.session.input_dialect), include_temp=True
)

def printSchema(self, level: t.Optional[int] = None) -> None:
def print_schema(
column_name: str, column_type: exp.DataType, nullable: bool, current_level: int
):
if level and current_level >= level:
return
if current_level > 0:
print(" | " * current_level, end="")
print(
f" |-- {column_name}: {column_type.sql(self.session.output_dialect).lower()} (nullable = {str(nullable).lower()})"
)
if column_type.this == exp.DataType.Type.STRUCT:
for column_def in column_type.expressions:
print_schema(column_def.name, column_def.args["kind"], True, current_level + 1)
if column_type.this == exp.DataType.Type.ARRAY:
for data_type in column_type.expressions:
print_schema("element", data_type, True, current_level + 1)
if column_type.this == exp.DataType.Type.MAP:
print_schema("key", column_type.expressions[0], True, current_level + 1)
print_schema("value", column_type.expressions[1], True, current_level + 1)

columns = self._get_columns_from_temp_object()
print("root")
for column in columns:
print_schema(
column.name,
exp.DataType.build(column.dataType, dialect=self.session.output_dialect),
column.nullable,
0,
)
4 changes: 3 additions & 1 deletion sqlframe/bigquery/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def currentDatabase(self) -> str:
return to_schema(self.session.default_dataset).db

@normalize(["tableName", "dbName"])
def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]:
def listColumns(
self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False
) -> t.List[Column]:
"""Returns a t.List of columns for the given table/view in the specified database.
.. versionadded:: 2.0.0
Expand Down
2 changes: 2 additions & 0 deletions sqlframe/duckdb/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class DuckDBCatalog(
ListColumnsFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"],
_BaseCatalog["DuckDBSession", "DuckDBDataFrame"],
):
TEMP_CATALOG_FILTER = exp.column("table_catalog").eq("temp")

def listFunctions(
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
) -> t.List[Function]:
Expand Down
4 changes: 3 additions & 1 deletion sqlframe/duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
from sqlframe.base.mixins.dataframe_mixins import PrintSchemaFromTempObjectsMixin
from sqlframe.duckdb.group import DuckDBGroupedData

if sys.version_info >= (3, 11):
Expand All @@ -34,13 +35,14 @@ class DuckDBDataFrameStatFunctions(_BaseDataFrameStatFunctions["DuckDBDataFrame"


class DuckDBDataFrame(
PrintSchemaFromTempObjectsMixin,
_BaseDataFrame[
"DuckDBSession",
"DuckDBDataFrameWriter",
"DuckDBDataFrameNaFunctions",
"DuckDBDataFrameStatFunctions",
"DuckDBGroupedData",
]
],
):
_na = DuckDBDataFrameNaFunctions
_stat = DuckDBDataFrameStatFunctions
Expand Down
1 change: 1 addition & 0 deletions sqlframe/postgres/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class PostgresCatalog(
_BaseCatalog["PostgresSession", "PostgresDataFrame"],
):
CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.column("current_catalog")
TEMP_SCHEMA_FILTER = exp.column("table_schema").like("pg_temp_%")

def listFunctions(
self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None
Expand Down
4 changes: 3 additions & 1 deletion sqlframe/postgres/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_BaseDataFrameNaFunctions,
_BaseDataFrameStatFunctions,
)
from sqlframe.base.mixins.dataframe_mixins import PrintSchemaFromTempObjectsMixin
from sqlframe.postgres.group import PostgresGroupedData

if sys.version_info >= (3, 11):
Expand All @@ -33,13 +34,14 @@ class PostgresDataFrameStatFunctions(_BaseDataFrameStatFunctions["PostgresDataFr


class PostgresDataFrame(
PrintSchemaFromTempObjectsMixin,
_BaseDataFrame[
"PostgresSession",
"PostgresDataFrameWriter",
"PostgresDataFrameNaFunctions",
"PostgresDataFrameStatFunctions",
"PostgresGroupedData",
]
],
):
_na = PostgresDataFrameNaFunctions
_stat = PostgresDataFrameStatFunctions
Expand Down
4 changes: 3 additions & 1 deletion sqlframe/spark/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,9 @@ def listTables(
)
return [Table(*x) for x in self._spark_catalog.listTables(dbName, pattern)]

def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]:
def listColumns(
self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False
) -> t.List[Column]:
"""Returns a t.List of columns for the given table/view in the specified database.
.. versionadded:: 2.0.0
Expand Down
79 changes: 79 additions & 0 deletions tests/integration/engines/duck/test_duckdb_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import datetime

from sqlframe.base.types import Row
from sqlframe.duckdb import DuckDBDataFrame, DuckDBSession

pytest_plugins = ["tests.integration.fixtures"]


def test_print_schema_basic(duckdb_employee: DuckDBDataFrame, capsys):
duckdb_employee.printSchema()
captured = capsys.readouterr()
assert (
captured.out.strip()
== """
root
|-- employee_id: int (nullable = true)
|-- fname: text (nullable = true)
|-- lname: text (nullable = true)
|-- age: int (nullable = true)
|-- store_id: int (nullable = true)""".strip()
)


def test_print_schema_nested(duckdb_session: DuckDBSession, capsys):
df = duckdb_session.createDataFrame(
[
(
1,
2.0,
"foo",
{"a": 1},
[Row(a=1, b=2)],
[1, 2, 3],
Row(a=1),
datetime.date(2022, 1, 1),
datetime.datetime(2022, 1, 1, 0, 0, 0),
datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc),
True,
)
],
[
"bigint_col",
"double_col",
"string_col",
"map<string,bigint>_col",
"array<struct<a:bigint,b:bigint>>",
"array<bigint>_col",
"struct<a:bigint>_col",
"date_col",
"timestamp_col",
"timestamptz_col",
"boolean_col",
],
)
df.printSchema()
captured = capsys.readouterr()
assert (
captured.out.strip()
== """
root
|-- bigint_col: bigint (nullable = true)
|-- double_col: double (nullable = true)
|-- string_col: text (nullable = true)
|-- map<string,bigint>_col: map(text, bigint) (nullable = true)
| |-- key: text (nullable = true)
| |-- value: bigint (nullable = true)
|-- array<struct<a:bigint,b:bigint>>: struct(a bigint, b bigint)[] (nullable = true)
| |-- element: struct(a bigint, b bigint) (nullable = true)
| | |-- a: bigint (nullable = true)
| | |-- b: bigint (nullable = true)
|-- array<bigint>_col: bigint[] (nullable = true)
| |-- element: bigint (nullable = true)
|-- struct<a:bigint>_col: struct(a bigint) (nullable = true)
| |-- a: bigint (nullable = true)
|-- date_col: date (nullable = true)
|-- timestamp_col: timestamp (nullable = true)
|-- timestamptz_col: timestamptz (nullable = true)
|-- boolean_col: boolean (nullable = true)""".strip()
)
64 changes: 64 additions & 0 deletions tests/integration/engines/postgres/test_postgres_dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import datetime

from sqlframe.base.types import Row
from sqlframe.duckdb import DuckDBDataFrame, DuckDBSession

pytest_plugins = ["tests.integration.fixtures"]


def test_print_schema_basic(postgres_employee: DuckDBDataFrame, capsys):
postgres_employee.printSchema()
captured = capsys.readouterr()
assert (
captured.out.strip()
== """
root
|-- employee_id: int (nullable = true)
|-- fname: text (nullable = true)
|-- lname: text (nullable = true)
|-- age: int (nullable = true)
|-- store_id: int (nullable = true)""".strip()
)


def test_print_schema_nested(postgres_session: DuckDBSession, capsys):
df = postgres_session.createDataFrame(
[
(
1,
2.0,
"foo",
[1, 2, 3],
datetime.date(2022, 1, 1),
datetime.datetime(2022, 1, 1, 0, 0, 0),
datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc),
True,
)
],
[
"bigint_col",
"double_col",
"string_col",
"array<bigint>_col",
"date_col",
"timestamp_col",
"timestamptz_col",
"boolean_col",
],
)
df.printSchema()
captured = capsys.readouterr()
# array does not include type
assert (
captured.out.strip()
== """
root
|-- bigint_col: bigint (nullable = true)
|-- double_col: double precision (nullable = true)
|-- string_col: text (nullable = true)
|-- array<bigint>_col: array (nullable = true)
|-- date_col: date (nullable = true)
|-- timestamp_col: timestamp (nullable = true)
|-- timestamptz_col: timestamptz (nullable = true)
|-- boolean_col: boolean (nullable = true)""".strip()
)
2 changes: 0 additions & 2 deletions tests/integration/engines/test_int_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from collections import Counter

import pytest
from pyspark.sql import DataFrame
from pyspark.sql import SparkSession as PySparkSession
from sqlglot import exp

Expand Down Expand Up @@ -175,7 +174,6 @@ def test_col(get_session_and_func, arg):
)
def test_typeof(get_session_and_func, get_types, arg, expected):
session, typeof = get_session_and_func("typeof")
types = get_types(session)
# If we just pass a struct in for values then Spark will automatically explode the struct into columns
# it won't do this though if there is another column so that is why we include an ignore column
df = session.createDataFrame([(1, arg)], schema=["ignore_col", "col"])
Expand Down

0 comments on commit 80591aa

Please sign in to comment.