Skip to content

Commit

Permalink
Allow passing database for pinot queries (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
shounakmk219 authored Mar 22, 2024
1 parent 741e6be commit 43079c0
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 22 deletions.
5 changes: 3 additions & 2 deletions examples/pinot_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

async def run_pinot_async_example():
async with connect_async(host='localhost', port=8000, path='/query/sql',
scheme='http', verify_ssl=False, timeout=10.0) as conn:
scheme='http', verify_ssl=False, timeout=10.0,
extra_request_headers="Database=default") as conn:
curs = await conn.execute("""
SELECT count(*)
FROM baseballStats
Expand All @@ -20,7 +21,7 @@ async def run_pinot_async_example():
session = httpx.AsyncClient(verify=False)
conn = connect_async(
host='localhost', port=8000, path='/query/sql', scheme='http',
verify_ssl=False, session=session)
verify_ssl=False, session=session, extra_request_headers="Database=default")

# launch 10 requests in parallel spanning a limit/offset range
reqs = []
Expand Down
5 changes: 3 additions & 2 deletions examples/pinot_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

def run_pinot_live_example() -> None:
# Query pinot.live with pinotdb connect
conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https")
conn = connect(host="pinot-broker.pinot.live", port=443, path="/query/sql", scheme="https",
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT * FROM airlineStats LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand All @@ -21,7 +22,7 @@ def run_pinot_live_example() -> None:
"pinot+https://pinot-broker.pinot.live:443/query/sql?controller=https://pinot-controller.pinot.live/"
) # uses HTTP by default :(

airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True)
airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot")
query=select([func.count("*")], from_obj=airlineStats)
print(engine.execute(query).scalar())
Expand Down
3 changes: 2 additions & 1 deletion examples/pinot_quickstart_auth_zk.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def run_pinot_quickstart_batch_example() -> None:
scheme="http",
username="admin",
password="verysecret",
extra_request_headers="Database=default",
)
curs = conn.cursor()
tables = [
Expand Down Expand Up @@ -65,7 +66,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None:
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')

baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True)
baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot")
query = select([func.count("*")], from_obj=baseballStats)
print(engine.execute(query).scalar())
Expand Down
5 changes: 3 additions & 2 deletions examples/pinot_quickstart_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@


def run_pinot_quickstart_batch_example() -> None:
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
extra_request_headers="Database=default")
curs = conn.cursor()

tables = [
Expand Down Expand Up @@ -52,7 +53,7 @@ def run_pinot_quickstart_batch_sqlalchemy_example() -> None:
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')

baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True)
baseballStats = Table("baseballStats", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot")
query = select([func.count("*")], from_obj=baseballStats)
print(engine.execute(query).scalar())
Expand Down
5 changes: 3 additions & 2 deletions examples/pinot_quickstart_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
## -d apachepinot/pinot:latest QuickStart -type hybrid

def run_pinot_quickstart_hybrid_example() -> None:
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT * FROM airlineStats LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand Down Expand Up @@ -53,7 +54,7 @@ def run_pinot_quickstart_hybrid_sqlalchemy_example() -> None:
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')

airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True)
airlineStats = Table("airlineStats", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot")
query=select([func.count("*")], from_obj=airlineStats)
print(engine.execute(query).scalar())
Expand Down
5 changes: 3 additions & 2 deletions examples/pinot_quickstart_json_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@


def run_quickstart_json_batch_example() -> None:
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT * FROM githubEvents LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand Down Expand Up @@ -43,7 +44,7 @@ def run_quickstart_json_batch_sqlalchemy_example() -> None:
# engine = create_engine('pinot+http://localhost:8000/query/sql?controller=http://localhost:9000/')
# engine = create_engine('pinot+https://localhost:8000/query/sql?controller=http://localhost:9000/')

githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True)
githubEvents = Table("githubEvents", MetaData(bind=engine), autoload=True, schema="default")
print(f"\nSending Count(*) SQL to Pinot\nResults:")
query=select([func.count("*")], from_obj=githubEvents)
print(engine.execute(query).scalar())
Expand Down
3 changes: 2 additions & 1 deletion examples/pinot_quickstart_multi_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
## -d apachepinot/pinot:latest QuickStart -type MULTI_STAGE

def run_pinot_quickstart_multi_stage_example() -> None:
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True)
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", use_multistage_engine=True,
extra_request_headers="Database=default")
curs = conn.cursor()

sql = "SELECT a.playerID, a.runs, a.yearID, b.runs, b.yearID FROM baseballStats_OFFLINE AS a JOIN baseballStats_OFFLINE AS b ON a.playerID = b.playerID WHERE a.runs > 160 AND b.runs < 2 LIMIT 10"
Expand Down
9 changes: 6 additions & 3 deletions examples/pinot_quickstart_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def run_pinot_quickstart_timeout_example() -> None:

#Test 1 : Try without timeout. The request should succeed.

conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http")
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http",
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT * FROM airlineStats LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand All @@ -20,7 +21,8 @@ def run_pinot_quickstart_timeout_example() -> None:

#Test 2 : Try with timeout=None. The request should succeed.

conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None)
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=None,
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT count(*) FROM airlineStats LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand All @@ -29,7 +31,8 @@ def run_pinot_quickstart_timeout_example() -> None:

#Test 3 : Try with a really small timeout. The query should raise an exception.

conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001)
conn = connect(host="localhost", port=8000, path="/query/sql", scheme="http", timeout=0.001,
extra_request_headers="Database=default")
curs = conn.cursor()
sql = "SELECT AirlineID, sum(Cancelled) FROM airlineStats WHERE Year > 2010 GROUP BY AirlineID LIMIT 5"
print(f"Sending SQL to Pinot: {sql}")
Expand Down
5 changes: 3 additions & 2 deletions pinotdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def close(self):
except exceptions.Error:
pass # already closed
# if we're managing the httpx session, attempt to close it
if not self.is_session_external:
if not self.is_session_external and self.session:
self.session.close()

@check_closed
Expand Down Expand Up @@ -334,7 +334,8 @@ def __init__(
for header in extra_request_headers.split(","):
k, v = header.split("=", 1)
extra_headers[k] = v

if 'database' in kwargs:
extra_headers['database'] = kwargs['database']
self.session.headers.update(extra_headers)

@check_closed
Expand Down
24 changes: 20 additions & 4 deletions pinotdb/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ def __init__(
)


def extract_table_name(fqn):
split = fqn.split(".", 2)
return fqn if len(split) == 1 else split[1]


class PinotDialect(default.DefaultDialect):

name = "pinot"
Expand All @@ -132,6 +137,7 @@ class PinotDialect(default.DefaultDialect):
preparer = PinotIdentifierPareparer
statement_compiler = PinotCompiler
type_compiler = PinotTypeCompiler
supports_schemas = False
supports_statement_cache = False
supports_alter = False
supports_pk_autoincrement = False
Expand All @@ -154,6 +160,7 @@ def __init__(self, *args, **kwargs):
self._password = None
self._debug = False
self._verify_ssl = True
self._database = None
self.update_from_kwargs(kwargs)

def update_from_kwargs(self, givenkw):
Expand All @@ -167,6 +174,8 @@ def update_from_kwargs(self, givenkw):
kwargs["username"] = self._username = kwargs.pop("username")
if "password" in kwargs:
kwargs["password"] = self._password = kwargs.pop("password")
if "database" in kwargs:
kwargs["database"] = self._database = kwargs.pop("database")
kwargs["debug"] = self._debug = bool(kwargs.get("debug", False))
kwargs["verify_ssl"] = self._verify_ssl = (str(kwargs.get("verify_ssl", "true")).lower() in ['true'])
logger.info(
Expand Down Expand Up @@ -206,7 +215,7 @@ def create_connect_args(self, url):

def get_metadata_from_controller(self, path):
url = parse.urljoin(self._controller, path)
r = requests.get(url, headers={"Accept": "application/json"}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password))
r = requests.get(url, headers={"Accept": "application/json", "Database": self._database}, verify=self._verify_ssl, auth= HTTPBasicAuth(self._username, self._password))
try:
result = r.json()
except ValueError as e:
Expand All @@ -221,13 +230,20 @@ def get_metadata_from_controller(self, path):
return result

def get_schema_names(self, connection, **kwargs):
return ["default"]
if self._database:
return [self._database]
else:
return ['default']

def has_table(self, connection, table_name, schema=None):
return table_name in self.get_table_names(connection, schema)

def get_table_names(self, connection, schema=None, **kwargs):
return self.get_metadata_from_controller("/tables")["tables"]
resp = self.get_metadata_from_controller("/tables")
if 'tables' in resp:
return list(map(extract_table_name, resp["tables"]))
else:
return []

def get_view_names(self, connection, schema=None, **kwargs):
return []
Expand Down Expand Up @@ -296,7 +312,7 @@ def _check_unicode_returns(self, connection, additional_tests=None):

def _check_unicode_description(self, connection):
return True

# Fix for SQL Alchemy error
def _json_deserializer(self, content: any):
"""
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ def test_cannot_get_metadata_if_broken_json(self):

def test_gets_schema_names(self):
names = self.dialect.get_schema_names('some connection')

self.assertEqual(names, ['default'])
self.dialect._database = 'foo'
names = self.dialect.get_schema_names('some connection')
self.assertEqual(names, ['foo'])

@responses.activate
def test_gets_table_names_from_controller(self):
Expand Down

0 comments on commit 43079c0

Please sign in to comment.