diff --git a/pydbapi/api/trino.py b/pydbapi/api/trino.py index 9df962d..9c36bb2 100644 --- a/pydbapi/api/trino.py +++ b/pydbapi/api/trino.py @@ -1,176 +1,163 @@ -# -*- coding: utf-8 -*- -# @Author: longfengpili -# @Date: 2023-06-02 15:27:41 -# @Last Modified by: longfengpili -# @Last Modified time: 2024-02-28 16:26:03 -# @github: https://github.com/longfengpili - - -import threading -from datetime import date - -from trino.dbapi import connect -from trino.auth import BasicAuthentication - -from pydbapi.db import DBMixin, DBFileExec -from pydbapi.sql import SqlCompile -# from pydbapi.col import ColumnModel, ColumnsModel -from pydbapi.conf import AUTO_RULES - - -import logging -mytrinologger = logging.getLogger(__name__) - - -class SqlTrinoCompile(SqlCompile): - '''[summary] - - [description] - 构造mysql sql - Extends: - SqlCompile - ''' - - def __init__(self, tablename): - super(SqlTrinoCompile, self).__init__(tablename) - - def create_partition(self, partition): - coltype = partition.coltype - if not (coltype.startswith('varchar') or coltype == 'date'): - raise TypeError(f"{partition} only support varchar, date !") - partition = f"with (partitioned_by = ARRAY['{partition.newname}'])" - return partition - - def create(self, columns, partition=None): - partition_sql = None - if partition: - partition_key = columns.get_column_by_name(partition) - if not partition_key: - raise ValueError(f"<{partition}> not in {columns}") - - columns.remove(partition) - columns.append(partition_key) - partition_sql = self.create_partition(partition_key) - - sql = self.create_nonindex(columns) - - if partition_sql: - sql = sql.replace(';', f'\n{partition_sql};') - - return sql - - -class TrinoDB(DBMixin, DBFileExec): - _instance_lock = threading.Lock() - - def __init__(self, host, user, password, database, catalog='hive', port=8443, safe_rule=True): - '''[summary] - - [init] - - Args: - host ([str]): [host] - user ([str]): [username] - password ([str]): [password] - database ([str]): [database] - isolation_level (number): [isolation_level] (default: `0`) - AUTOCOMMIT = 0 # 每个事务单独执行 - READ_UNCOMMITTED = 1 # 脏读(dirty read),一个事务可以读取到另一个事务未提交的事务记录 - READ_COMMITTED = 2 # 不可重复读(non-repeatable read),一个事务只能读取到已经提交的记录,不能读取到未提交的记录 - REPEATABLE_READ = 3 # 幻读(phantom read),一个事务可以多次从数据库读取某条记录,而且多次读取的那条记录都是一致的,相同的 - SERIALIZABLE = 4 # 事务执行时,会在所有级别上加锁,比如read和write时都会加锁,仿佛事务是以串行的方式进行的,而不是一起发生的。这会防止脏读、不可重复读和幻读的出现,但是,会带来性能的下降 - 数据库默认的隔离级别:mysql为可重复读,oracle为提交后读 - trino不支持多个事务组合操作 - catalog (str): [cataglog] (default: `'hive'`) - port (number): [port] (default: `8443`) - safe_rule (bool): [safe rule] (default: `True`) - ''' - - self.host = host - self.port = port - self.user = user - self.password = password - self.catalog = catalog - self.database = database - super(TrinoDB, self).__init__() - self.auto_rules = AUTO_RULES if safe_rule else None - self.dbtype = 'trino' - - # def __new__(cls, *args, **kwargs): - # if not hasattr(TrinoDB, '_instance'): - # with TrinoDB._instance_lock: - # if not hasattr(TrinoDB, '_instance'): - # TrinoDB._instance = super().__new__(cls) - - # return TrinoDB._instance - - @classmethod - def get_instance(cls, *args, **kwargs): - # mytrinologger.info(TrinoDB._instance_lock) - if not hasattr(TrinoDB, '_instance'): - # mytrinologger.info(TrinoDB._instance_lock) - with TrinoDB._instance_lock: - if not hasattr(TrinoDB, '_instance'): - TrinoDB._instance = cls(*args, **kwargs) - - return TrinoDB._instance - - def get_conn(self): - if not hasattr(TrinoDB, '_conn'): - with TrinoDB._instance_lock: - if not hasattr(TrinoDB, '_conn'): - auth = BasicAuthentication(self.user, self.password) - conn = connect(host=self.host, user=self.user, auth=auth, - catalog=self.catalog, schema=self.database, - port=self.port, http_scheme="https" - ) - mytrinologger.info(f'connect {self.__class__.__name__}({self.user}@{self.host}:{self.port}/{self.catalog}.{self.database})') # noqa: E501 - TrinoDB._conn = conn - return TrinoDB._conn - - def create(self, tablename, columns, partition=None, verbose=0): - # tablename = f"{self.database}.{tablename}" - sqlcompile = SqlTrinoCompile(tablename) - sql_for_create = sqlcompile.create(columns, partition=partition) - rows, action, result = self.execute(sql_for_create, verbose=verbose) - return rows, action, result - - def insert(self, tablename, columns, inserttype='value', values=None, chunksize=1000, - fromtable=None, condition=None, verbose=0): - if values: - vlength = len(values) - - if self._check_isauto(tablename): - sqlcompile = SqlCompile(tablename) - sql_for_insert = sqlcompile.insert(columns, inserttype=inserttype, values=values, - chunksize=chunksize, fromtable=fromtable, condition=condition) - rows, action, result = self.execute(sql_for_insert, verbose=verbose) - - rows = vlength if values else rows - mytrinologger.info(f'【{action}】{tablename} insert succeed !') - return rows, action, result - - def alter_table(self, tablename: str, colname: str, newname: str = None, newtype: str = None, - partition: str = 'part_date', verbose: int = 0): - - old_columns, alter_columns = self.alter_column(tablename, colname, newname, newtype) - - # tablename - today = date.today() - today_str = today.strftime('%Y%m%d') - tablename_backup = f"{tablename}_{today_str}_{self.user}" - tablename_tmp = f"{tablename}_tmp" - - # alter tablename to backup - altersql = f'alter table {tablename} rename to {tablename_backup};' - self.execute(altersql, verbose=verbose) - - # create tmp table - self.create(tablename_tmp, alter_columns, partition=partition, verbose=verbose) - - # move data to tmp table - self.insert(tablename_tmp, alter_columns, fromtable=tablename_backup, inserttype='select', verbose=verbose) - - # alter tmp to tablename - altersql = f'alter table {tablename_tmp} rename to {tablename};' - self.execute(altersql, verbose=verbose) +# -*- coding: utf-8 -*- +# @Author: longfengpili +# @Date: 2023-06-02 15:27:41 +# @Last Modified by: longfengpili +# @Last Modified time: 2024-02-28 18:41:09 +# @github: https://github.com/longfengpili + + +import threading +from datetime import date + +from trino.dbapi import connect +from trino.auth import BasicAuthentication + +from pydbapi.db import DBMixin, DBFileExec +from pydbapi.sql import SqlCompile +# from pydbapi.col import ColumnModel, ColumnsModel +from pydbapi.conf import AUTO_RULES + + +import logging +mytrinologger = logging.getLogger(__name__) + + +class SqlTrinoCompile(SqlCompile): + '''[summary] + + [description] + 构造mysql sql + Extends: + SqlCompile + ''' + + def __init__(self, tablename): + super(SqlTrinoCompile, self).__init__(tablename) + + def create_partition(self, partition): + coltype = partition.coltype + if not (coltype.startswith('varchar') or coltype == 'date'): + raise TypeError(f"{partition} only support varchar, date !") + partition = f"with (partitioned_by = ARRAY['{partition.newname}'])" + return partition + + def create(self, columns, partition=None): + partition_sql = None + if partition: + partition_key = columns.get_column_by_name(partition) + if not partition_key: + raise ValueError(f"<{partition}> not in {columns}") + + columns.remove(partition) + columns.append(partition_key) + partition_sql = self.create_partition(partition_key) + + sql = self.create_nonindex(columns) + + if partition_sql: + sql = sql.replace(';', f'\n{partition_sql};') + + return sql + + +class TrinoDB(DBMixin, DBFileExec): + _instance_lock = threading.Lock() + + def __init__(self, host, user, password, database, catalog='hive', port=8443, safe_rule=True): + '''[summary] + + [init] + + Args: + host ([str]): [host] + user ([str]): [username] + password ([str]): [password] + database ([str]): [database] + isolation_level (number): [isolation_level] (default: `0`) + AUTOCOMMIT = 0 # 每个事务单独执行 + READ_UNCOMMITTED = 1 # 脏读(dirty read),一个事务可以读取到另一个事务未提交的事务记录 + READ_COMMITTED = 2 # 不可重复读(non-repeatable read),一个事务只能读取到已经提交的记录,不能读取到未提交的记录 + REPEATABLE_READ = 3 # 幻读(phantom read),一个事务可以多次从数据库读取某条记录,而且多次读取的那条记录都是一致的,相同的 + SERIALIZABLE = 4 # 事务执行时,会在所有级别上加锁,比如read和write时都会加锁,仿佛事务是以串行的方式进行的,而不是一起发生的。这会防止脏读、不可重复读和幻读的出现,但是,会带来性能的下降 + 数据库默认的隔离级别:mysql为可重复读,oracle为提交后读 + trino不支持多个事务组合操作 + catalog (str): [cataglog] (default: `'hive'`) + port (number): [port] (default: `8443`) + safe_rule (bool): [safe rule] (default: `True`) + ''' + + self.host = host + self.port = port + self.user = user + self.password = password + self.catalog = catalog + self.database = database + super(TrinoDB, self).__init__() + self.auto_rules = AUTO_RULES if safe_rule else None + self.dbtype = 'trino' + + # def __new__(cls, *args, **kwargs): + # if not hasattr(TrinoDB, '_instance'): + # with TrinoDB._instance_lock: + # if not hasattr(TrinoDB, '_instance'): + # TrinoDB._instance = super().__new__(cls) + + # return TrinoDB._instance + + @classmethod + def get_instance(cls, *args, **kwargs): + # mytrinologger.info(TrinoDB._instance_lock) + if not hasattr(TrinoDB, '_instance'): + # mytrinologger.info(TrinoDB._instance_lock) + with TrinoDB._instance_lock: + if not hasattr(TrinoDB, '_instance'): + TrinoDB._instance = cls(*args, **kwargs) + + return TrinoDB._instance + + def get_conn(self): + if not hasattr(TrinoDB, '_conn'): + with TrinoDB._instance_lock: + if not hasattr(TrinoDB, '_conn'): + auth = BasicAuthentication(self.user, self.password) + conn = connect(host=self.host, user=self.user, auth=auth, + catalog=self.catalog, schema=self.database, + port=self.port, http_scheme="https" + ) + mytrinologger.info(f'connect {self.__class__.__name__}({self.user}@{self.host}:{self.port}/{self.catalog}.{self.database})') # noqa: E501 + TrinoDB._conn = conn + return TrinoDB._conn + + def create(self, tablename, columns, partition=None, verbose=0): + # tablename = f"{self.database}.{tablename}" + sqlcompile = SqlTrinoCompile(tablename) + sql_for_create = sqlcompile.create(columns, partition=partition) + rows, action, result = self.execute(sql_for_create, verbose=verbose) + return rows, action, result + + def insert(self, tablename, columns, inserttype='value', values=None, chunksize=1000, + fromtable=None, condition=None, verbose=0): + if values: + vlength = len(values) + + if self._check_isauto(tablename): + sqlcompile = SqlCompile(tablename) + sql_for_insert = sqlcompile.insert(columns, inserttype=inserttype, values=values, + chunksize=chunksize, fromtable=fromtable, condition=condition) + rows, action, result = self.execute(sql_for_insert, verbose=verbose) + + rows = vlength if values else rows + mytrinologger.info(f'【{action}】{tablename} insert succeed !') + return rows, action, result + + def alter_table(self, tablename: str, colname: str, newname: str = None, newtype: str = None, + partition: str = 'part_date', verbose: int = 0): + + alter_columns = self.alter_column(tablename, colname, newname, newtype) + + # create middle table + mtablename = f"{tablename}_middle" + self.create(mtablename, alter_columns, partition=partition, verbose=verbose) + + # alter table + self.alter_table_base(tablename, mtablename, alter_columns, verbose=verbose) diff --git a/tests/trino/test_trino.py b/tests/trino/test_trino.py index a99da13..ef5d588 100644 --- a/tests/trino/test_trino.py +++ b/tests/trino/test_trino.py @@ -2,7 +2,7 @@ # @Author: longfengpili # @Date: 2023-06-02 15:27:41 # @Last Modified by: longfengpili -# @Last Modified time: 2024-02-28 15:59:17 +# @Last Modified time: 2024-02-29 17:09:55 # @github: https://github.com/longfengpili @@ -24,8 +24,8 @@ class TestTrino: def setup_method(self, method): self.trinodb = TrinoDB(TRINO_HOST, TRINO_USER, TRINO_PASSWORD, TRINO_DATABASE, safe_rule=False) - self.tablename = 'dow_jp_w.test_friut_xu' - self.id = ColumnModel('id', 'integer') + self.tablename = 'warship_jp_w.test_friut_xu' + self.id = ColumnModel('id', 'varchar') self.name = ColumnModel('name', 'varchar(1024)') self.address = ColumnModel('address', 'varchar(1024)') self.birthday = ColumnModel('birthday', 'varchar(1024)') @@ -49,26 +49,34 @@ def test_drop(self): rows, action, result = self.trinodb.drop(tablename) print(f"【rows】: {rows}, 【action】: {action}, 【result】: {result}") + def test_get_columns(self): + try: + columns = self.trinodb.get_columns(self.tablename) + print(columns) + self.trinodb.drop(self.tablename) + except Exception as e: + print(e) + def test_create(self): rows, action, result = self.trinodb.create(self.tablename, self.columns, partition='birthday') print(f"【rows】: {rows}, 【action】: {action}, 【result】: {result}") def test_insertsql(self): - values = [[1, 'apple', 'beijing', '2012-01-23', '{"yuwen": 90, "shuxue": 20}'], - [2, 'banana', 'shanghai', '2020-02-25 01:00:00', '{"yuwen": 91, "shuxue": 80}'], - [3, 'chocolate', 'yunnan', '2020-06-14 23:00:05', '{"yuwen": 90, "shuxue": 90}'], - [4, 'pizza', 'taiwan', '2020-05-15 23:08:25', '{"yuwen": 10, "shuxue": 21}'], - [5, 'pizza', 'hebei', '2020-08-12 14:05:36', '{"yuwen": 30, "shuxue": 23}']] + values = [['1', 'apple', 'beijing', '2012-01-23', '{"yuwen": 90, "shuxue": 20}'], + ['2', 'banana', 'shanghai', '2020-02-25 01:00:00', '{"yuwen": 91, "shuxue": 80}'], + ['3', 'chocolate', 'yunnan', '2020-06-14 23:00:05', '{"yuwen": 90, "shuxue": 90}'], + ['4', 'pizza', 'taiwan', '2020-05-15 23:08:25', '{"yuwen": 10, "shuxue": 21}'], + ['5', 'pizza', 'hebei', '2020-08-12 14:05:36', '{"yuwen": 30, "shuxue": 23}']] sqlcompile = SqlTrinoCompile(self.tablename) sql = sqlcompile.insert(self.columns, inserttype='value', values=values, chunksize=1000) print(sql) def test_insert(self): - values = [[1, 'apple', 'beijing', '2012-01-23', '{"yuwen": 90, "shuxue": 20}'], - [2, 'banana', 'shanghai', '2020-02-25 01:00:00', '{"yuwen": 91, "shuxue": 80}'], - [3, 'chocolate', 'yunnan', '2020-06-14 23:00:05', '{"yuwen": 90, "shuxue": 90}'], - [4, 'pizza', 'taiwan', '2020-05-15 23:08:25', '{"yuwen": 10, "shuxue": 21}'], - [5, 'pizza', 'hebei', '2020-08-12 14:05:36', '{"yuwen": 30, "shuxue": 23}']] + values = [['1', 'apple', 'beijing', '2012-01-23', '{"yuwen": 90, "shuxue": 20}'], + ['2', 'banana', 'shanghai', '2020-02-25 01:00:00', '{"yuwen": 91, "shuxue": 80}'], + ['3', 'chocolate', 'yunnan', '2020-06-14 23:00:05', '{"yuwen": 90, "shuxue": 90}'], + ['4', 'pizza', 'taiwan', '2020-05-15 23:08:25', '{"yuwen": 10, "shuxue": 21}'], + ['5', 'pizza', 'hebei', '2020-08-12 14:05:36', '{"yuwen": 30, "shuxue": 23}']] rows, action, result = self.trinodb.insert(self.tablename, self.columns, values=values, chunksize=1, verbose=1) print(f"【rows】: {rows}, 【action】: {action}, 【result】: {result}") @@ -190,9 +198,8 @@ def test_select_by_sql(self): print(f"【rows】: {rows}, 【action】: {action}, 【result】: {result}") def test_alter_col(self): - old_cols, alter_cols = self.trinodb.alter_column(self.tablename, 'name', 'b', 'int') - print(old_cols) + alter_cols = self.trinodb.alter_column(self.tablename, 'id', 'idx', 'int') print(alter_cols) def test_alter_table(self): - self.trinodb.alter_table(self.tablename, colname='name', newname='b', newtype='int', partition='birthday') + self.trinodb.alter_table(self.tablename, colname='id', newname='idx', newtype='int', partition='birthday')