Skip to content

Commit

Permalink
add API interfaces for train, predict and evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
qidanrui committed Dec 2, 2023
1 parent 55e7a0c commit b17e6bd
Show file tree
Hide file tree
Showing 9 changed files with 205 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
8 changes: 8 additions & 0 deletions dbgpt_hub/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
dbgpt_hub.eval
==============
"""

from .evaluation_api import start_evaluate

__all__ = ["start_evaluate"]
37 changes: 37 additions & 0 deletions dbgpt_hub/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import subprocess
import json

from typing import Optional, Dict, Any
from process_sql import get_schema, Schema, get_sql
from exec_eval import eval_exec_match
from func_timeout import func_timeout, FunctionTimedOut
Expand Down Expand Up @@ -1152,6 +1153,42 @@ def build_foreign_key_map_from_json(table):
return tables


def evaluate_api(args: Optional[Dict[str, Any]] = None):
# Prepare output file path by appending "2sql" before ".txt" if --natsql is true
if args["natsql"]:
pred_file_path = (
args["input"].rsplit(".", 1)[0] + "2sql." + args["input"].rsplit(".", 1)[1]
)
gold_file_path = args["gold_natsql"]
table_info_path = args["table_natsql"]
else:
pred_file_path = args["input"]
gold_file_path = args["gold"]
table_info_path = args["table"]

# only evaluating exact match needs this argument
kmaps = None
if args["etype"] in ["all", "match"]:
assert (
args.table is not None
), "table argument must be non-None if exact set match is evaluated"
kmaps = build_foreign_key_map_from_json(args["table"])

# Print args
print(f"params as fllows \n {args}")

evaluate(
gold_file_path,
pred_file_path,
args["db"],
args["etype"],
kmaps,
args["plug_value"],
args["keep_distinct"],
args["progress_bar_for_each_datapoint"],
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down
32 changes: 32 additions & 0 deletions dbgpt_hub/eval/evaluation_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional, Dict, Any

from dbgpt_hub.eval import evaluation


def start_evaluate(
args: Optional[Dict[str, Any]] = None,
):
# Arguments for evaluation
if args is None:
args = {
"input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql",
"gold": "./dbgpt_hub/data/eval_data/gold.txt",
"gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt",
"db": "./dbgpt_hub/data/spider/database",
"table": "./dbgpt_hub/data/eval_data/tables.json",
"table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json",
"etype": "exec",
"plug_value": True,
"keep_distict": False,
"progress_bar_for_each_datapoint": False,
"natsql": False,
}
else:
args = args

# Execute evaluation
evaluation.evaluate_api(args)


if __name__ == "__main__":
start_evaluate()
8 changes: 8 additions & 0 deletions dbgpt_hub/predict/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
dbgpt_hub.predict
==============
"""

from .predict_api import start_predict

__all__ = ["start_predict"]
55 changes: 33 additions & 22 deletions dbgpt_hub/predict/predict.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import os
import json
import sys
from tqdm import tqdm

ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from typing import List, Dict

from tqdm import tqdm
from typing import List, Dict, Optional, Any

from dbgpt_hub.data_process.data_utils import extract_sql_prompt_dataset
from dbgpt_hub.llm_base.chat_model import ChatModel
from dbgpt_hub.configs.config import (
PREDICTED_DATA_PATH,
OUT_DIR,
PREDICTED_OUT_FILENAME,
)


def prepare_dataset() -> List[Dict]:
with open(PREDICTED_DATA_PATH, "r") as fp:
def prepare_dataset(
predict_file_path: Optional[str] = None,
) -> List[Dict]:
with open(predict_file_path, "r") as fp:
data = json.load(fp)
predict_data = [extract_sql_prompt_dataset(item) for item in data]
return predict_data
Expand All @@ -33,21 +31,34 @@ def inference(model: ChatModel, predict_data: List[Dict], **input_kwargs):
return res


def main():
predict_data = prepare_dataset()
def predict(args: Optional[Dict[str, Any]] = None):
predict_file_path = ""
if args is None:
predict_file_path = os.path.join(
ROOT_PATH, "dbgpt_hub/data/eval_data/dev_sql.json"
)
predict_out_dir = os.path.join(
os.path.join(ROOT_PATH, "dbgpt_hub/output/"), "pred"
)
if not os.path.exists(predict_out_dir):
os.mkdir(predict_out_dir)
predict_output_filename = os.path.join(predict_out_dir, "pred_sql.sql")
print(f"predict_output_filename \t{predict_output_filename}")
else:
predict_file_path = os.path.join(ROOT_PATH, args["predict_file_path"])
predict_out_dir = os.path.join(
os.path.join(ROOT_PATH, args["predict_out_dir"]), "pred"
)
if not os.path.exists(predict_out_dir):
os.mkdir(predict_out_dir)
predict_output_filename = os.path.join(predict_out_dir, args["pred_sql.sql"])
print(f"predict_output_filename \t{predict_output_filename}")

predict_data = prepare_dataset(predict_file_path=predict_file_path)
model = ChatModel()
result = inference(model, predict_data)

predict_out_dir = os.path.join(OUT_DIR, "pred")
if not os.path.exists(predict_out_dir):
os.mkdir(predict_out_dir)

predict_output_dir_name = os.path.join(
predict_out_dir, model.data_args.predicted_out_filename
)
print(f"predict_output_dir_name \t{predict_output_dir_name}")

with open(predict_output_dir_name, "w") as f:
with open(predict_output_filename, "w") as f:
for p in result:
try:
f.write(p.replace("\n", " ") + "\n")
Expand All @@ -56,4 +67,4 @@ def main():


if __name__ == "__main__":
main()
predict()
31 changes: 31 additions & 0 deletions dbgpt_hub/predict/predict_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
from dbgpt_hub.predict import predict
from typing import Optional, Dict, Any


def start_predict(
args: Optional[Dict[str, Any]] = None, cuda_visible_devices: Optional[str] = "0"
):
# Setting CUDA Device
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices

# Default Arguments
if args is None:
args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"template": "llama2",
"finetuning_type": "lora",
"checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json",
"predict_out_dir": "dbgpt_hub/output/",
"predicted_out_filename": "pred_sql.sql",
}
else:
args = args

# Execute prediction
predict.predict(args)


if __name__ == "__main__":
start_predict()
8 changes: 8 additions & 0 deletions dbgpt_hub/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
dbgpt_hub.train
==============
"""

from .sft_train_api import start_sft

__all__ = ["start_sft"]
47 changes: 47 additions & 0 deletions dbgpt_hub/train/sft_train_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os

from typing import Optional, Dict, Any
from dbgpt_hub.train import sft_train


def start_sft(
args: Optional[Dict[str, Any]] = None, cuda_visible_devices: Optional[str] = "0"
):
# Setting CUDA Device
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices

# Default Arguments
if args is None:
args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"do_train": True,
"dataset": "example_text2sql_train",
"max_source_length": 2048,
"max_target_length": 512,
"finetuning_type": "lora",
"lora_target": "q_proj,v_proj",
"template": "llama2",
"lora_rank": 64,
"lora_alpha": 32,
"output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"overwrite_cache": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 16,
"lr_scheduler_type": "cosine_with_restarts",
"logging_steps": 50,
"save_steps": 2000,
"learning_rate": 2e-4,
"num_train_epochs": 8,
"plot_loss": True,
"bf16": True,
}
else:
args = args

# Run SFT
sft_train.train(args)


if __name__ == "__main__":
start_sft()

0 comments on commit b17e6bd

Please sign in to comment.