Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion agent/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from agent.component import component_class
from agent.component.base import ComponentBase
from api.db.services.file_service import FileService
from api.db.services.task_service import has_canceled
from api.utils import get_uuid, hash_str2int
from rag.prompts.generator import chunks_format
from rag.svr.task_executor import TaskCanceledException
from rag.utils.redis_conn import REDIS_CONN

class Graph:
Expand Down Expand Up @@ -126,6 +128,7 @@ def reset(self):
self.components[k]["obj"].reset()
try:
REDIS_CONN.delete(f"{self.task_id}-logs")
REDIS_CONN.delete(f"{self.task_id}-cancel")
except Exception as e:
logging.exception(e)

Expand Down Expand Up @@ -163,6 +166,17 @@ def get_variable_value(self, exp: str) -> Any:
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
return cpn["obj"].output(var_nm)

def is_canceled(self) -> bool:
return has_canceled(self.task_id)

def cancel_task(self) -> bool:
try:
REDIS_CONN.set(f"{self.task_id}-cancel", "x")
except Exception as e:
logging.exception(e)
return False
return True


class Canvas(Graph):

Expand All @@ -187,7 +201,7 @@ def load(self):
"sys.conversation_turns": 0,
"sys.files": []
}

self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", [])

Expand Down Expand Up @@ -250,10 +264,20 @@ def decorate(event, dt):
self.path.append("begin")
self.retrieval.append({"chunks": [], "doc_aggs": []})

if self.is_canceled():
msg = f"Task {self.task_id} has been canceled before starting."
logging.info(msg)
raise TaskCanceledException(msg)

yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}})

def _run_batch(f, t):
if self.is_canceled():
msg = f"Task {self.task_id} has been canceled during batch execution."
logging.info(msg)
raise TaskCanceledException(msg)

with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
for i in range(f, t):
Expand Down Expand Up @@ -401,6 +425,14 @@ def _extend_path(cpn_ids):
"created_at": st,
})
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
elif "Task has been canceled" in self.error:
yield decorate("workflow_finished",
{
"inputs": kwargs.get("inputs"),
"outputs": "Task has been canceled",
"elapsed_time": time.perf_counter() - st,
"created_at": st,
})

def is_reff(self, exp: str) -> bool:
exp = exp.strip("{").strip("}")
Expand Down
17 changes: 17 additions & 0 deletions agent/component/agent_with_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def get_input_form(self) -> dict[str, dict]:

@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Agent processing"):
return

if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
Expand All @@ -152,6 +155,8 @@ def _invoke(self, **kwargs):
self._param.prompts = [{"role": "user", "content": usr_pmt}]

if not self.tools:
if self.check_if_canceled("Agent processing"):
return
return LLM._invoke(self, **kwargs)

prompt, msg, user_defined_prompt = self._prepare_prompt_variables()
Expand All @@ -166,6 +171,8 @@ def _invoke(self, **kwargs):
use_tools = []
ans = ""
for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent processing"):
return
ans += delta_ans

if ans.find("**ERROR**") >= 0:
Expand All @@ -186,12 +193,16 @@ def stream_output_with_tools(self, prompt, msg, user_defined_prompt={}):
answer_without_toolcall = ""
use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt):
if self.check_if_canceled("Agent streaming"):
return

if delta_ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
yield self.get_exception_default_value()
else:
self.set_output("_ERROR", delta_ans)
return
answer_without_toolcall += delta_ans
yield delta_ans

Expand Down Expand Up @@ -266,6 +277,8 @@ def complete():
st = timer()
txt = ""
for delta_ans in self._gen_citations(entire_txt):
if self.check_if_canceled("Agent streaming"):
return
yield delta_ans, 0
txt += delta_ans

