Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(nyz): add web search + ReAct demo in PsyDI #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions backend/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import tiktoken


def web_search(query: str) -> str:
return "这是一个虚拟的搜索结果"


class LLMAPI:

def __init__(self):
Expand All @@ -26,3 +30,96 @@ def call(self, messages: List[str]) -> Tuple[str, int]:
)
content = response.choices[0].message.content.strip()
return content, token_count


class ReActLLMAPI(LLMAPI):
react_prompt_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
{tool_description}
互联网搜索引擎(Search):使用多种中文互联网搜索引擎查找信息,用于获取任何你不知道的词汇/特定概念的信息,从而基于查询到的信息提升回答效果。
如果使用工具请遵循以下格式回复:
```
思路(Thought):思考你当前步骤需要解决什么问题,是否需要使用工具
工具(Action):计划使用的工具名称,你的工具必须从 [Search, ] 中选择
工具输入(Action Input):工具输入参数
```
工具返回按照以下格式回复:
```
响应结果(Response):调用工具后的结果
```
如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
```
思路(Thought):给出最终答案的思考过程
结果(Finish):最终答案
```
开始!
"""
max_turns_prompt_template = """你需要基于历史消息整合返回一个最终答案"""

def __init__(self, max_turns: int, **kwargs):
super().__init__(**kwargs)
self.max_turns = max_turns

def call(self, messages: List[str]) -> Tuple[str, int]:
token_count = 0
for turn in range(self.max_turns):
messages = self._format_messages(messages)
content, token_count_per_turn = super().call(messages)
token_count += token_count_per_turn
thought, action, action_input, finish = self._parse_content(content)
if finish:
return finish, token_count

action_output = self._execute_action(action, action_input)
messages = self._parse_action_output(messages, action_output)

messages = self._format_messages_max_turns(messages)
content, token_count_per_turn = super().call(messages)
token_count += token_count_per_turn
return content, token_count

def _format_messages(self, messages: List[str]) -> str:
new_messages = []
new_messages.append(dict(role="system", content=self.react_prompt_template))
new_messages += messages
return new_messages

def _format_messages_max_turns(self, messages: List[str]) -> str:
new_messages = messages
new_messages.append(dict(role="system", content=self.max_turns_prompt_template))
return new_messages

def _parse_content(self, content: str) -> Tuple[str, str, str, str]:
lines = content.split("\n")
thought = ""
action = ""
action_input = ""
finish = ""
print('\ncontent:', content)
for line in lines:
if line.startswith("思路(Thought):"):
thought = line.split(":", 1)[1]
elif line.startswith("工具(Action):"):
action = line.split(":", 1)[1]
elif line.startswith("工具输入(Action Input):"):
action_input = line.split(":", 1)[1]
elif line.startswith("结果(Finish):"):
finish = line.split(":", 1)[1]
return thought, action, action_input, finish

def _parse_action_output(self, messages: List[str], action_output: str) -> List[str]:
new_messages = messages
new_messages.append(dict(role="system", content=f"调用工具的响应结果(Response):{action_output}\n请你结合这个结果继续对话\n"))
return new_messages

def _execute_action(self, action: str, action_input: str) -> str:
if action == "Search":
return web_search(action_input)
else:
raise RuntimeError("不支持的工具")


if __name__ == "__main__":
llm = LLMAPI()
print(llm.call([{"role": "user", "content": "你能跟我讲讲零元购吗"}])[0])
react_llm = ReActLLMAPI(max_turns=3)
print(react_llm.call([{"role": "user", "content": "你能跟我讲讲零元购吗"}])[0])