-
Notifications
You must be signed in to change notification settings - Fork 3
/
app.py
85 lines (78 loc) · 3.01 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import sqlparse
model_name = "defog/sqlcoder-7b-2"
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
pipeline = pipeline(task="text-generation", model=model, tokenizer=tokenizer)
def predict(question, ddl):
prompt = f"""### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
### Database Schema
The query will run on a database with the following schema:
{ddl}
### Answer
Given the database schema, here is the SQL query that [QUESTION]{question}[/QUESTION]
[SQL]
"""
predictions = pipeline(
prompt,
max_new_tokens=300,
do_sample=False,
num_beams=4,
num_return_sequences=1,
return_full_text=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
)[0]["generated_text"].split("```")[0].split(";")[0].strip()+ ";"
return sqlparse.format(predictions, keyword_case="upper", reindent=True)
gradio_app = gr.Interface(
fn=predict,
inputs=[
gr.Textbox(lines=2, value="What are our top 3 products by sales in the New York region?", label="question"),
gr.Textbox(lines=20, value="""CREATE TABLE products (
product_id INTEGER PRIMARY KEY, -- Unique ID for each product
name VARCHAR(50), -- Name of the product
price DECIMAL(10,2), -- Price of each unit of the product
quantity INTEGER -- Current quantity in stock
);
CREATE TABLE customers (
customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
name VARCHAR(50), -- Name of the customer
address VARCHAR(100) -- Mailing address of the customer
);
CREATE TABLE salespeople (
salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
name VARCHAR(50), -- Name of the salesperson
region VARCHAR(50) -- Geographic sales region
);
CREATE TABLE sales (
sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
product_id INTEGER, -- ID of product sold
customer_id INTEGER, -- ID of customer who made purchase
salesperson_id INTEGER, -- ID of salesperson who made the sale
sale_date DATE, -- Date the sale occurred
quantity INTEGER -- Quantity of product sold
);
CREATE TABLE product_suppliers (
supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
product_id INTEGER, -- Product ID supplied
supply_price DECIMAL(10,2) -- Unit price charged by supplier
);
-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id""", label="database schema"),
],
outputs="text",
title="Convert an English question into SQL",
)
if __name__ == "__main__":
gradio_app.launch(share=True)