forked from xusenlinzy/api-for-open-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsql_querier.py
111 lines (95 loc) · 4.14 KB
/
sql_querier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import sqlite3
import openai
from loguru import logger
def ask_database(conn, query):
"""Function to query SQLite database with a provided SQL query."""
try:
results = str(conn.execute(query).fetchall())
except Exception as e:
results = f"query failed with error: {e}"
return results
class SqlQuerier:
def __init__(self, openai_api_base, openai_api_key="xxx", db_path="Chinook.db"):
openai.api_base = openai_api_base
openai.api_key = openai_api_key
self.conn = sqlite3.connect(db_path)
logger.info("Opened database successfully")
def run(self, query, database_schema):
# Step 1: send the conversation and available functions to model
messages = [{"role": "user", "content": query}]
functions = [
{
"name_for_human":
"数据库查询工具",
"name_for_model":
"ask_database",
"description_for_model":
"该工具用来回答音乐相关的问题,输出应该是一个标准化的SQL查询语句。",
"parameters": [
{
'name': 'query',
'description': f"基于下面数据库表结构的SQL查询语句,用来回答用户问题。\n\n{database_schema}",
'required': True,
'schema': {
'type': 'string'
},
},
],
}
]
response = openai.ChatCompletion.create(
model="qwen",
messages=messages,
temperature=0,
functions=functions,
stop=["Observation:"]
)
logger.info(response["choices"][0]["message"]["function_call"])
answer = ""
response_message = response["choices"][0]["message"]
# Step 2: check if model wanted to call a function
if response_message.get("function_call"):
logger.info(f"Function call: {response_message['function_call']}")
# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
available_functions = {
"ask_database": ask_database,
} # only one function in this example
function_name = response_message["function_call"]["name"]
fuction_to_call = available_functions[function_name]
function_args = json.loads(response_message["function_call"]["arguments"])
logger.info(f"Function args: {function_args}")
function_response = fuction_to_call(self.conn, function_args["query"])
logger.info(f"Function response: {function_response}")
# Step 4: send the info on the function call and function response to model
messages.append(response_message) # extend conversation with assistant's reply
messages.append(
{
"role": "function",
"content": function_response,
}
) # extend conversation with function response
second_response = openai.ChatCompletion.create(
model="qwen",
messages=messages,
temperature=0,
functions=functions,
) # get a new response from model where it can see the function response
answer = second_response["choices"][0]["message"]["content"]
logger.info(f"Model output: {answer}")
j = answer.rfind("Final Answer:")
answer = answer[j + 14:] if answer else answer
return answer
if __name__ == '__main__':
database_schema = """create table albums
AlbumId INTEGER not null primary key autoincrement, --专辑ID
Title NVARCHAR(160) not null, --专辑名称
ArtistId INTEGER not null references artists --艺术家ID
);
"""
openai_api_base = "http://192.168.0.53:7891/v1"
query = "发行专辑最多的艺术家是谁?"
sql_querier = SqlQuerier(openai_api_base)
answer = sql_querier.run(query, database_schema)
print(answer)