Skip to content

Commit

Permalink
add action workflow (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq authored May 18, 2024
1 parent f83590e commit 8483236
Show file tree
Hide file tree
Showing 19 changed files with 107 additions and 45 deletions.
32 changes: 32 additions & 0 deletions .github/workflows/main.workflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: SQLFrame
on:
push:
branches:
- main
pull_request:
types:
- synchronize
- opened
jobs:
run-tests:
runs-on: ubuntu-latest
env:
PYTEST_XDIST_AUTO_NUM_WORKERS: 4
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
steps:
- name: Checkout
uses: actions/checkout@v2
- name: Install Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: make install-dev
- name: Run Style
run: make style
- name: Setup Postgres
uses: ikalnytskyi/action-setup-postgres@v6
- name: Run tests
run: make local-test
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
install-dev:
pip install -e ".[dev]"
pip install -e ".[dev,duckdb,postgres,redshift,bigquery,snowflake,spark]"

install-pre-commit:
pre-commit install
Expand Down
17 changes: 12 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,13 @@
"bigquery": [
"google-cloud-bigquery[pandas]",
"google-cloud-bigquery-storage",
"pandas",
],
"dev": [
"duckdb",
"mkdocs==1.4.2",
"mkdocs-include-markdown-plugin==4.0.3",
"mkdocs-material==9.0.5",
"mkdocs-material-extensions==1.1.1",
"mypy",
"pandas",
"pymdown-extensions",
"pandas-stubs",
"psycopg",
"pyarrow",
"pyspark",
Expand All @@ -47,17 +44,27 @@
"typing_extensions",
"types-psycopg2",
],
"docs": [
"mkdocs==1.4.2",
"mkdocs-include-markdown-plugin==4.0.3",
"mkdocs-material==9.0.5",
"mkdocs-material-extensions==1.1.1",
"pymdown-extensions",
],
"duckdb": [
"duckdb",
"pandas",
],
"postgres": [
"pandas",
"psycopg2",
],
"redshift": [
"pandas",
"redshift_connector",
],
"snowflake": [
"pandas",
"snowflake-connector-python[pandas,secure-local-storage]",
],
"spark": [
Expand Down
4 changes: 2 additions & 2 deletions sqlframe/base/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def crossJoin(self, other: DF) -> Self:
| 16| Bob| 85|
+---+-----+------+
"""
return self.join.__wrapped__(self, other, how="cross")
return self.join.__wrapped__(self, other, how="cross") # type: ignore

@operation(Operation.FROM)
def join(
Expand Down Expand Up @@ -769,7 +769,7 @@ def join(
new_df = self.copy(expression=join_expression)
new_df.pending_join_hints.extend(self.pending_join_hints)
new_df.pending_hints.extend(other_df.pending_hints)
new_df = new_df.select.__wrapped__(new_df, *select_column_names)
new_df = new_df.select.__wrapped__(new_df, *select_column_names) # type: ignore
return new_df

@operation(Operation.ORDER_BY)
Expand Down
12 changes: 6 additions & 6 deletions sqlframe/base/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sqlframe.base.catalog import _BaseCatalog


def normalize(normalize_kwargs: t.List[str]):
def normalize(normalize_kwargs: t.List[str]) -> t.Callable[[t.Callable], t.Callable]:
"""
Decorator used around DataFrame methods to indicate what type of operation is being performed from the
ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
Expand All @@ -23,9 +23,9 @@ def normalize(normalize_kwargs: t.List[str]):
in cases where there is overlap in names.
"""

def decorator(func: t.Callable):
def decorator(func: t.Callable) -> t.Callable:
@functools.wraps(func)
def wrapper(self: _BaseCatalog, *args, **kwargs):
def wrapper(self: _BaseCatalog, *args, **kwargs) -> _BaseCatalog:
kwargs.update(dict(zip(func.__code__.co_varnames[1:], args)))
for kwarg in normalize_kwargs:
if kwarg in kwargs:
Expand All @@ -43,9 +43,9 @@ def wrapper(self: _BaseCatalog, *args, **kwargs):
return decorator


def func_metadata(unsupported_engines: t.Optional[t.Union[str, t.List[str]]] = None):
def _metadata(func):
func.unsupported_engines = ensure_list(unsupported_engines) if unsupported_engines else []
def func_metadata(unsupported_engines: t.Optional[t.Union[str, t.List[str]]] = None) -> t.Callable:
def _metadata(func: t.Callable) -> t.Callable:
func.unsupported_engines = ensure_list(unsupported_engines) if unsupported_engines else [] # type: ignore
return func

return _metadata
3 changes: 3 additions & 0 deletions sqlframe/base/mixins/readwriter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def _write(self, path: str, mode: t.Optional[str], format: str, **options): # t
raise NotImplementedError("Append mode is not supported for parquet.")
pandas_df.to_parquet(path, **kwargs)
elif format == "json":
# Pandas versions are inconsistent on how to handle True/False index so we just remove it
# since in all versions it will not result in an index column in the output.
del kwargs["index"]
kwargs["mode"] = mode
kwargs["orient"] = "records"
pandas_df.to_json(path, lines=True, **kwargs)
Expand Down
14 changes: 7 additions & 7 deletions sqlframe/base/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class Operation(IntEnum):
LIMIT = 7


def operation(op: Operation):
def operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:
"""
Decorator used around DataFrame methods to indicate what type of operation is being performed from the
ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
Expand All @@ -35,9 +35,9 @@ def operation(op: Operation):
in cases where there is overlap in names.
"""

def decorator(func: t.Callable):
def decorator(func: t.Callable) -> t.Callable:
@functools.wraps(func)
def wrapper(self: _BaseDataFrame, *args, **kwargs):
def wrapper(self: _BaseDataFrame, *args, **kwargs) -> _BaseDataFrame:
if self.last_op == Operation.INIT:
self = self._convert_leaf_to_cte()
self.last_op = Operation.NO_OP
Expand All @@ -47,15 +47,15 @@ def wrapper(self: _BaseDataFrame, *args, **kwargs):
self = self._convert_leaf_to_cte()
df: t.Union[_BaseDataFrame, _BaseGroupedData] = func(self, *args, **kwargs)
df.last_op = new_op # type: ignore
return df
return df # type: ignore

wrapper.__wrapped__ = func # type: ignore
return wrapper

return decorator


def group_operation(op: Operation):
def group_operation(op: Operation) -> t.Callable[[t.Callable], t.Callable]:
"""
Decorator used around DataFrame methods to indicate what type of operation is being performed from the
ordered Operation enums. This is used to determine which operations should be performed on a CTE vs.
Expand All @@ -67,9 +67,9 @@ def group_operation(op: Operation):
in cases where there is overlap in names.
"""

def decorator(func: t.Callable):
def decorator(func: t.Callable) -> t.Callable:
@functools.wraps(func)
def wrapper(self: _BaseGroupedData, *args, **kwargs):
def wrapper(self: _BaseGroupedData, *args, **kwargs) -> _BaseDataFrame:
if self._df.last_op == Operation.INIT:
self._df = self._df._convert_leaf_to_cte()
self._df.last_op = Operation.NO_OP
Expand Down
6 changes: 3 additions & 3 deletions sqlframe/base/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from functools import cached_property

import sqlglot
from more_itertools import take
from sqlglot import Dialect, exp
from sqlglot.expressions import parse_identifier
from sqlglot.helper import seq_get
from sqlglot.optimizer import optimize
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import (
Expand Down Expand Up @@ -211,10 +211,10 @@ def get_default_data_type(value: t.Any) -> t.Optional[str]:
row_types.append((row_name, default_type))
return "struct<" + ", ".join(f"{k}: {v}" for (k, v) in row_types) + ">"
elif isinstance(value, dict):
sample_row = take(1, value.items())
sample_row = seq_get(list(value.items()), 0)
if not sample_row:
return None
key, value = sample_row[0]
key, value = sample_row
default_key = get_default_data_type(key)
default_value = get_default_data_type(value)
if not default_key or not default_value:
Expand Down
4 changes: 3 additions & 1 deletion sqlframe/base/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from sqlglot import expressions as exp


def replace_id_value(node, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]):
def replace_id_value(
node: exp.Expression, replacement_mapping: t.Dict[exp.Identifier, exp.Identifier]
) -> exp.Expression:
if isinstance(node, exp.Identifier) and node in replacement_mapping:
node = node.replace(replacement_mapping[node].copy())
return node
7 changes: 4 additions & 3 deletions sqlframe/base/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlglot.schema import ensure_column_mapping as sqlglot_ensure_column_mapping

if t.TYPE_CHECKING:
from pandas.core.frame import DataFrame as PandasDataFrame
from pyspark.sql.dataframe import SparkSession as PySparkSession

from sqlframe.base import types
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_tables_from_expression_with_join(expression: exp.Select) -> t.List[exp.T
return [left_table] + other_tables


def to_csv(options: t.Dict[str, OptionalPrimitiveType], equality_char: str = "="):
def to_csv(options: t.Dict[str, OptionalPrimitiveType], equality_char: str = "=") -> str:
return ", ".join(
[f"{k}{equality_char}{v}" for k, v in (options or {}).items() if v is not None]
)
Expand All @@ -116,7 +117,7 @@ def ensure_column_mapping(schema: t.Union[str, StructType]) -> t.Dict:


# SO: https://stackoverflow.com/questions/37513355/converting-pandas-dataframe-into-spark-dataframe-error
def get_equivalent_spark_type(pandas_type):
def get_equivalent_spark_type(pandas_type) -> types.DataType:
"""
This method will retrieve the corresponding spark type given a pandas
type.
Expand All @@ -139,7 +140,7 @@ def get_equivalent_spark_type(pandas_type):
return type_map.get(str(pandas_type).lower(), types.StringType())


def pandas_to_spark_schema(pandas_df):
def pandas_to_spark_schema(pandas_df: PandasDataFrame) -> types.StructType:
"""
This method will return a spark dataframe schema given a pandas dataframe.
Expand Down
5 changes: 3 additions & 2 deletions sqlframe/bigquery/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import fnmatch
import typing as t

from google.cloud.bigquery import StandardSqlDataType
from sqlglot import exp

from sqlframe.base.catalog import CatalogMetadata, Column, Function
Expand All @@ -16,8 +15,10 @@
from sqlframe.base.util import schema_, to_schema

if t.TYPE_CHECKING:
from sqlframe.bigquery.session import BigQuerySession # noqa
from google.cloud.bigquery import StandardSqlDataType

from sqlframe.bigquery.dataframe import BigQueryDataFrame # noqa
from sqlframe.bigquery.session import BigQuerySession # noqa


class BigQueryCatalog(
Expand Down
5 changes: 3 additions & 2 deletions sqlframe/bigquery/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
)

if t.TYPE_CHECKING:
from google.cloud import bigquery
from google.cloud.bigquery.client import Client as BigQueryClient
from google.cloud.bigquery.dbapi.connection import Connection as BigQueryConnection
else:
BigQueryClient = t.Any
BigQueryConnection = t.Any


Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(
self.default_dataset = default_dataset

@property
def _client(self) -> bigquery.client.Client:
def _client(self) -> BigQueryClient:
assert self._connection
return self._connection._client

Expand Down
2 changes: 1 addition & 1 deletion sqlframe/duckdb/readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _write(self, path: str, mode: t.Optional[str], **options): # type: ignore
return
if mode == "append":
raise NotImplementedError("Append mode not supported")
options = to_csv(options, equality_char=" ")
options = to_csv(options, equality_char=" ") # type: ignore
sqls = self._df.sql(pretty=False, optimize=False, as_list=True)
for i, sql in enumerate(sqls):
if i < len(sqls) - 1:
Expand Down
23 changes: 13 additions & 10 deletions tests/common_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from sqlframe.standalone.session import StandaloneSession

if t.TYPE_CHECKING:
from google.cloud.bigquery.dbapi.connection import Connection as BigQueryConnection
from google.cloud.bigquery.dbapi.connection import (
Connection as BigQueryConnection,
)
from redshift_connector.core import Connection as RedshiftConnection
from snowflake.connector import SnowflakeConnection

Expand All @@ -36,6 +38,7 @@ def pyspark_session(tmp_path_factory) -> PySparkSession:
.config("spark.sql.warehouse.dir", data_dir)
.config("spark.driver.extraJavaOptions", f"-Dderby.system.home={derby_dir}")
.config("spark.sql.shuffle.partitions", 1)
.config("spark.sql.session.timeZone", "America/Los_Angeles")
.master("local[1]")
.appName("Unit-tests")
.getOrCreate()
Expand All @@ -60,11 +63,11 @@ def spark_session(pyspark_session: PySparkSession) -> SparkSession:


@pytest.fixture(scope="function")
def duckdb_session(monkeypatch: pytest.MonkeyPatch) -> DuckDBSession:
import duckdb
def duckdb_session() -> DuckDBSession:
from duckdb import connect

# https://github.com/duckdb/duckdb/issues/11404
connection = duckdb.connect()
connection = connect()
connection.sql("set TimeZone = 'UTC'")
return DuckDBSession(conn=connection)

Expand All @@ -74,12 +77,12 @@ def function_scoped_postgres(postgresql_proc):
import psycopg2

janitor = DatabaseJanitor(
postgresql_proc.user,
postgresql_proc.host,
postgresql_proc.port,
postgresql_proc.dbname,
postgresql_proc.version,
postgresql_proc.password,
user=postgresql_proc.user,
host=postgresql_proc.host,
port=postgresql_proc.port,
dbname=postgresql_proc.dbname,
version=postgresql_proc.version,
password=postgresql_proc.password,
)
with janitor:
conn = psycopg2.connect(
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
from __future__ import annotations

import time

import pytest


@pytest.fixture(scope="session", autouse=True)
def set_tz():
import os

os.environ["TZ"] = "US/Pacific"
time.tzset()
yield
del os.environ["TZ"]


@pytest.fixture(scope="function", autouse=True)
def rescope_sparksession_singleton():
from sqlframe.base.session import _BaseSession
Expand Down
File renamed without changes.
Loading

0 comments on commit 8483236

Please sign in to comment.