From 43079c001a744d04640f333ffa2f20e718281170 Mon Sep 17 00:00:00 2001 From: Shounak kulkarni Date: Fri, 22 Mar 2024 23:57:10 +0500 Subject: [PATCH] Allow passing database for pinot queries (#89) --- examples/pinot_async.py | 5 +++-- examples/pinot_live.py | 5 +++-- examples/pinot_quickstart_auth_zk.py | 3 ++- examples/pinot_quickstart_batch.py | 5 +++-- examples/pinot_quickstart_hybrid.py | 5 +++-- examples/pinot_quickstart_json_batch.py | 5 +++-- examples/pinot_quickstart_multi_stage.py | 3 ++- examples/pinot_quickstart_timeout.py | 9 ++++++--- pinotdb/db.py | 5 +++-- pinotdb/sqlalchemy.py | 24 ++++++++++++++++++++---- tests/unit/test_sqlalchemy.py | 4 +++- 11 files changed, 51 insertions(+), 22 deletions(-) diff --git a/examples/pinot_async.py b/examples/pinot_async.py index 2752417..ed01ad3 100644 --- a/examples/pinot_async.py +++ b/examples/pinot_async.py @@ -7,7 +7,8 @@ async def run_pinot_async_example(): async with connect_async(host='localhost', port=8000, path='/query/sql', - scheme='http', verify_ssl=False, timeout=10.0) as conn: + scheme='http', verify_ssl=False, timeout=10.0, + extra_request_headers="Database=default") as conn: curs = await conn.execute(""" SELECT count(*) FROM baseballStats @@ -20,7 +21,7 @@ async def run_pinot_async_example(): session = httpx.AsyncClient(verify=False) conn = connect_async( host='localhost', port=8000, path='/query/sql', scheme='http', - verify_ssl=False, session=session) + verify_ssl=False, session=session, extra_request_headers="Database=default") # launch 10 requests in parallel spanning a limit/offset range reqs = [] diff --git a/examples/pinot_live.py b/examples/pinot_live.py index 3b1564c..be0e77d 100644 --- a/examples/pinot_live.py +++ b/examples/pinot_live.py @@ -8,7 +8,8 @@ def run_pinot_live_example() -> None: # Query pinot.live with pinotdb connect - conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https") + conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https", + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT * FROM airlineStats LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -21,7 +22,7 @@ def run_pinot_live_example() -> None: "pinot+https://pinot-broker.pinot.live:443/query/sql?controller=https://pinot-controller.pinot.live/" ) # uses HTTP by default :( - airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True) + airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot") query=select([func.count("*")], from_obj=airlineStats) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_auth_zk.py b/examples/pinot_quickstart_auth_zk.py index e59dd89..b074d12 100644 --- a/examples/pinot_quickstart_auth_zk.py +++ b/examples/pinot_quickstart_auth_zk.py @@ -20,6 +20,7 @@ def run_pinot_quickstart_batch_example() -> None: scheme="http", username="admin", password="verysecret", + extra_request_headers="Database=default", ) curs = conn.cursor() tables = [ @@ -65,7 +66,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None: # engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/') # engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/') - baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True) + baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot") query = select([func.count("*")], from_obj=baseballStats) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_batch.py b/examples/pinot_quickstart_batch.py index 4d18cea..af6d5b8 100644 --- a/examples/pinot_quickstart_batch.py +++ b/examples/pinot_quickstart_batch.py @@ -12,7 +12,8 @@ def run_pinot_quickstart_batch_example() -> None: - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http") + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", + extra_request_headers="Database=default") curs = conn.cursor() tables = [ @@ -52,7 +53,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None: # engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/') # engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/') - baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True) + baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot") query = select([func.count("*")], from_obj=baseballStats) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_hybrid.py b/examples/pinot_quickstart_hybrid.py index f7ad687..957de71 100644 --- a/examples/pinot_quickstart_hybrid.py +++ b/examples/pinot_quickstart_hybrid.py @@ -11,7 +11,8 @@ ## -d apachepinot/pinot:latest QuickStart -type hybrid def run_pinot_quickstart_hybrid_example() -> None: - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http") + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT * FROM airlineStats LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -53,7 +54,7 @@ def run_pinot_quickstart_hybrid_sqlalchemy_example() -> None: # engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/') # engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/') - airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True) + airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot") query=select([func.count("*")], from_obj=airlineStats) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_json_batch.py b/examples/pinot_quickstart_json_batch.py index 74937b5..e7caad5 100644 --- a/examples/pinot_quickstart_json_batch.py +++ b/examples/pinot_quickstart_json_batch.py @@ -13,7 +13,8 @@ def run_quickstart_json_batch_example() -> None: - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http") + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT * FROM githubEvents LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -43,7 +44,7 @@ def run_quickstart_json_batch_sqlalchemy_example() -> None: # engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/') # engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/') - githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True) + githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True, schema="default") print(f"\nSending Count(*) SQL to Pinot\nResults:") query=select([func.count("*")], from_obj=githubEvents) print(engine.execute(query).scalar()) diff --git a/examples/pinot_quickstart_multi_stage.py b/examples/pinot_quickstart_multi_stage.py index 42a1e51..2219b42 100644 --- a/examples/pinot_quickstart_multi_stage.py +++ b/examples/pinot_quickstart_multi_stage.py @@ -11,7 +11,8 @@ ## -d apachepinot/pinot:latest QuickStart -type MULTI_STAGE def run_pinot_quickstart_multi_stage_example() -> None: - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True) + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True, + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT a.playerID, a.runs, a.yearID, b.runs, b.yearID FROM baseballStats_OFFLINE AS a JOIN baseballStats_OFFLINE AS b ON a.playerID = b.playerID WHERE a.runs > 160 AND b.runs < 2 LIMIT 10" diff --git a/examples/pinot_quickstart_timeout.py b/examples/pinot_quickstart_timeout.py index 5a687be..677d112 100644 --- a/examples/pinot_quickstart_timeout.py +++ b/examples/pinot_quickstart_timeout.py @@ -11,7 +11,8 @@ def run_pinot_quickstart_timeout_example() -> None: #Test 1 : Try without timeout. The request should succeed. - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http") + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT * FROM airlineStats LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -20,7 +21,8 @@ def run_pinot_quickstart_timeout_example() -> None: #Test 2 : Try with timeout=None. The request should succeed. - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None) + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None, + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT count(*) FROM airlineStats LIMIT 5" print(f"Sending SQL to Pinot: {sql}") @@ -29,7 +31,8 @@ def run_pinot_quickstart_timeout_example() -> None: #Test 3 : Try with a really small timeout. The query should raise an exception. - conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001) + conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001, + extra_request_headers="Database=default") curs = conn.cursor() sql = "SELECT AirlineID, sum(Cancelled) FROM airlineStats WHERE Year > 2010 GROUP BY AirlineID LIMIT 5" print(f"Sending SQL to Pinot: {sql}") diff --git a/pinotdb/db.py b/pinotdb/db.py index 55e59bd..ed206f3 100644 --- a/pinotdb/db.py +++ b/pinotdb/db.py @@ -162,7 +162,7 @@ def close(self): except exceptions.Error: pass # already closed # if we're managing the httpx session, attempt to close it - if not self.is_session_external: + if not self.is_session_external and self.session: self.session.close() @check_closed @@ -334,7 +334,8 @@ def __init__( for header in extra_request_headers.split(","): k, v = header.split("=", 1) extra_headers[k] = v - + if 'database' in kwargs: + extra_headers['database'] = kwargs['database'] self.session.headers.update(extra_headers) @check_closed diff --git a/pinotdb/sqlalchemy.py b/pinotdb/sqlalchemy.py index 8ea44ea..bfa8af5 100644 --- a/pinotdb/sqlalchemy.py +++ b/pinotdb/sqlalchemy.py @@ -123,6 +123,11 @@ def __init__( ) +def extract_table_name(fqn): + split = fqn.split(".", 2) + return fqn if len(split) == 1 else split[1] + + class PinotDialect(default.DefaultDialect): name = "pinot" @@ -132,6 +137,7 @@ class PinotDialect(default.DefaultDialect): preparer = PinotIdentifierPareparer statement_compiler = PinotCompiler type_compiler = PinotTypeCompiler + supports_schemas = False supports_statement_cache = False supports_alter = False supports_pk_autoincrement = False @@ -154,6 +160,7 @@ def __init__(self, *args, **kwargs): self._password = None self._debug = False self._verify_ssl = True + self._database = None self.update_from_kwargs(kwargs) def update_from_kwargs(self, givenkw): @@ -167,6 +174,8 @@ def update_from_kwargs(self, givenkw): kwargs["username"] = self._username = kwargs.pop("username") if "password" in kwargs: kwargs["password"] = self._password = kwargs.pop("password") + if "database" in kwargs: + kwargs["database"] = self._database = kwargs.pop("database") kwargs["debug"] = self._debug = bool(kwargs.get("debug", False)) kwargs["verify_ssl"] = self._verify_ssl = (str(kwargs.get("verify_ssl", "true")).lower() in ['true']) logger.info( @@ -206,7 +215,7 @@ def create_connect_args(self, url): def get_metadata_from_controller(self, path): url = parse.urljoin(self._controller, path) - r = requests.get(url, headers={"Accept": "application/json"}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password)) + r = requests.get(url, headers={"Accept": "application/json", "Database": self._database}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password)) try: result = r.json() except ValueError as e: @@ -221,13 +230,20 @@ def get_metadata_from_controller(self, path): return result def get_schema_names(self, connection, **kwargs): - return ["default"] + if self._database: + return [self._database] + else: + return ['default'] def has_table(self, connection, table_name, schema=None): return table_name in self.get_table_names(connection, schema) def get_table_names(self, connection, schema=None, **kwargs): - return self.get_metadata_from_controller("/tables")["tables"] + resp = self.get_metadata_from_controller("/tables") + if 'tables' in resp: + return list(map(extract_table_name, resp["tables"])) + else: + return [] def get_view_names(self, connection, schema=None, **kwargs): return [] @@ -296,7 +312,7 @@ def _check_unicode_returns(self, connection, additional_tests=None): def _check_unicode_description(self, connection): return True - + # Fix for SQL Alchemy error def _json_deserializer(self, content: any): """ diff --git a/tests/unit/test_sqlalchemy.py b/tests/unit/test_sqlalchemy.py index 0005008..31b7c4f 100644 --- a/tests/unit/test_sqlalchemy.py +++ b/tests/unit/test_sqlalchemy.py @@ -105,8 +105,10 @@ def test_cannot_get_metadata_if_broken_json(self): def test_gets_schema_names(self): names = self.dialect.get_schema_names('some connection') - self.assertEqual(names, ['default']) + self.dialect._database = 'foo' + names = self.dialect.get_schema_names('some connection') + self.assertEqual(names, ['foo']) @responses.activate def test_gets_table_names_from_controller(self):