Skip to content

Commit

Permalink
Merge pull request #56 from cal-itp/params-magic
Browse files Browse the repository at this point in the history
probably very hacky way to add papermill parameters at runtime
  • Loading branch information
atvaccaro authored Apr 14, 2022
2 parents 2edce07 + 65b13a5 commit c23c08d
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length=88
max-line-length=120
ignore=E203, W503
2 changes: 1 addition & 1 deletion calitp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# flake8: noqa

__version__ = "0.0.15"
__version__ = "0.0.16"

from .sql import get_table, write_table, query_sql, to_snakecase, get_engine
from .storage import save_to_gcfs, read_gcfs
4 changes: 1 addition & 3 deletions calitp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def is_development():
# if a person can write data, then they need to set AIRFLOW_ENV
if is_pipeline():
if "AIRFLOW_ENV" not in os.environ:
raise KeyError(
"Pipeline admin must set AIRFLOW_ENV env variable explicitly"
)
raise KeyError("Pipeline admin must set AIRFLOW_ENV env variable explicitly")

env = os.environ["AIRFLOW_ENV"]

Expand Down
15 changes: 12 additions & 3 deletions calitp/magics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from IPython.core.magic import register_cell_magic
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring
from IPython.display import Markdown, display
Expand All @@ -13,9 +14,7 @@
help="Print the code to markdown, in addition to running",
)
@argument("-o", "--output", type=str, help="A variable name to save the result as")
@argument(
"-q", "--quiet", action="store_true", help="Whether to hide the result printout"
)
@argument("-q", "--quiet", action="store_true", help="Whether to hide the result printout")
@register_cell_magic
def sql(line, cell):
# %%sql -m
Expand All @@ -35,3 +34,13 @@ def sql(line, cell):

if not args.quiet:
return res


@register_cell_magic
def capture_parameters(line, cell):
shell = get_ipython()
shell.run_cell(cell, silent=True)
# We assume the last line is a tuple
tup = [s.strip() for s in cell.strip().split("\n")[-1].split(",")]

print(json.dumps({identifier: shell.user_ns[identifier] for identifier in tup if identifier}))
16 changes: 4 additions & 12 deletions calitp/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@


class CreateTableAs(Executable, ClauseElement):
def __init__(
self, name, select, replace=False, if_not_exists=False, partition_by=None
):
def __init__(self, name, select, replace=False, if_not_exists=False, partition_by=None):
self.name = name
self.select = select
self.replace = replace
Expand All @@ -34,9 +32,7 @@ def visit_insert_from_select(element, compiler, **kw):
if_not_exists = " IF NOT EXISTS" if element.if_not_exists else ""

# TODO: visit partition by clause
partition_by = (
f" PARTITION BY {element.partition_by}" if element.partition_by else ""
)
partition_by = f" PARTITION BY {element.partition_by}" if element.partition_by else ""

return f"""
CREATE{or_replace} TABLE{if_not_exists} {name}
Expand Down Expand Up @@ -113,9 +109,7 @@ def write_table(
@require_pipeline("write_table")
def _write_table_df(sql_stmt, table_name, engine=None, replace=True):
if_exists = "replace" if replace else "fail"
return sql_stmt.to_gbq(
format_table_name(table_name), project_id=get_project_id(), if_exists=if_exists
)
return sql_stmt.to_gbq(format_table_name(table_name), project_id=get_project_id(), if_exists=if_exists)


def query_sql(fname, write_as=None, replace=False, dry_run=False, as_df=True):
Expand Down Expand Up @@ -176,9 +170,7 @@ def sql_patch_comments(table_name, field_comments, table_comments=None, bq_clien
if bq_client is None:
from google.cloud import bigquery

bq_client = bigquery.Client(
project=get_project_id(), location=CALITP_BQ_LOCATION
)
bq_client = bigquery.Client(project=get_project_id(), location=CALITP_BQ_LOCATION)

tbl = bq_client.get_table(table_name)
old_schema = tbl.schema
Expand Down
4 changes: 1 addition & 3 deletions calitp/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ def get_fs(gcs_project="", **kwargs):
if is_cloud():
return gcsfs.GCSFileSystem(project=gcs_project, token="cloud", **kwargs)
else:
return gcsfs.GCSFileSystem(
project=gcs_project, token="google_default", **kwargs
)
return gcsfs.GCSFileSystem(project=gcs_project, token="google_default", **kwargs)


@require_pipeline("save_to_gcfs")
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.black]
line-length = 120
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
_version_re = re.compile(r"__version__\s+=\s+(.*)")

with open("calitp/__init__.py", "rb") as f:
version = str(
ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1))
)
version = str(ast.literal_eval(_version_re.search(f.read().decode("utf-8")).group(1)))

setup(
name="calitp",
Expand Down

0 comments on commit c23c08d

Please sign in to comment.