diff --git a/packages/hyper_file.py b/packages/hyper_file.py index 03fd871..5fac392 100644 --- a/packages/hyper_file.py +++ b/packages/hyper_file.py @@ -1,15 +1,12 @@ -from tableauhyperapi import HyperProcess, Telemetry, Connection, CreateMode, \ - escape_name +from tableauhyperapi import HyperProcess, Telemetry, Connection, CreateMode, escape_name from pyarrow.parquet import ParquetFile import os import logging as log import packages.hyper_utils as hu -class HyperFile(): - - def __init__(self, parquet_folder: str, - file_extension: str = None) -> None: +class HyperFile: + def __init__(self, parquet_folder: str, file_extension: str = None) -> None: """Requires parquet folder and parquet file extension if any Args: @@ -36,9 +33,9 @@ def create_hyper_file(self, hyper_path: str) -> int: telemetry = Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU create_mode = CreateMode.CREATE_AND_REPLACE with HyperProcess(telemetry=telemetry) as hp: - with Connection(endpoint=hp.endpoint, - database=hyper_path, - create_mode=create_mode) as conn: + with Connection( + endpoint=hp.endpoint, database=hyper_path, create_mode=create_mode + ) as conn: table_definition = hu.get_table_def(ParquetFile(files[0])) schema = table_definition.table_name.schema_name conn.catalog.create_schema(schema=schema) @@ -46,18 +43,20 @@ def create_hyper_file(self, hyper_path: str) -> int: total_rows = 0 for file in files: try: - copy_command = f"COPY \"Extract\".\"Extract\" from '{file}' with (format parquet)" # noqa + copy_command = f'COPY "Extract"."Extract" from \'{file}\' with (format parquet)' count = conn.execute_command(copy_command) total_rows += count except Exception as e: - log.warning(f'File {os.path.basename(file)} \ - could not be processed. {e}') - log.info(f'Error message: {e}') - log.info(f'Process completed with {total_rows} rows added.') + log.warning( + f"File {os.path.basename(file)} could not be processed. {e}" + ) + log.info(f"Error message: {e}") + log.info(f"Process completed with {total_rows} rows added.") return total_rows - def delete_rows(self, hyper_path: str, date_column: str, - days_to_delete: int) -> int: + def delete_rows( + self, hyper_path: str, date_column: str, days_to_delete: int + ) -> int: """Delete rows from a hyper based on days before a date to delete. Args: @@ -69,12 +68,12 @@ def delete_rows(self, hyper_path: str, date_column: str, """ telemetry = Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU with HyperProcess(telemetry=telemetry) as hp: - with Connection(endpoint=hp.endpoint, - database=hyper_path, - create_mode=CreateMode.NONE) as connection: - qry = f'DELETE FROM \"Extract\".\"Extract\" WHERE {escape_name(date_column)} >= CURRENT_DATE - {days_to_delete}' # noqa + with Connection( + endpoint=hp.endpoint, database=hyper_path, create_mode=CreateMode.NONE + ) as connection: + qry = f'DELETE FROM "Extract"."Extract" WHERE {escape_name(date_column)} >= CURRENT_DATE - {days_to_delete}' # noqa count = connection.execute_command(qry) - log.info(f'Process completed with {count} rows deleted.') + log.info(f"Process completed with {count} rows deleted.") return count def append_rows(self, hyper_path: str) -> int: @@ -88,20 +87,20 @@ def append_rows(self, hyper_path: str) -> int: """ telemetry = Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU with HyperProcess(telemetry=telemetry) as hp: - with Connection(endpoint=hp.endpoint, - database=hyper_path, - create_mode=CreateMode.NONE) as connection: + with Connection( + endpoint=hp.endpoint, database=hyper_path, create_mode=CreateMode.NONE + ) as connection: total_rows = 0 - files = hu.get_parquet_files(self.parquet_folder, - self.file_extension) + files = hu.get_parquet_files(self.parquet_folder, self.file_extension) for parquet_path in files: try: - copy_command = f"COPY \"Extract\".\"Extract\" from '{parquet_path}' with (format parquet)" # noqa + copy_command = f'COPY "Extract"."Extract" from \'{parquet_path}\' with (format parquet)' count = connection.execute_command(copy_command) total_rows += count except Exception as e: - log.warning(f'File {os.path.basename(parquet_path)}\ - could not be processed. {e}') - log.info(f'Error message: {e}') - log.info(f'Process completed with {total_rows} rows added.') + log.warning( + f"File {os.path.basename(parquet_path)} could not be processed. {e}" + ) + log.info(f"Error message: {e}") + log.info(f"Process completed with {total_rows} rows added.") return total_rows diff --git a/packages/hyper_utils.py b/packages/hyper_utils.py index cb66cd0..149c66c 100644 --- a/packages/hyper_utils.py +++ b/packages/hyper_utils.py @@ -1,6 +1,5 @@ from pyarrow import ChunkedArray -from tableauhyperapi import SqlType, TableDefinition, NULLABLE, NOT_NULLABLE, \ - TableName +from tableauhyperapi import SqlType, TableDefinition, NULLABLE, NOT_NULLABLE, TableName from pyarrow.parquet import ParquetFile import glob import pyarrow as pa @@ -18,6 +17,11 @@ def _convert_struct_field(column: ChunkedArray) -> TableDefinition.Column: Returns: Column: Column with Hyper SqlType """ + S = "s" + MS = "ms" + NS = "ns" + US = "us" + DECIMAL = "decimal" if column.type == pa.string(): sql_type = SqlType.text() elif column.type in [pa.date32(), pa.date64()]: @@ -32,22 +36,25 @@ def _convert_struct_field(column: ChunkedArray) -> TableDefinition.Column: sql_type = SqlType.big_int() elif column.type == pa.bool_(): sql_type = SqlType.bool() - elif column.type in [pa.timestamp('s'), pa.timestamp('ms'), - pa.timestamp('us'), pa.timestamp('ns')]: + elif column.type in [ + pa.timestamp(S), + pa.timestamp(MS), + pa.timestamp(US), + pa.timestamp(NS), + ]: sql_type = SqlType.timestamp() elif column.type == pa.binary(): sql_type = SqlType.bytes() - elif str(column.type).startswith("decimal"): + elif str(column.type).startswith(DECIMAL): precision = column.type.precision scale = column.type.scale sql_type = SqlType.numeric(precision, scale) else: - raise ValueError(f'Invalid StructField datatype for column \ - `{column.name}` : {column.type}') + raise ValueError( + f"Invalid StructField datatype for column `{column.name}` : {column.type}" + ) nullable = NULLABLE if column.nullable else NOT_NULLABLE - return TableDefinition.Column(name=column.name, - type=sql_type, - nullability=nullable) + return TableDefinition.Column(name=column.name, type=sql_type, nullability=nullable) def get_table_def(df: ParquetFile) -> TableDefinition: @@ -61,14 +68,10 @@ def get_table_def(df: ParquetFile) -> TableDefinition: """ schema = df.schema_arrow cols = list(map(_convert_struct_field, schema)) - return TableDefinition( - table_name=TableName("Extract", "Extract"), - columns=cols - ) + return TableDefinition(table_name=TableName("Extract", "Extract"), columns=cols) -def get_parquet_files(parquet_folder: str, - parquet_extension: str = None) -> list[str]: +def get_parquet_files(parquet_folder: str, parquet_extension: str = None) -> list[str]: """Get list of parquet files in a folder Args: @@ -82,9 +85,10 @@ def get_parquet_files(parquet_folder: str, Returns: list[str]: list of filenames """ - ext = f"*.{parquet_extension}" if parquet_extension is not None else '*' + ext = f"*.{parquet_extension}" if parquet_extension is not None else "*" files = glob.glob(parquet_folder + ext) if len(files) == 0: - raise ValueError(f'Error! The parquet_folder: {parquet_folder} \ - returned no files!') + raise ValueError( + f"Error! The parquet_folder: {parquet_folder} returned no files!" + ) return files diff --git a/packages/tableau_server_utils.py b/packages/tableau_server_utils.py index 0d64c7f..92e9074 100644 --- a/packages/tableau_server_utils.py +++ b/packages/tableau_server_utils.py @@ -3,9 +3,7 @@ class TableauServerUtils: - - def __init__(self, server_address: str, token_name: str, - token_value: str) -> None: + def __init__(self, server_address: str, token_name: str, token_value: str) -> None: """TableauServerUtils constructor Args: @@ -14,8 +12,7 @@ def __init__(self, server_address: str, token_name: str, token_value (str): token value """ self.server = TSC.Server(server_address, use_server_version=True) - self.tableau_auth = TSC.PersonalAccessTokenAuth(token_name, - token_value) + self.tableau_auth = TSC.PersonalAccessTokenAuth(token_name, token_value) def get_project_id(self, project_name: str) -> str: """Get project id by name @@ -30,19 +27,24 @@ def get_project_id(self, project_name: str) -> str: Returns: str: project id """ - logging.info(f'Signing into the server {self.server.baseurl}') + logging.info(f"Signing into the server {self.server.baseurl}") with self.server.auth.sign_in(self.tableau_auth): req_options = TSC.RequestOptions() req_options.filter.add( - TSC.Filter(TSC.RequestOptions.Field.Name, - TSC.RequestOptions.Operator.Equals, project_name)) + TSC.Filter( + TSC.RequestOptions.Field.Name, + TSC.RequestOptions.Operator.Equals, + project_name, + ) + ) projects = list(TSC.Pager(self.server.projects, req_options)) if len(projects) > 1: raise ValueError("The project name is not unique.") return projects[0].id - def publish_hyper(self, project_id: str, hyper_path: str, - mode: str = 'overwrite') -> None: + def publish_hyper( + self, project_id: str, hyper_path: str, mode: str = "overwrite" + ) -> None: """Publish hyper file into the tableau server Args: @@ -51,18 +53,18 @@ def publish_hyper(self, project_id: str, hyper_path: str, mode (str): publish mode. Accept overwrite or append mode. Defaults to overwrite. """ - logging.info(f'Signing into the server {self.server.baseurl}') + OVERWRITE = "overwrite" + APPEND = "append" + logging.info(f"Signing into the server {self.server.baseurl}") with self.server.auth.sign_in(self.tableau_auth): - if mode == 'overwrite': + if mode == OVERWRITE: publish_mode = TSC.Server.PublishMode.Overwrite - elif mode == 'append': + elif mode == APPEND: publish_mode = TSC.Server.PublishMode.Append else: - raise ValueError(f'Error! Mode must be overwrite or append.\ - Received {mode}') + raise ValueError(f"Mode must be overwrite or append. Received {mode}") datasource = TSC.DatasourceItem(project_id=project_id) - logging.info('Publishing Hyper file into the server!') - ds = self.server.datasources.publish( - datasource, hyper_path, publish_mode) - logging.info(f'Datasource published on ID: {ds.id}') - logging.info('Job Finished.') + logging.info("Publishing Hyper file into the server!") + ds = self.server.datasources.publish(datasource, hyper_path, publish_mode) + logging.info(f"Datasource published on ID: {ds.id}") + logging.info("Job Finished.") diff --git a/packages/time_decorator.py b/packages/time_decorator.py index f66cfc5..334597e 100644 --- a/packages/time_decorator.py +++ b/packages/time_decorator.py @@ -13,12 +13,14 @@ def timeit(func: Callable) -> Callable: Returns: Callable: measured function """ + @wraps(func) def timeit_wrapper(*args, **kwargs) -> Callable: start_time = time.perf_counter() result = func(*args, **kwargs) end_time = time.perf_counter() total_time = end_time - start_time - logging.info(f'Function {func.__name__} took {total_time:.2f} seconds') + logging.info(f"Function {func.__name__} took {total_time:.2f} seconds") return result + return timeit_wrapper diff --git a/tests/test_hyper_file.py b/tests/test_hyper_file.py index 02b1816..aaf9be7 100644 --- a/tests/test_hyper_file.py +++ b/tests/test_hyper_file.py @@ -6,45 +6,46 @@ import pytest import os +FILENAME = "test.hyper" +QRY = 'SELECT COUNT(*) FROM "Extract"."Extract"' + @pytest.fixture def create_hyper_file(get_pyarrow_table): def _method(hyper_filename): + PARQUET = "parquet" df = get_pyarrow_table - filename = str(dt.datetime.today().strftime("%Y-%m-%d")) + '.parquet' + filename = str(dt.datetime.today().strftime("%Y-%m-%d")) + "." + PARQUET pa.parquet.write_table(df, filename) - hf = HyperFile('', 'parquet') + hf = HyperFile("", PARQUET) hf.create_hyper_file(hyper_filename) return hf + return _method def test_create_hyper_file(create_hyper_file): - filename = 'test.hyper' - create_hyper_file(filename) + create_hyper_file(FILENAME) with HyperProcess(Telemetry.SEND_USAGE_DATA_TO_TABLEAU) as hyper: - with Connection(hyper.endpoint, filename, CreateMode.NONE) as con: - rows = con.execute_scalar_query( - 'SELECT COUNT(*) FROM "Extract"."Extract"') - os.remove(filename) + with Connection(hyper.endpoint, FILENAME, CreateMode.NONE) as con: + rows = con.execute_scalar_query(QRY) + os.remove(FILENAME) assert rows == 2 def test_delete_rows(create_hyper_file): - filename = 'test.hyper' - hf = create_hyper_file(filename) - count = hf.delete_rows(filename, 'date32', 1) - os.remove(filename) + COLUMN_NAME = "date32" + hf = create_hyper_file(FILENAME) + count = hf.delete_rows(FILENAME, COLUMN_NAME, 1) + os.remove(FILENAME) assert count == 1 def test_append_rows(create_hyper_file): - filename = 'test.hyper' - hf = create_hyper_file(filename) - hf.append_rows(filename) + hf = create_hyper_file(FILENAME) + hf.append_rows(FILENAME) with HyperProcess(Telemetry.SEND_USAGE_DATA_TO_TABLEAU) as hyper: - with Connection(hyper.endpoint, filename, CreateMode.NONE) as con: - rows = con.execute_scalar_query( - 'SELECT COUNT(*) FROM "Extract"."Extract"') - os.remove(filename) + with Connection(hyper.endpoint, FILENAME, CreateMode.NONE) as con: + rows = con.execute_scalar_query(QRY) + os.remove(FILENAME) assert rows == 4 diff --git a/tests/test_hyper_utils.py b/tests/test_hyper_utils.py index 7be5e8c..ecd5275 100644 --- a/tests/test_hyper_utils.py +++ b/tests/test_hyper_utils.py @@ -5,31 +5,48 @@ import os from tableauhyperapi import SqlType +STRF_TIME = "%Y-%m-%d" + @pytest.fixture def get_pyarrow_table(): + A = "a" + B = "b" + US = "us" + UTF8 = "utf-8" + COL_NAMES = [ + "int8", + "int16", + "int32", + "int64", + "string", + "float32", + "float64", + "bool", + "timestamp", + "date32", + "date64", + "binary", + "decimal128", + ] array = [ pa.array([1, 2], type=pa.int8()), pa.array([1, 2], type=pa.int16()), pa.array([1, 2], type=pa.int32()), pa.array([1, 2], type=pa.int64()), - pa.array(['a', 'b'], type=pa.string()), + pa.array([A, B], type=pa.string()), pa.array([1.0, 1.5], type=pa.float32()), pa.array([1.0, 1.5], type=pa.float64()), pa.array([True, False], type=pa.bool_()), - pa.array([dt.datetime(2023, 1, 1, 0, 0, 0), dt.datetime.now()], - type=pa.timestamp('us')), + pa.array( + [dt.datetime(2023, 1, 1, 0, 0, 0), dt.datetime.now()], type=pa.timestamp(US) + ), pa.array([dt.date(2023, 1, 1), dt.date.today()], type=pa.date32()), pa.array([dt.date(2023, 1, 1), dt.date.today()], type=pa.date64()), - pa.array([b'a', b'b'], type=pa.binary()), - pa.array([1234, 1234], type=pa.decimal128(7, 3)) - ] - names = [ - 'int8', 'int16', 'int32', 'int64', 'string', 'float32', - 'float64', 'bool', 'timestamp', 'date32', 'date64', - 'binary', 'decimal128' + pa.array([A.encode(UTF8), B.encode(UTF8)], type=pa.binary()), + pa.array([1234, 1234], type=pa.decimal128(7, 3)), ] - yield pa.table(array, names=names) + yield pa.table(array, names=COL_NAMES) @pytest.fixture @@ -56,7 +73,7 @@ def test_convert_struct_field(get_pyarrow_schema): def test_get_table_def(get_pyarrow_table): df = get_pyarrow_table - now = str(dt.datetime.today().strftime("%Y-%m-%d")) + now = str(dt.datetime.today().strftime(STRF_TIME)) pa.parquet.write_table(df, now) with pa.parquet.ParquetFile(now) as file: table_def = hu.get_table_def(file) @@ -77,11 +94,12 @@ def test_get_table_def(get_pyarrow_table): def test_get_parquet_files(get_pyarrow_table): + PARQUET = "parquet" df = get_pyarrow_table - now = str(dt.datetime.today().strftime("%Y-%m-%d")) - extension = '.parquet' + now = str(dt.datetime.today().strftime(STRF_TIME)) + extension = "." + PARQUET filename = now + extension pa.parquet.write_table(df, filename) - files = hu.get_parquet_files('', extension.replace('.', '')) + files = hu.get_parquet_files("", extension.replace(".", "")) os.remove(filename) assert len(files) == 1