forked from IntelLabs/Auto-Steer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
67 lines (54 loc) · 2.65 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# Copyright 2022 Intel Corporation
# SPDX-License-Identifier: MIT
#
"""Run AutoSteer's training mode to explore alternative query plans"""
from typing import Type
import storage
import os
import sys
import connectors.connector
from connectors import mysql_connector, duckdb_connector, postgres_connector, presto_connector, spark_connector
from utils.arguments_parser import get_parser
from utils.custom_logging import logger
from autosteer.dp_exploration import explore_optimizer_configs
from autosteer.query_span import run_get_query_span
from inference.train import train_tcnn
def approx_query_span_and_run(connector: Type[connectors.connector.DBConnector], benchmark: str, query: str):
run_get_query_span(connector, benchmark, query)
connector = connector()
explore_optimizer_configs(connector, f'{benchmark}/{query}')
def inference_mode(connector, benchmark: str, retrain: bool, create_datasets: bool):
train_tcnn(connector, benchmark, retrain, create_datasets)
def get_connector_type(connector: str) -> Type[connectors.connector.DBConnector]:
if connector == 'postgres':
return postgres_connector.PostgresConnector
elif connector == 'mysql':
return mysql_connector.MySqlConnector
elif connector == 'spark':
return spark_connector.SparkConnector
elif connector == 'presto':
return presto_connector.PrestoConnector
elif connector == 'duckdb':
return duckdb_connector.DuckDBConnector
logger.fatal('Unknown connector %s', connector)
if __name__ == '__main__':
args = get_parser().parse_args()
ConnectorType = get_connector_type(args.database)
storage.TESTED_DATABASE = ConnectorType.get_name()
if args.benchmark is None or not os.path.isdir(args.benchmark):
logger.fatal('Cannot access the benchmark directory containing the sql files with path=%s', args.benchmark)
sys.exit(1)
storage.BENCHMARK_ID = storage.register_benchmark(args.benchmark)
if (args.inference and args.training) or (not args.inference and not args.training):
logger.fatal('Specify either training or inference mode')
sys.exit(1)
if args.inference:
logger.info('Run AutoSteer\'s inference mode')
inference_mode(ConnectorType, args.benchmark, args.retrain, args.create_datasets)
elif args.training:
logger.info('Run AutoSteer\'s training mode')
queries = sorted(list(filter(lambda q: q.endswith('.sql'), os.listdir(args.benchmark))))
logger.info('Found the following SQL files: %s', queries)
for query in queries:
logger.info('run Q%s...', query)
approx_query_span_and_run(ConnectorType, args.benchmark, query)