diff --git a/CHANGELOG.md b/CHANGELOG.md index 182cc2dea17..4702178cb22 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - `divnull` - `nullifzero` - `snowflake_cortex_sentiment` +- Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`. ### Snowpark pandas API Updates diff --git a/docs/source/snowpark/catalog.rst b/docs/source/snowpark/catalog.rst new file mode 100644 index 00000000000..f8291b08f39 --- /dev/null +++ b/docs/source/snowpark/catalog.rst @@ -0,0 +1,63 @@ +============= +Catalog +============= +Catalog module for Snowpark. + +.. currentmodule:: snowflake.snowpark.catalog + +.. rubric:: Catalog + +.. autosummary:: + :toctree: api/ + + Catalog.databaseExists + Catalog.database_exists + Catalog.dropDatabase + Catalog.dropSchema + Catalog.dropTable + Catalog.dropView + Catalog.drop_database + Catalog.drop_schema + Catalog.drop_table + Catalog.drop_view + Catalog.getCurrentDatabase + Catalog.getCurrentSchema + Catalog.getProcedure + Catalog.getTable + Catalog.getUserDefinedFunction + Catalog.getView + Catalog.get_current_database + Catalog.get_current_schema + Catalog.get_procedure + Catalog.get_table + Catalog.get_user_defined_function + Catalog.get_view + Catalog.listColumns + Catalog.listDatabases + Catalog.listProcedures + Catalog.listSchemas + Catalog.listTables + Catalog.listUserDefinedFunctions + Catalog.listViews + Catalog.list_columns + Catalog.list_databases + Catalog.list_procedures + Catalog.list_schemas + Catalog.list_tables + Catalog.list_user_defined_functions + Catalog.list_views + Catalog.procedureExists + Catalog.procedure_exists + Catalog.schemaExists + Catalog.schema_exists + Catalog.setCurrentDatabase + Catalog.setCurrentSchema + Catalog.set_current_database + Catalog.set_current_schema + Catalog.tableExists + Catalog.table_exists + Catalog.userDefinedFunctionExists + Catalog.user_defined_function_exists + Catalog.viewExists + Catalog.view_exists + diff --git a/docs/source/snowpark/index.rst b/docs/source/snowpark/index.rst index ad3ad563e39..ab8125d4058 100644 --- a/docs/source/snowpark/index.rst +++ b/docs/source/snowpark/index.rst @@ -9,9 +9,9 @@ Snowpark APIs column types row - functions - window - grouping + functions + window + grouping table_function table async_job @@ -21,6 +21,7 @@ Snowpark APIs udtf observability files + catalog lineage context exceptions diff --git a/docs/source/snowpark/session.rst b/docs/source/snowpark/session.rst index 21da1f76849..3e9b046a521 100644 --- a/docs/source/snowpark/session.rst +++ b/docs/source/snowpark/session.rst @@ -38,6 +38,7 @@ Snowpark Session Session.append_query_tag Session.call Session.cancel_all + Session.catalog Session.clear_imports Session.clear_packages Session.close diff --git a/recipe/meta.yaml b/recipe/meta.yaml index 467593287a3..b357eda695a 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -43,6 +43,7 @@ requirements: - protobuf >=3.20,<6 - python-dateutil - tzlocal + - snowflake.core >=1.0.0,<2 test: imports: diff --git a/setup.py b/setup.py index a1be8a8eda7..656e164570a 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "protobuf>=3.20, <6", # Snowpark IR "python-dateutil", # Snowpark IR "tzlocal", # Snowpark IR + "snowflake.core>=1.0.0, <2", # Catalog ] REQUIRED_PYTHON_VERSION = ">=3.8, <3.12" diff --git a/src/snowflake/snowpark/catalog.py b/src/snowflake/snowpark/catalog.py new file mode 100644 index 00000000000..d0b9c226556 --- /dev/null +++ b/src/snowflake/snowpark/catalog.py @@ -0,0 +1,659 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import re +from typing import List, Optional, Union + +from snowflake.core import Root +from snowflake.core.database import Database +from snowflake.core.exceptions import NotFoundError +from snowflake.core.function import Function +from snowflake.core.procedure import Procedure +from snowflake.core.schema import Schema +from snowflake.core.table import Table, TableColumn +from snowflake.core.user_defined_function import UserDefinedFunction +from snowflake.core.view import View + + +import snowflake.snowpark +from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type +from snowflake.snowpark.functions import lit, parse_json +from snowflake.snowpark.types import DataType + + +class Catalog: + """The Catalog class provides methods to interact with and manage the Snowflake objects. + It allows users to list, get, and drop various database objects such as databases, schemas, tables, + views, functions, etc. + """ + + def __init__(self, session: "snowflake.snowpark.session.Session") -> None: + self._session = session + self._root = Root(session) + self._python_regex_udf = None + + def _parse_database( + self, + database: Optional[Union[str, Database]], + model_obj: Optional[ + Union[Schema, Table, View, Function, Procedure, UserDefinedFunction] + ] = None, + ) -> str: + if isinstance( + model_obj, (Schema, Table, View, Function, Procedure, UserDefinedFunction) + ): + return model_obj.database_name + + if isinstance(database, str): + return database + if isinstance(database, Database): + return database.name + if database is None: + return self._session.get_current_database() + raise ValueError( + f"Unexpected type. Expected str or Database, got '{type(database)}'" + ) + + def _parse_schema( + self, + schema: Optional[Union[str, Schema]], + model_obj: Optional[ + Union[Table, View, Function, Procedure, UserDefinedFunction] + ] = None, + ) -> str: + if isinstance( + model_obj, (Table, View, Function, Procedure, UserDefinedFunction) + ): + return model_obj.schema_name + + if isinstance(schema, str): + return schema + if isinstance(schema, Schema): + return schema.name + if schema is None: + return self._session.get_current_schema() + raise ValueError( + f"Unexpected type. Expected str or Schema, got '{type(schema)}'" + ) + + def _parse_function_or_procedure( + self, + fn: Union[str, Function, Procedure, UserDefinedFunction], + arg_types: Optional[List[DataType]], + ) -> str: + if isinstance(fn, str): + if arg_types is None: + raise ValueError( + "arg_types must be provided when function/procedure is a string" + ) + arg_types_str = ", ".join( + [convert_sp_to_sf_type(arg_type) for arg_type in arg_types] + ) + return f"{fn}({arg_types_str})" + + arg_types_str = ", ".join(arg.datatype for arg in fn.arguments) + return f"{fn.name}({arg_types_str})" + + def _initialize_regex_udf(self) -> None: + with self._session._lock: + if self._python_regex_udf is not None: + return + + def python_regex_filter(pattern: str, input: str) -> bool: + return bool(re.match(pattern, input)) + + self._python_regex_udf = self._session.udf.register(python_regex_filter) + + def _list_objects( + self, + *, + object_name: str, + object_class, + database: Optional[Union[str, Database]], + schema: Optional[Union[str, Schema]], + pattern: Optional[str], + ): + db_name = self._parse_database(database) + schema_name = self._parse_schema(schema) + + df = self._session.sql( + f"SHOW AS RESOURCE {object_name} IN {db_name}.{schema_name} -- catalog api" + ) + if pattern: + # initialize udf + self._initialize_regex_udf() + + # The result of SHOW AS RESOURCE query is a json string which contains + # key 'name' to store the name of the object. We parse json for the returned + # result and apply the filter on name. + df = df.filter( + self._python_regex_udf( + lit(pattern), parse_json('"As Resource"')["name"] + ) + ) + + return list(map(lambda row: object_class.from_json(row[0]), df.collect())) + + # List methods + def list_databases( + self, + *, + pattern: Optional[str] = None, + ) -> List[Database]: + """List databases in the current session. + + Args: + pattern: the pattern of name to match. Defaults to None. + """ + iter = self._root.databases.iter() + if pattern: + iter = filter(lambda x: re.match(pattern, x.name), iter) + + return list(iter) + + def list_schemas( + self, + *, + database: Optional[Union[str, Database]] = None, + pattern: Optional[str] = None, + ) -> List[Schema]: + """List schemas in the current session. If database is provided, list schemas in the + database, otherwise list schemas in the current database. + + Args: + database: database name or ``Database`` object. Defaults to None. + pattern: the pattern of name to match. Defaults to None. + """ + db_name = self._parse_database(database) + iter = self._root.databases[db_name].schemas.iter() + if pattern: + iter = filter(lambda x: re.match(pattern, x.name), iter) + return list(iter) + + def list_tables( + self, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + pattern: Optional[str] = None, + ) -> List[Table]: + """List tables in the current session. If database or schema are provided, list tables + in the given database or schema, otherwise list tables in the current database/schema. + + Args: + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + pattern: the pattern of name to match. Defaults to None. + """ + return self._list_objects( + object_name="TABLES", + object_class=Table, + database=database, + schema=schema, + pattern=pattern, + ) + + def list_views( + self, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + pattern: Optional[str] = None, + ) -> List[View]: + """List views in the current session. If database or schema are provided, list views + in the given database or schema, otherwise list views in the current database/schema. + + Args: + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + pattern: the pattern of name to match. Defaults to None. + """ + return self._list_objects( + object_name="VIEWS", + object_class=View, + database=database, + schema=schema, + pattern=pattern, + ) + + def list_columns( + self, + table_name: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> List[TableColumn]: + """List columns in the given table. + + Args: + table_name: table name. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + if isinstance(table_name, str): + table = self.get_table(table_name, database=database, schema=schema) + else: + table = table_name + return table.columns + + def list_procedures( + self, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + pattern: Optional[str] = None, + ) -> List[Procedure]: + """List of procedures in the given database and schema. If database or schema are not + provided, list procedures in the current database and schema. + + Args: + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + pattern: the pattern of name to match. Defaults to None. + """ + return self._list_objects( + object_name="PROCEDURES", + object_class=Procedure, + database=database, + schema=schema, + pattern=pattern, + ) + + def list_user_defined_functions( + self, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + pattern: Optional[str] = None, + ) -> List[UserDefinedFunction]: + """List of user defined functions in the given database and schema. If database or schema + are not provided, list user defined functions in the current database and schema. + Args: + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + pattern: the pattern of name to match. Defaults to None. + """ + return self._list_objects( + object_name="USER FUNCTIONS", + object_class=UserDefinedFunction, + database=database, + schema=schema, + pattern=pattern, + ) + + # get methods + def get_current_database(self) -> Database: + """Get the current database.""" + current_db_name = self._session.get_current_database() + return self._root.databases[current_db_name] + + def get_current_schema(self) -> Schema: + """Get the current schema.""" + current_db = self.get_current_database() + current_schema_name = self._session.get_current_schema() + return current_db.schemas[current_schema_name] + + def get_table( + self, + table_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Table: + """Get the table by name in given database and schema. If database or schema are not + provided, get the table in the current database and schema. + + Args: + table_name: name of the table. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database) + schema_name = self._parse_schema(schema) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .tables[table_name] + .fetch() + ) + + def get_view( + self, + view_name: str, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> View: + """Get the view by name in given database and schema. If database or schema are not + provided, get the view in the current database and schema. + + Args: + view_name: name of the view. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database) + schema_name = self._parse_schema(schema) + return ( + self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + ) + + def get_procedure( + self, + procedure_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> Procedure: + """Get the procedure by name and argument types in given database and schema. If database or + schema are not provided, get the procedure in the current database and schema. + + Args: + procedure_name: name of the procedure. + arg_types: list of argument types to uniquely identify the procedure. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database) + schema_name = self._parse_schema(schema) + procedure_id = self._parse_function_or_procedure(procedure_name, arg_types) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .procedures[procedure_id] + .fetch() + ) + + def get_user_defined_function( + self, + udf_name: str, + arg_types: List[DataType], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> UserDefinedFunction: + """Get the user defined function by name and argument types in given database and schema. + If database or schema are not provided, get the user defined function in the current + database and schema. + + Args: + udf_name: name of the user defined function. + arg_types: list of argument types to uniquely identify the user defined function. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database) + schema_name = self._parse_schema(schema) + function_id = self._parse_function_or_procedure(udf_name, arg_types) + return ( + self._root.databases[db_name] + .schemas[schema_name] + .user_defined_functions[function_id] + .fetch() + ) + + # set methods + def set_current_database(self, database: Union[str, Database]) -> None: + """Set the current default database for the session. + + Args: + database: database name or ``Database`` object. + """ + db_name = self._parse_database(database) + self._session.use_database(db_name) + + def set_current_schema(self, schema: Union[str, Schema]) -> None: + """Set the current default schema for the session. + + Args: + schema: schema name or ``Schema`` object. + """ + schema_name = self._parse_schema(schema) + self._session.use_schema(schema_name) + + # exists methods + def database_exists(self, database: Union[str, Database]) -> bool: + """Check if the given database exists. + + Args: + database: database name or ``Database`` object. + """ + db_name = self._parse_database(database) + try: + self._root.databases[db_name].fetch() + return True + except NotFoundError: + return False + + def schema_exists( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> bool: + """Check if the given schema exists in the given database. If database is not provided, + check if the schema exists in the current database. + + Args: + schema: schema name or ``Schema`` object. + database: database name or ``Database`` object. Defaults to None. + """ + db_name = self._parse_database(database, schema) + schema_name = self._parse_schema(schema) + try: + self._root.databases[db_name].schemas[schema_name].fetch() + return True + except NotFoundError: + return False + + def table_exists( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + """Check if the given table exists in the given database and schema. If database or schema + are not provided, check if the table exists in the current database and schema. + + Args: + table: table name or ``Table`` object. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database, table) + schema_name = self._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + try: + self._root.databases[db_name].schemas[schema_name].tables[ + table_name + ].fetch() + return True + except NotFoundError: + return False + + def view_exists( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + """Check if the given view exists in the given database and schema. If database or schema + are not provided, check if the view exists in the current database and schema. + + Args: + view: view name or ``View`` object. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database, view) + schema_name = self._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + try: + self._root.databases[db_name].schemas[schema_name].views[view_name].fetch() + return True + except NotFoundError: + return False + + def procedure_exists( + self, + procedure: Union[str, Procedure], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + """Check if the given procedure exists in the given database and schema. If database or + schema are not provided, check if the procedure exists in the current database and schema. + + Args: + procedure: procedure name or ``Procedure`` object. + arg_types: list of argument types to uniquely identify the procedure. Defaults to None. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database, procedure) + schema_name = self._parse_schema(schema, procedure) + procedure_id = self._parse_function_or_procedure(procedure, arg_types) + + try: + self._root.databases[db_name].schemas[schema_name].procedures[ + procedure_id + ].fetch() + return True + except NotFoundError: + return False + + def user_defined_function_exists( + self, + udf: Union[str, UserDefinedFunction], + arg_types: Optional[List[DataType]] = None, + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> bool: + """Check if the given user defined function exists in the given database and schema. If + database or schema are not provided, check if the user defined function exists in the + current database and schema. + + Args: + udf: user defined function name or ``UserDefinedFunction`` object. + arg_types: list of argument types to uniquely identify the user defined function. + Defaults to None. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database, udf) + schema_name = self._parse_schema(schema, udf) + function_id = self._parse_function_or_procedure(udf, arg_types) + + try: + self._root.databases[db_name].schemas[schema_name].user_defined_functions[ + function_id + ].fetch() + return True + except NotFoundError: + return False + + # drop methods + def drop_database(self, database: Union[str, Database]) -> None: + """Drop the given database. + + Args: + database: database name or ``Database`` object. + """ + db_name = self._parse_database(database) + self._root.databases[db_name].drop() + + def drop_schema( + self, + schema: Union[str, Schema], + *, + database: Optional[Union[str, Database]] = None, + ) -> None: + """Drop the given schema in the given database. If database is not provided, drop the + schema in the current database. + + Args: + schema: schema name or ``Schema`` object. + database: database name or ``Database`` object. Defaults to None. + """ + db_name = self._parse_database(database, schema) + schema_name = self._parse_schema(schema) + self._root.databases[db_name].schemas[schema_name].drop() + + def drop_table( + self, + table: Union[str, Table], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + """Drop the given table in the given database and schema. If database or schema are not + provided, drop the table in the current database and schema. + + Args: + table: table name or ``Table`` object. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database, table) + schema_name = self._parse_schema(schema, table) + table_name = table if isinstance(table, str) else table.name + + self._root.databases[db_name].schemas[schema_name].tables[table_name].drop() + + def drop_view( + self, + view: Union[str, View], + *, + database: Optional[Union[str, Database]] = None, + schema: Optional[Union[str, Schema]] = None, + ) -> None: + """Drop the given view in the given database and schema. If database or schema are not + provided, drop the view in the current database and schema. + + Args: + view: view name or ``View`` object. + database: database name or ``Database`` object. Defaults to None. + schema: schema name or ``Schema`` object. Defaults to None. + """ + db_name = self._parse_database(database, view) + schema_name = self._parse_schema(schema, view) + view_name = view if isinstance(view, str) else view.name + + self._root.databases[db_name].schemas[schema_name].views[view_name].drop() + + # aliases + listDatabases = list_databases + listSchemas = list_schemas + listTables = list_tables + listViews = list_views + listColumns = list_columns + listProcedures = list_procedures + listUserDefinedFunctions = list_user_defined_functions + + getCurrentDatabase = get_current_database + getCurrentSchema = get_current_schema + getTable = get_table + getView = get_view + getProcedure = get_procedure + getUserDefinedFunction = get_user_defined_function + + setCurrentDatabase = set_current_database + setCurrentSchema = set_current_schema + + databaseExists = database_exists + schemaExists = schema_exists + tableExists = table_exists + viewExists = view_exists + procedureExists = procedure_exists + userDefinedFunctionExists = user_defined_function_exists + + dropDatabase = drop_database + dropSchema = drop_schema + dropTable = drop_table + dropView = drop_view diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 0231ad9c870..0c57942ba2b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -134,6 +134,7 @@ zip_file_or_directory_to_stream, ) from snowflake.snowpark.async_job import AsyncJob +from snowflake.snowpark.catalog import Catalog from snowflake.snowpark.column import Column from snowflake.snowpark.context import ( _is_execution_environment_sandboxed_for_client, @@ -656,6 +657,7 @@ def __init__( self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) self._sp_profiler = StoredProcedureProfiler(session=self) + self._catalog = None self._ast_batch = AstBatch(self) @@ -735,6 +737,18 @@ def get_active_session(cls) -> Optional["Session"]: getActiveSession = get_active_session + @property + def catalog(self) -> Catalog: + """Returns the catalog object.""" + if self._catalog is None: + if isinstance(self._conn, MockServerConnection): + self._conn.log_not_supported_error( + external_feature_name="Session.catalog", + raise_error=NotImplementedError, + ) + self._catalog = Catalog(self) + return self._catalog + def close(self) -> None: """Close this session.""" if is_in_stored_procedure(): diff --git a/tests/integ/test_catalog.py b/tests/integ/test_catalog.py new file mode 100644 index 00000000000..1b6b8726673 --- /dev/null +++ b/tests/integ/test_catalog.py @@ -0,0 +1,469 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import uuid +import pytest + +from snowflake.snowpark.catalog import Catalog +from snowflake.snowpark.session import Session +from snowflake.snowpark.types import IntegerType + +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="deepcopy is not supported and required by local testing", + run=False, + ) +] + +CATALOG_TEMP_OBJECT_PREFIX = "SP_CATALOG_TEMP" + + +def get_temp_name(type: str) -> str: + return f"{CATALOG_TEMP_OBJECT_PREFIX}_{type}_{uuid.uuid4().hex[:6]}".upper() + + +def create_temp_db(session) -> str: + original_db = session.get_current_database() + temp_db = get_temp_name("DB") + session._run_query(f"create or replace database {temp_db}") + session.use_database(original_db) + return temp_db + + +@pytest.fixture(scope="module") +def temp_db1(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +@pytest.fixture(scope="module") +def temp_db2(session): + temp_db = create_temp_db(session) + yield temp_db + session._run_query(f"drop database if exists {temp_db}") + + +def create_temp_schema(session, db: str) -> str: + original_db = session.get_current_database() + original_schema = session.get_current_schema() + temp_schema = get_temp_name("SCHEMA") + session._run_query(f"create or replace schema {db}.{temp_schema}") + + session.use_database(original_db) + session.use_schema(original_schema) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_schema1(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +@pytest.fixture(scope="module") +def temp_schema2(session, temp_db1): + temp_schema = create_temp_schema(session, temp_db1) + yield temp_schema + session._run_query(f"drop schema if exists {temp_db1}.{temp_schema}") + + +def create_temp_table(session, db: str, schema: str) -> str: + temp_table = get_temp_name("TABLE") + session._run_query( + f"create or replace temp table {db}.{schema}.{temp_table} (a int, b string)" + ) + return temp_table + + +@pytest.fixture(scope="module") +def temp_table1(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +@pytest.fixture(scope="module") +def temp_table2(session, temp_db1, temp_schema1): + temp_table = create_temp_table(session, temp_db1, temp_schema1) + yield temp_table + session._run_query(f"drop table if exists {temp_db1}.{temp_schema1}.{temp_table}") + + +def create_temp_view(session, db: str, schema: str) -> str: + temp_schema = get_temp_name("SCHEMA") + session._run_query( + f"create or replace temp view {db}.{schema}.{temp_schema} as select 1 as a, '2' as b" + ) + return temp_schema + + +@pytest.fixture(scope="module") +def temp_view1(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +@pytest.fixture(scope="module") +def temp_view2(session, temp_db1, temp_schema1): + temp_view = create_temp_view(session, temp_db1, temp_schema1) + yield temp_view + session._run_query(f"drop view if exists {temp_db1}.{temp_schema1}.{temp_view}") + + +def create_temp_procedure(session: Session, db, schema) -> str: + temp_procedure = get_temp_name("PROCEDURE") + session.sproc.register( + lambda _, x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_procedure}", + packages=["snowflake-snowpark-python"], + ) + return temp_procedure + + +@pytest.fixture(scope="module") +def temp_procedure1(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_procedure2(session, temp_db1, temp_schema1): + temp_procedure = create_temp_procedure(session, temp_db1, temp_schema1) + yield temp_procedure + session._run_query( + f"drop procedure if exists {temp_db1}.{temp_schema1}.{temp_procedure}(int)" + ) + + +def create_temp_udf(session: Session, db, schema) -> str: + temp_udf = get_temp_name("UDF") + session.udf.register( + lambda x: x + 1, + return_type=IntegerType(), + input_types=[IntegerType()], + name=f"{db}.{schema}.{temp_udf}", + ) + return temp_udf + + +@pytest.fixture(scope="module") +def temp_udf1(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) + + +@pytest.fixture(scope="module") +def temp_udf2(session, temp_db1, temp_schema1): + temp_udf = create_temp_udf(session, temp_db1, temp_schema1) + yield temp_udf + session._run_query( + f"drop function if exists {temp_db1}.{temp_schema1}.{temp_udf}(int)" + ) + + +def test_list_db(session, temp_db1, temp_db2): + catalog: Catalog = session.catalog + db_list = catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_DB_*") + assert {db.name for db in db_list} >= {temp_db1, temp_db2} + + +def test_list_schema(session, temp_db1, temp_schema1, temp_schema2): + catalog: Catalog = session.catalog + assert ( + len(catalog.list_databases(pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_*")) + == 0 + ) + schema_list = catalog.list_schemas( + pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_SCHEMA_*", database=temp_db1 + ) + assert {schema.name for schema in schema_list} >= {temp_schema1, temp_schema2} + + +def test_list_tables(session, temp_db1, temp_schema1, temp_table1, temp_table2): + catalog: Catalog = session.catalog + + assert len(catalog.list_tables(pattern="does_not_exist_*")) == 0 + assert ( + len( + catalog.list_tables( + pattern="does_not_exist_*", database=temp_db1, schema=temp_schema1 + ) + ) + == 0 + ) + + table_list = catalog.list_tables(database=temp_db1, schema=temp_schema1) + assert {table.name for table in table_list} == {temp_table1, temp_table2} + + cols = catalog.list_columns(temp_table1, database=temp_db1, schema=temp_schema1) + assert len(cols) == 2 + assert cols[0].name == "A" + assert cols[0].datatype == "NUMBER(38,0)" + assert cols[0].nullable is True + assert cols[1].name == "B" + assert cols[1].datatype == "VARCHAR(16777216)" + assert cols[1].nullable is True + + +def test_list_views(session, temp_db1, temp_schema1, temp_view1, temp_view2): + catalog: Catalog = session.catalog + + assert len(catalog.list_views(pattern="does_not_exist_*")) == 0 + assert ( + len( + catalog.list_views( + pattern="does_not_exist_*", database=temp_db1, schema=temp_schema1 + ) + ) + == 0 + ) + + view_list = catalog.list_views(database=temp_db1, schema=temp_schema1) + assert {view.name for view in view_list} >= {temp_view1, temp_view2} + + +def test_list_procedures( + session, temp_db1, temp_schema1, temp_procedure1, temp_procedure2 +): + catalog: Catalog = session.catalog + + assert len(catalog.list_procedures(pattern="does_not_exist_*")) == 0 + assert ( + len( + catalog.list_procedures( + pattern="does_not_exist_*", database=temp_db1, schema=temp_schema1 + ) + ) + == 0 + ) + + procedure_list = catalog.list_procedures( + database=temp_db1, + schema=temp_schema1, + pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_PROCEDURE_*", + ) + assert {procedure.name for procedure in procedure_list} >= { + temp_procedure1, + temp_procedure2, + } + + +@pytest.mark.xfail(reason="SNOW-1787268: Bug in snowflake api functions iter") +def test_list_udfs(session, temp_db1, temp_schema1, temp_udf1, temp_udf2): + catalog: Catalog = session.catalog + + assert len(catalog.list_functions(pattern="does_not_exist_*")) == 0 + assert ( + len( + catalog.list_functions( + pattern="does_not_exist_*", database=temp_db1, schema=temp_schema1 + ) + ) + == 0 + ) + udf_list = catalog.list_functions( + database=temp_db1, + schema=temp_schema1, + pattern=f"{CATALOG_TEMP_OBJECT_PREFIX}_UDF_*", + ) + assert {udf.name for udf in udf_list} >= {temp_udf1, temp_udf2} + + +def test_get_db_schema(session): + catalog: Catalog = session.catalog + current_db = session.get_current_database() + current_schema = session.get_current_schema() + assert catalog.get_current_database().name == current_db + assert catalog.get_current_schema().name == current_schema + + +def test_get_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): + catalog: Catalog = session.catalog + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert table.name == temp_table1 + assert table.database_name == temp_db1 + assert table.schema_name == temp_schema1 + + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert view.name == temp_view1 + assert view.database_name == temp_db1 + assert view.schema_name == temp_schema1 + + +def test_get_function_procedure_udf( + session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 +): + catalog: Catalog = session.catalog + + procedure = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert procedure.name == temp_procedure1 + assert procedure.database_name == temp_db1 + assert procedure.schema_name == temp_schema1 + + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert udf.name == temp_udf1 + assert udf.database_name == temp_db1 + assert udf.schema_name == temp_schema1 + + +def test_set_db_schema(session, temp_db1, temp_db2, temp_schema1, temp_schema2): + catalog = session.catalog + + original_db = session.get_current_database() + original_schema = session.get_current_schema() + try: + catalog.set_current_database(temp_db1) + catalog.set_current_schema(temp_schema1) + assert session.get_current_database() == f'"{temp_db1}"' + assert session.get_current_schema() == f'"{temp_schema1}"' + + catalog.set_current_schema(temp_schema2) + assert session.get_current_schema() == f'"{temp_schema2}"' + + catalog.set_current_database(temp_db2) + assert session.get_current_database() == f'"{temp_db2}"' + finally: + session.use_database(original_db) + session.use_schema(original_schema) + + +def test_exists_db_schema(session, temp_db1, temp_schema1): + catalog = session.catalog + assert catalog.database_exists(temp_db1) + assert not catalog.database_exists("does_not_exist") + + assert catalog.schema_exists(temp_schema1, database=temp_db1) + assert not catalog.schema_exists(temp_schema1, database="does_not_exist") + + +def test_exists_table_view(session, temp_db1, temp_schema1, temp_table1, temp_view1): + catalog = session.catalog + db1_obj = catalog._root.databases[temp_db1].fetch() + schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch() + + assert catalog.table_exists(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(temp_table1, database=db1_obj, schema=schema1_obj) + table = catalog.get_table(temp_table1, database=temp_db1, schema=temp_schema1) + assert catalog.table_exists(table) + assert not catalog.table_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + assert catalog.view_exists(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(temp_view1, database=db1_obj, schema=schema1_obj) + view = catalog.get_view(temp_view1, database=temp_db1, schema=temp_schema1) + assert catalog.view_exists(view) + assert not catalog.view_exists( + "does_not_exist", database=temp_db1, schema=temp_schema1 + ) + + +def test_exists_function_procedure_udf( + session, temp_db1, temp_schema1, temp_procedure1, temp_udf1 +): + catalog = session.catalog + db1_obj = catalog._root.databases[temp_db1].fetch() + schema1_obj = catalog._root.databases[temp_db1].schemas[temp_schema1].fetch() + + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists( + temp_procedure1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + proc = catalog.get_procedure( + temp_procedure1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.procedure_exists(proc) + assert not catalog.procedure_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists( + temp_udf1, [IntegerType()], database=db1_obj, schema=schema1_obj + ) + udf = catalog.get_user_defined_function( + temp_udf1, [IntegerType()], database=temp_db1, schema=temp_schema1 + ) + assert catalog.user_defined_function_exists(udf) + assert not catalog.user_defined_function_exists( + "does_not_exist", [], database=temp_db1, schema=temp_schema1 + ) + + +@pytest.mark.parametrize("use_object", [True, False]) +def test_drop(session, use_object): + catalog = session.catalog + + original_db = session.get_current_database() + original_schema = session.get_current_schema() + try: + temp_db = create_temp_db(session) + temp_schema = create_temp_schema(session, temp_db) + temp_table = create_temp_table(session, temp_db, temp_schema) + temp_view = create_temp_view(session, temp_db, temp_schema) + if use_object: + temp_schema = catalog._root.databases[temp_db].schemas[temp_schema].fetch() + temp_db = catalog._root.databases[temp_db].fetch() + + assert catalog.database_exists(temp_db) + assert catalog.schema_exists(temp_schema, database=temp_db) + assert catalog.table_exists(temp_table, database=temp_db, schema=temp_schema) + assert catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_table(temp_table, database=temp_db, schema=temp_schema) + catalog.drop_view(temp_view, database=temp_db, schema=temp_schema) + + assert not catalog.table_exists( + temp_table, database=temp_db, schema=temp_schema + ) + assert not catalog.view_exists(temp_view, database=temp_db, schema=temp_schema) + + catalog.drop_schema(temp_schema, database=temp_db) + assert not catalog.schema_exists(temp_schema, database=temp_db) + + catalog.drop_database(temp_db) + assert not catalog.database_exists(temp_db) + finally: + session.use_database(original_db) + session.use_schema(original_schema) + + +def test_parse_names_negative(session): + catalog = session.catalog + with pytest.raises( + ValueError, + match="Unexpected type. Expected str or Database, got '<class 'int'>'", + ): + catalog.database_exists(123) + + with pytest.raises( + ValueError, match="Unexpected type. Expected str or Schema, got '<class 'int'>'" + ): + catalog.schema_exists(123) + + with pytest.raises( + ValueError, + match="arg_types must be provided when function/procedure is a string", + ): + catalog.procedure_exists("proc")