From 80591aaa795b81243c79f36aeaff20c21b13c305 Mon Sep 17 00:00:00 2001 From: Ryan Eakman <6326532+eakmanrq@users.noreply.github.com> Date: Fri, 24 May 2024 19:40:08 -0700 Subject: [PATCH] feat: add printSchema support for duckdb and postgres (#29) --- docs/duckdb.md | 1 + docs/postgres.md | 1 + sqlframe/base/catalog.py | 7 +- sqlframe/base/mixins/catalog_mixins.py | 20 ++--- sqlframe/base/mixins/dataframe_mixins.py | 63 +++++++++++++++ sqlframe/bigquery/catalog.py | 4 +- sqlframe/duckdb/catalog.py | 2 + sqlframe/duckdb/dataframe.py | 4 +- sqlframe/postgres/catalog.py | 1 + sqlframe/postgres/dataframe.py | 4 +- sqlframe/spark/catalog.py | 4 +- .../engines/duck/test_duckdb_dataframe.py | 79 +++++++++++++++++++ .../postgres/test_postgres_dataframe.py | 64 +++++++++++++++ .../integration/engines/test_int_functions.py | 2 - 14 files changed, 240 insertions(+), 16 deletions(-) create mode 100644 sqlframe/base/mixins/dataframe_mixins.py create mode 100644 tests/integration/engines/duck/test_duckdb_dataframe.py create mode 100644 tests/integration/engines/postgres/test_postgres_dataframe.py diff --git a/docs/duckdb.md b/docs/duckdb.md index 4d33c7b..28a15eb 100644 --- a/docs/duckdb.md +++ b/docs/duckdb.md @@ -171,6 +171,7 @@ df_store = session.createDataFrame( * [na](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.na.html) * [orderBy](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.orderBy.html) * [persist](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.persist.html) +* [printSchema](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.printSchema.html) * [replace](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.replace.html) * [select](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.select.html) * [show](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.show.html) diff --git a/docs/postgres.md b/docs/postgres.md index 8a460dc..0b6c28a 100644 --- a/docs/postgres.md +++ b/docs/postgres.md @@ -174,6 +174,7 @@ df_store = session.createDataFrame( * [na](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.na.html) * [orderBy](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.orderBy.html) * [persist](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.persist.html) +* [printSchema](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.printSchema.html) * [replace](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.replace.html) * [select](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.select.html) * [show](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.show.html) diff --git a/sqlframe/base/catalog.py b/sqlframe/base/catalog.py index 80fa45b..e9b9cf0 100644 --- a/sqlframe/base/catalog.py +++ b/sqlframe/base/catalog.py @@ -26,6 +26,9 @@ class _BaseCatalog(t.Generic[SESSION, DF]): """User-facing catalog API, accessible through `SparkSession.catalog`.""" + TEMP_CATALOG_FILTER: t.Optional[exp.Expression] = None + TEMP_SCHEMA_FILTER: t.Optional[exp.Expression] = None + def __init__(self, sparkSession: SESSION, schema: t.Optional[MappingSchema] = None) -> None: """Create a new Catalog that wraps the underlying JVM object.""" self.session = sparkSession @@ -569,7 +572,9 @@ def listTables( """ raise NotImplementedError - def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]: + def listColumns( + self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False + ) -> t.List[Column]: """Returns a t.List of columns for the given table/view in the specified database. .. versionadded:: 2.0.0 diff --git a/sqlframe/base/mixins/catalog_mixins.py b/sqlframe/base/mixins/catalog_mixins.py index bedb146..bb1c361 100644 --- a/sqlframe/base/mixins/catalog_mixins.py +++ b/sqlframe/base/mixins/catalog_mixins.py @@ -315,7 +315,9 @@ def listTables( class ListColumnsFromInfoSchemaMixin(_BaseInfoSchemaMixin, t.Generic[SESSION, DF]): @normalize(["tableName", "dbName"]) - def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]: + def listColumns( + self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False + ) -> t.List[Column]: """Returns a t.List of columns for the given table/view in the specified database. .. versionadded:: 2.0.0 @@ -385,12 +387,6 @@ def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[ "catalog", exp.parse_identifier(self.currentCatalog(), dialect=self.session.input_dialect), ) - # if self.QUALIFY_INFO_SCHEMA_WITH_DATABASE: - # if not table.db: - # raise ValueError("dbName must be specified when listing columns from INFORMATION_SCHEMA") - # source_table = f"{table.db}.INFORMATION_SCHEMA.COLUMNS" - # else: - # source_table = "INFORMATION_SCHEMA.COLUMNS" source_table = self._get_info_schema_table("columns", database=table.db) select = ( exp.select( @@ -402,9 +398,15 @@ def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[ .where(exp.column("table_name").eq(table.name)) ) if table.db: - select = select.where(exp.column("table_schema").eq(table.db)) + schema_filter: exp.Expression = exp.column("table_schema").eq(table.db) + if include_temp and self.TEMP_SCHEMA_FILTER: + schema_filter = exp.Or(this=schema_filter, expression=self.TEMP_SCHEMA_FILTER) + select = select.where(schema_filter) if table.catalog: - select = select.where(exp.column("table_catalog").eq(table.catalog)) + catalog_filter: exp.Expression = exp.column("table_catalog").eq(table.catalog) + if include_temp and self.TEMP_CATALOG_FILTER: + catalog_filter = exp.Or(this=catalog_filter, expression=self.TEMP_CATALOG_FILTER) + select = select.where(catalog_filter) results = self.session._fetch_rows(select) return [ Column( diff --git a/sqlframe/base/mixins/dataframe_mixins.py b/sqlframe/base/mixins/dataframe_mixins.py new file mode 100644 index 0000000..3b782e1 --- /dev/null +++ b/sqlframe/base/mixins/dataframe_mixins.py @@ -0,0 +1,63 @@ +import typing as t + +from sqlglot import exp + +from sqlframe.base.catalog import Column +from sqlframe.base.dataframe import ( + GROUP_DATA, + NA, + SESSION, + STAT, + WRITER, + _BaseDataFrame, +) + + +class PrintSchemaFromTempObjectsMixin( + _BaseDataFrame, t.Generic[SESSION, WRITER, NA, STAT, GROUP_DATA] +): + def _get_columns_from_temp_object(self) -> t.List[Column]: + table = exp.to_table(self.session._random_id) + self.session._execute( + exp.Create( + this=table, + kind="VIEW", + replace=True, + properties=exp.Properties(expressions=[exp.TemporaryProperty()]), + expression=self.expression, + ) + ) + return self.session.catalog.listColumns( + table.sql(dialect=self.session.input_dialect), include_temp=True + ) + + def printSchema(self, level: t.Optional[int] = None) -> None: + def print_schema( + column_name: str, column_type: exp.DataType, nullable: bool, current_level: int + ): + if level and current_level >= level: + return + if current_level > 0: + print(" | " * current_level, end="") + print( + f" |-- {column_name}: {column_type.sql(self.session.output_dialect).lower()} (nullable = {str(nullable).lower()})" + ) + if column_type.this == exp.DataType.Type.STRUCT: + for column_def in column_type.expressions: + print_schema(column_def.name, column_def.args["kind"], True, current_level + 1) + if column_type.this == exp.DataType.Type.ARRAY: + for data_type in column_type.expressions: + print_schema("element", data_type, True, current_level + 1) + if column_type.this == exp.DataType.Type.MAP: + print_schema("key", column_type.expressions[0], True, current_level + 1) + print_schema("value", column_type.expressions[1], True, current_level + 1) + + columns = self._get_columns_from_temp_object() + print("root") + for column in columns: + print_schema( + column.name, + exp.DataType.build(column.dataType, dialect=self.session.output_dialect), + column.nullable, + 0, + ) diff --git a/sqlframe/bigquery/catalog.py b/sqlframe/bigquery/catalog.py index 8a9e30f..e996ad8 100644 --- a/sqlframe/bigquery/catalog.py +++ b/sqlframe/bigquery/catalog.py @@ -46,7 +46,9 @@ def currentDatabase(self) -> str: return to_schema(self.session.default_dataset).db @normalize(["tableName", "dbName"]) - def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]: + def listColumns( + self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False + ) -> t.List[Column]: """Returns a t.List of columns for the given table/view in the specified database. .. versionadded:: 2.0.0 diff --git a/sqlframe/duckdb/catalog.py b/sqlframe/duckdb/catalog.py index 6fa3bc1..3a012cf 100644 --- a/sqlframe/duckdb/catalog.py +++ b/sqlframe/duckdb/catalog.py @@ -36,6 +36,8 @@ class DuckDBCatalog( ListColumnsFromInfoSchemaMixin["DuckDBSession", "DuckDBDataFrame"], _BaseCatalog["DuckDBSession", "DuckDBDataFrame"], ): + TEMP_CATALOG_FILTER = exp.column("table_catalog").eq("temp") + def listFunctions( self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None ) -> t.List[Function]: diff --git a/sqlframe/duckdb/dataframe.py b/sqlframe/duckdb/dataframe.py index 8253f8a..a2d9193 100644 --- a/sqlframe/duckdb/dataframe.py +++ b/sqlframe/duckdb/dataframe.py @@ -9,6 +9,7 @@ _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) +from sqlframe.base.mixins.dataframe_mixins import PrintSchemaFromTempObjectsMixin from sqlframe.duckdb.group import DuckDBGroupedData if sys.version_info >= (3, 11): @@ -34,13 +35,14 @@ class DuckDBDataFrameStatFunctions(_BaseDataFrameStatFunctions["DuckDBDataFrame" class DuckDBDataFrame( + PrintSchemaFromTempObjectsMixin, _BaseDataFrame[ "DuckDBSession", "DuckDBDataFrameWriter", "DuckDBDataFrameNaFunctions", "DuckDBDataFrameStatFunctions", "DuckDBGroupedData", - ] + ], ): _na = DuckDBDataFrameNaFunctions _stat = DuckDBDataFrameStatFunctions diff --git a/sqlframe/postgres/catalog.py b/sqlframe/postgres/catalog.py index e599626..1ce8d8c 100644 --- a/sqlframe/postgres/catalog.py +++ b/sqlframe/postgres/catalog.py @@ -34,6 +34,7 @@ class PostgresCatalog( _BaseCatalog["PostgresSession", "PostgresDataFrame"], ): CURRENT_CATALOG_EXPRESSION: exp.Expression = exp.column("current_catalog") + TEMP_SCHEMA_FILTER = exp.column("table_schema").like("pg_temp_%") def listFunctions( self, dbName: t.Optional[str] = None, pattern: t.Optional[str] = None diff --git a/sqlframe/postgres/dataframe.py b/sqlframe/postgres/dataframe.py index 00e652d..a31058f 100644 --- a/sqlframe/postgres/dataframe.py +++ b/sqlframe/postgres/dataframe.py @@ -9,6 +9,7 @@ _BaseDataFrameNaFunctions, _BaseDataFrameStatFunctions, ) +from sqlframe.base.mixins.dataframe_mixins import PrintSchemaFromTempObjectsMixin from sqlframe.postgres.group import PostgresGroupedData if sys.version_info >= (3, 11): @@ -33,13 +34,14 @@ class PostgresDataFrameStatFunctions(_BaseDataFrameStatFunctions["PostgresDataFr class PostgresDataFrame( + PrintSchemaFromTempObjectsMixin, _BaseDataFrame[ "PostgresSession", "PostgresDataFrameWriter", "PostgresDataFrameNaFunctions", "PostgresDataFrameStatFunctions", "PostgresGroupedData", - ] + ], ): _na = PostgresDataFrameNaFunctions _stat = PostgresDataFrameStatFunctions diff --git a/sqlframe/spark/catalog.py b/sqlframe/spark/catalog.py index 98b6d9c..6674e68 100644 --- a/sqlframe/spark/catalog.py +++ b/sqlframe/spark/catalog.py @@ -468,7 +468,9 @@ def listTables( ) return [Table(*x) for x in self._spark_catalog.listTables(dbName, pattern)] - def listColumns(self, tableName: str, dbName: t.Optional[str] = None) -> t.List[Column]: + def listColumns( + self, tableName: str, dbName: t.Optional[str] = None, include_temp: bool = False + ) -> t.List[Column]: """Returns a t.List of columns for the given table/view in the specified database. .. versionadded:: 2.0.0 diff --git a/tests/integration/engines/duck/test_duckdb_dataframe.py b/tests/integration/engines/duck/test_duckdb_dataframe.py new file mode 100644 index 0000000..2e789b8 --- /dev/null +++ b/tests/integration/engines/duck/test_duckdb_dataframe.py @@ -0,0 +1,79 @@ +import datetime + +from sqlframe.base.types import Row +from sqlframe.duckdb import DuckDBDataFrame, DuckDBSession + +pytest_plugins = ["tests.integration.fixtures"] + + +def test_print_schema_basic(duckdb_employee: DuckDBDataFrame, capsys): + duckdb_employee.printSchema() + captured = capsys.readouterr() + assert ( + captured.out.strip() + == """ +root + |-- employee_id: int (nullable = true) + |-- fname: text (nullable = true) + |-- lname: text (nullable = true) + |-- age: int (nullable = true) + |-- store_id: int (nullable = true)""".strip() + ) + + +def test_print_schema_nested(duckdb_session: DuckDBSession, capsys): + df = duckdb_session.createDataFrame( + [ + ( + 1, + 2.0, + "foo", + {"a": 1}, + [Row(a=1, b=2)], + [1, 2, 3], + Row(a=1), + datetime.date(2022, 1, 1), + datetime.datetime(2022, 1, 1, 0, 0, 0), + datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc), + True, + ) + ], + [ + "bigint_col", + "double_col", + "string_col", + "map_col", + "array>", + "array_col", + "struct_col", + "date_col", + "timestamp_col", + "timestamptz_col", + "boolean_col", + ], + ) + df.printSchema() + captured = capsys.readouterr() + assert ( + captured.out.strip() + == """ +root + |-- bigint_col: bigint (nullable = true) + |-- double_col: double (nullable = true) + |-- string_col: text (nullable = true) + |-- map_col: map(text, bigint) (nullable = true) + | |-- key: text (nullable = true) + | |-- value: bigint (nullable = true) + |-- array>: struct(a bigint, b bigint)[] (nullable = true) + | |-- element: struct(a bigint, b bigint) (nullable = true) + | | |-- a: bigint (nullable = true) + | | |-- b: bigint (nullable = true) + |-- array_col: bigint[] (nullable = true) + | |-- element: bigint (nullable = true) + |-- struct_col: struct(a bigint) (nullable = true) + | |-- a: bigint (nullable = true) + |-- date_col: date (nullable = true) + |-- timestamp_col: timestamp (nullable = true) + |-- timestamptz_col: timestamptz (nullable = true) + |-- boolean_col: boolean (nullable = true)""".strip() + ) diff --git a/tests/integration/engines/postgres/test_postgres_dataframe.py b/tests/integration/engines/postgres/test_postgres_dataframe.py new file mode 100644 index 0000000..88973fc --- /dev/null +++ b/tests/integration/engines/postgres/test_postgres_dataframe.py @@ -0,0 +1,64 @@ +import datetime + +from sqlframe.base.types import Row +from sqlframe.duckdb import DuckDBDataFrame, DuckDBSession + +pytest_plugins = ["tests.integration.fixtures"] + + +def test_print_schema_basic(postgres_employee: DuckDBDataFrame, capsys): + postgres_employee.printSchema() + captured = capsys.readouterr() + assert ( + captured.out.strip() + == """ +root + |-- employee_id: int (nullable = true) + |-- fname: text (nullable = true) + |-- lname: text (nullable = true) + |-- age: int (nullable = true) + |-- store_id: int (nullable = true)""".strip() + ) + + +def test_print_schema_nested(postgres_session: DuckDBSession, capsys): + df = postgres_session.createDataFrame( + [ + ( + 1, + 2.0, + "foo", + [1, 2, 3], + datetime.date(2022, 1, 1), + datetime.datetime(2022, 1, 1, 0, 0, 0), + datetime.datetime(2022, 1, 1, 0, 0, 0, tzinfo=datetime.timezone.utc), + True, + ) + ], + [ + "bigint_col", + "double_col", + "string_col", + "array_col", + "date_col", + "timestamp_col", + "timestamptz_col", + "boolean_col", + ], + ) + df.printSchema() + captured = capsys.readouterr() + # array does not include type + assert ( + captured.out.strip() + == """ +root + |-- bigint_col: bigint (nullable = true) + |-- double_col: double precision (nullable = true) + |-- string_col: text (nullable = true) + |-- array_col: array (nullable = true) + |-- date_col: date (nullable = true) + |-- timestamp_col: timestamp (nullable = true) + |-- timestamptz_col: timestamptz (nullable = true) + |-- boolean_col: boolean (nullable = true)""".strip() + ) diff --git a/tests/integration/engines/test_int_functions.py b/tests/integration/engines/test_int_functions.py index 2219f58..50fd400 100644 --- a/tests/integration/engines/test_int_functions.py +++ b/tests/integration/engines/test_int_functions.py @@ -6,7 +6,6 @@ from collections import Counter import pytest -from pyspark.sql import DataFrame from pyspark.sql import SparkSession as PySparkSession from sqlglot import exp @@ -175,7 +174,6 @@ def test_col(get_session_and_func, arg): ) def test_typeof(get_session_and_func, get_types, arg, expected): session, typeof = get_session_and_func("typeof") - types = get_types(session) # If we just pass a struct in for values then Spark will automatically explode the struct into columns # it won't do this though if there is another column so that is why we include an ignore column df = session.createDataFrame([(1, arg)], schema=["ignore_col", "col"])