Skip to content

Commit

Permalink
use udf
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-aalam committed Dec 14, 2024
1 parent 68aefef commit e3cc281
Showing 1 changed file with 69 additions and 35 deletions.
104 changes: 69 additions & 35 deletions src/snowflake/snowpark/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import snowflake.snowpark
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
from snowflake.snowpark.functions import lit, parse_json
from snowflake.snowpark.types import DataType


Expand All @@ -30,6 +31,7 @@ class Catalog:
def __init__(self, session: "snowflake.snowpark.session.Session") -> None:
self._session = session
self._root = Root(session)
self._python_regex_udf = None

def _parse_database(
self,
Expand Down Expand Up @@ -93,6 +95,46 @@ def _parse_function_or_procedure(
arg_types_str = ", ".join(arg.datatype for arg in fn.arguments)
return f"{fn.name}({arg_types_str})"

def _initialize_regex_udf(self) -> None:
with self._session._lock:
if self._python_regex_udf is not None:
return

def python_regex_filter(pattern: str, input: str) -> bool:
return bool(re.match(pattern, input))

self._python_regex_udf = self._session.udf.register(python_regex_filter)

def _list_objects(
self,
*,
object_name: str,
object_class,
database: Optional[Union[str, Database]],
schema: Optional[Union[str, Schema]],
pattern: Optional[str],
):
db_name = self._parse_database(database)
schema_name = self._parse_schema(schema)

df = self._session.sql(
f"SHOW AS RESOURCE {object_name} IN {db_name}.{schema_name} -- catalog api"
)
if pattern:
# initialize udf
self._initialize_regex_udf()

# The result of SHOW AS RESOURCE query is a json string which contains
# key 'name' to store the name of the object. We parse json for the returned
# result and apply the filter on name.
df = df.filter(
self._python_regex_udf(
lit(pattern), parse_json('"As Resource"')["name"]
)
)

return list(map(lambda row: object_class.from_json(row[0]), df.collect()))

# List methods
def list_databases(
self,
Expand Down Expand Up @@ -144,14 +186,13 @@ def list_tables(
schema: schema name or ``Schema`` object. Defaults to None.
pattern: the pattern of name to match. Defaults to None.
"""
db_name = self._parse_database(database)
schema_name = self._parse_schema(schema)

iter = self._root.databases[db_name].schemas[schema_name].tables.iter()
if pattern:
iter = filter(lambda x: re.match(pattern, x.name), iter)

return list(iter)
return self._list_objects(
object_name="TABLES",
object_class=Table,
database=database,
schema=schema,
pattern=pattern,
)

def list_views(
self,
Expand All @@ -168,14 +209,13 @@ def list_views(
schema: schema name or ``Schema`` object. Defaults to None.
pattern: the pattern of name to match. Defaults to None.
"""
db_name = self._parse_database(database)
schema_name = self._parse_schema(schema)

iter = self._root.databases[db_name].schemas[schema_name].views.iter()
if pattern:
iter = filter(lambda x: re.match(pattern, x.name), iter)

return list(iter)
return self._list_objects(
object_name="VIEWS",
object_class=View,
database=database,
schema=schema,
pattern=pattern,
)

def list_columns(
self,
Expand Down Expand Up @@ -212,14 +252,13 @@ def list_procedures(
schema: schema name or ``Schema`` object. Defaults to None.
pattern: the pattern of name to match. Defaults to None.
"""
db_name = self._parse_database(database)
schema_name = self._parse_schema(schema)

iter = self._root.databases[db_name].schemas[schema_name].procedures.iter()
if pattern:
iter = filter(lambda x: re.match(pattern, x.name), iter)

return list(iter)
return self._list_objects(
object_name="PROCEDURES",
object_class=Procedure,
database=database,
schema=schema,
pattern=pattern,
)

def list_user_defined_functions(
self,
Expand All @@ -235,18 +274,13 @@ def list_user_defined_functions(
schema: schema name or ``Schema`` object. Defaults to None.
pattern: the pattern of name to match. Defaults to None.
"""
db_name = self._parse_database(database)
schema_name = self._parse_schema(schema)

iter = (
self._root.databases[db_name]
.schemas[schema_name]
.user_defined_functions.iter()
return self._list_objects(
object_name="USER FUNCTIONS",
object_class=UserDefinedFunction,
database=database,
schema=schema,
pattern=pattern,
)
if pattern:
iter = filter(lambda x: re.match(pattern, x.name), iter)

return list(iter)

# get methods
def get_current_database(self) -> Database:
Expand Down

0 comments on commit e3cc281

Please sign in to comment.