Expand All @@ -281,6 +294,8 @@ def append_user_content(hist, content):
task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
for _ in range(self._param.max_rounds + 1):
if self.check_if_canceled("Agent streaming"):
return
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
# self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk
Expand Down Expand Up @@ -328,6 +343,8 @@ def append_user_content(hist, content):
6. Focus on delivering VALUE with the information already gathered
Respond immediately with your final comprehensive answer.
"""
if self.check_if_canceled("Agent final instruction"):
return
append_user_content(hist, final_instruction)

for txt, tkcnt in complete():
Expand Down
14 changes: 14 additions & 0 deletions agent/component/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,20 @@ def __init__(self, canvas, id, param: ComponentParamBase):
self._param = param
self._param.check()

def is_canceled(self) -> bool:
return self._canvas.is_canceled()

def check_if_canceled(self, message: str = "") -> bool:
if self.is_canceled():
task_id = getattr(self._canvas, 'task_id', 'unknown')
log_message = f"Task {task_id} has been canceled"
if message:
log_message += f" during {message}"
logging.info(log_message)
self.set_output("_ERROR", "Task has been canceled")
return True
return False

def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter())
try:
Expand Down
6 changes: 6 additions & 0 deletions agent/component/begin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ class Begin(UserFillUp):
component_name = "Begin"

def _invoke(self, **kwargs):
if self.check_if_canceled("Begin processing"):
return

for k, v in kwargs.get("inputs", {}).items():
if self.check_if_canceled("Begin processing"):
return

if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
if v.get("optional") and v.get("value", None) is None:
v = None
Expand Down
11 changes: 11 additions & 0 deletions agent/component/categorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class Categorize(LLM, ABC):

@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Categorize processing"):
return

msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg:
msg = [{"role": "user", "content": ""}]
Expand All @@ -114,10 +117,18 @@ def _invoke(self, **kwargs):
---- Real Data ----
{} →
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))

if self.check_if_canceled("Categorize processing"):
return

ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
if ERROR_PREFIX in ans:
raise Exception(ans)

if self.check_if_canceled("Categorize processing"):
return

# Count the number of times each category appears in the answer.
category_counts = {}
for c in self._param.category_description.keys():
Expand Down
9 changes: 6 additions & 3 deletions agent/component/fillup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from agent.component.base import ComponentBase, ComponentParamBase


class UserFillUpParam(ComponentParamBase):

def __init__(self):
super().__init__()
self.enable_tips = True
Expand All @@ -31,10 +31,13 @@ class UserFillUp(ComponentBase):
component_name = "UserFillUp"

def _invoke(self, **kwargs):
if self.check_if_canceled("UserFillUp processing"):
return

for k, v in kwargs.get("inputs", {}).items():
if self.check_if_canceled("UserFillUp processing"):
return
self.set_output(k, v)

def thoughts(self) -> str:
return "Waiting for your input..."


9 changes: 9 additions & 0 deletions agent/component/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class Invoke(ComponentBase, ABC):

@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Invoke processing"):
return

args = {}
for para in self._param.variables:
if para.get("value"):
Expand Down Expand Up @@ -89,6 +92,9 @@ def replace_variable(match):

last_e = ""
for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("Invoke processing"):
return

try:
if method == "get":
response = requests.get(url=url, params=args, headers=headers, proxies=proxies, timeout=self._param.timeout)
Expand Down Expand Up @@ -121,6 +127,9 @@ def replace_variable(match):

return self.output("result")
except Exception as e:
if self.check_if_canceled("Invoke processing"):
return

last_e = e
logging.exception(f"Http request error: {e}")
time.sleep(self._param.delay_after_error)
Expand Down
3 changes: 3 additions & 0 deletions agent/component/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def get_start(self):
return cid

def _invoke(self, **kwargs):
if self.check_if_canceled("Iteration processing"):
return

arr = self._canvas.get_variable_value(self._param.items_ref)
if not isinstance(arr, list):
self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is "+str(type(arr)))
Expand Down
10 changes: 9 additions & 1 deletion agent/component/iterationitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,27 @@ def __init__(self, canvas, id, param: ComponentParamBase):
self._idx = 0

def _invoke(self, **kwargs):
if self.check_if_canceled("IterationItem processing"):
return

parent = self.get_parent()
arr = self._canvas.get_variable_value(parent._param.items_ref)
if not isinstance(arr, list):
self._idx = -1
raise Exception(parent._param.items_ref + " must be an array, but its type is "+str(type(arr)))

if self._idx > 0:
if self.check_if_canceled("IterationItem processing"):
return
self.output_collation()

if self._idx >= len(arr):
self._idx = -1
return

if self.check_if_canceled("IterationItem processing"):
return

self.set_output("item", arr[self._idx])
self.set_output("index", self._idx)

Expand Down Expand Up @@ -80,4 +88,4 @@ def end(self):
return self._idx == -1

def thoughts(self) -> str:
return "Next turn..."
return "Next turn..."
14 changes: 13 additions & 1 deletion agent/component/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def delta(txt):

@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("LLM processing"):
return

def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
Expand All @@ -219,6 +222,9 @@ def clean_formated_answer(ans: str) -> str:
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
prompt += "\nRedundant information is FORBIDDEN."
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("LLM processing"):
return

_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
error = ""
ans = self._generate(msg)
Expand All @@ -244,6 +250,9 @@ def clean_formated_answer(ans: str) -> str:
return

for _ in range(self._param.max_retries+1):
if self.check_if_canceled("LLM processing"):
return

_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
error = ""
ans = self._generate(msg)
Expand All @@ -265,6 +274,9 @@ def _stream_output(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
for ans in self._generate_streamly(msg):
if self.check_if_canceled("LLM streaming"):
return

if ans.find("**ERROR**") >= 0:
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
Expand All @@ -283,4 +295,4 @@ def add_memory(self, user:str, assist:str, func_name: str, params: dict, results

def thoughts(self) -> str:
_, msg,_ = self._prepare_prompt_variables()
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move."
return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move."
Loading