From e3cc281d5b8494ef325b90fd60999a4706219f27 Mon Sep 17 00:00:00 2001 From: Afroz Alam Date: Fri, 13 Dec 2024 17:06:20 -0800 Subject: [PATCH] use udf --- src/snowflake/snowpark/catalog.py | 104 ++++++++++++++++++++---------- 1 file changed, 69 insertions(+), 35 deletions(-) diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py index 654c8bbf8b6..d0b9c226556 100644 --- a/src/snowflake/snowpark/catalog.py +++ b/src/snowflake/snowpark/catalog.py @@ -18,6 +18,7 @@ import snowflake.snowpark from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type +from snowflake.snowpark.functions import lit, parse_json from snowflake.snowpark.types import DataType @@ -30,6 +31,7 @@ class Catalog: def __init__(self, session: "snowflake.snowpark.session.Session") -> None: self._session = session self._root = Root(session) + self._python_regex_udf = None def _parse_database( self, @@ -93,6 +95,46 @@ def _parse_function_or_procedure( arg_types_str = ", ".join(arg.datatype for arg in fn.arguments) return f"{fn.name}({arg_types_str})" + def _initialize_regex_udf(self) -> None: + with self._session._lock: + if self._python_regex_udf is not None: + return + + def python_regex_filter(pattern: str, input: str) -> bool: + return bool(re.match(pattern, input)) + + self._python_regex_udf = self._session.udf.register(python_regex_filter) + + def _list_objects( + self, + *, + object_name: str, + object_class, + database: Optional[Union[str, Database]], + schema: Optional[Union[str, Schema]], + pattern: Optional[str], + ): + db_name = self._parse_database(database) + schema_name = self._parse_schema(schema) + + df = self._session.sql( + f"SHOW AS RESOURCE {object_name} IN {db_name}.{schema_name} -- catalog api" + ) + if pattern: + # initialize udf + self._initialize_regex_udf() + + # The result of SHOW AS RESOURCE query is a json string which contains + # key 'name' to store the name of the object. We parse json for the returned + # result and apply the filter on name. + df = df.filter( + self._python_regex_udf( + lit(pattern), parse_json('"As Resource"')["name"] + ) + ) + + return list(map(lambda row: object_class.from_json(row[0]), df.collect())) + # List methods def list_databases( self, @@ -144,14 +186,13 @@ def list_tables( schema: schema name or ``Schema`` object. Defaults to None. pattern: the pattern of name to match. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - - iter = self._root.databases[db_name].schemas[schema_name].tables.iter() - if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) - - return list(iter) + return self._list_objects( + object_name="TABLES", + object_class=Table, + database=database, + schema=schema, + pattern=pattern, + ) def list_views( self, @@ -168,14 +209,13 @@ def list_views( schema: schema name or ``Schema`` object. Defaults to None. pattern: the pattern of name to match. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - - iter = self._root.databases[db_name].schemas[schema_name].views.iter() - if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) - - return list(iter) + return self._list_objects( + object_name="VIEWS", + object_class=View, + database=database, + schema=schema, + pattern=pattern, + ) def list_columns( self, @@ -212,14 +252,13 @@ def list_procedures( schema: schema name or ``Schema`` object. Defaults to None. pattern: the pattern of name to match. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - - iter = self._root.databases[db_name].schemas[schema_name].procedures.iter() - if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) - - return list(iter) + return self._list_objects( + object_name="PROCEDURES", + object_class=Procedure, + database=database, + schema=schema, + pattern=pattern, + ) def list_user_defined_functions( self, @@ -235,18 +274,13 @@ def list_user_defined_functions( schema: schema name or ``Schema`` object. Defaults to None. pattern: the pattern of name to match. Defaults to None. """ - db_name = self._parse_database(database) - schema_name = self._parse_schema(schema) - - iter = ( - self._root.databases[db_name] - .schemas[schema_name] - .user_defined_functions.iter() + return self._list_objects( + object_name="USER FUNCTIONS", + object_class=UserDefinedFunction, + database=database, + schema=schema, + pattern=pattern, ) - if pattern: - iter = filter(lambda x: re.match(pattern, x.name), iter) - - return list(iter) # get methods def get_current_database(self) -> Database: