Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import random
import re
import sys
import threading
from dataclasses import asdict
from functools import lru_cache
from time import time
Expand Down Expand Up @@ -59,12 +58,10 @@
from nemoguardrails.llm.taskmanager import LLMTaskManager
from nemoguardrails.llm.types import Task
from nemoguardrails.logging.explain import LLMCallInfo
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig
from nemoguardrails.rails.llm.options import GenerationOptions
from nemoguardrails.streaming import StreamingHandler
from nemoguardrails.utils import (
get_or_create_event_loop,
new_event_dict,
new_uuid,
safe_eval,
Expand Down Expand Up @@ -99,16 +96,12 @@ def __init__(
self.user_message_index = None
self.bot_message_index = None
self.flows_index = None
self._init_lock = asyncio.Lock()

self.get_embedding_search_provider_instance = get_embedding_search_provider_instance

# There are still some edge cases not covered by nest_asyncio.
# Using a separate thread always for now.
loop = get_or_create_event_loop()
if True or check_sync_call_from_async_loop():
t = threading.Thread(target=asyncio.run, args=(self.init(),))
t.start()
t.join()
if self.config.colang_version == "2.x":
self._process_flows()

self.llm_task_manager = llm_task_manager

Expand All @@ -119,17 +112,6 @@ def __init__(
# calling the LLM with the user input.
self.passthrough_fn: Optional[Callable[..., Awaitable[str]]] = None

async def init(self):
# For Colang 2.x we need to do some initial processing
if self.config.colang_version == "2.x":
self._process_flows()

await asyncio.gather(
self._init_user_message_index(),
self._init_bot_message_index(),
self._init_flows_index(),
)

def _extract_user_message_example(self, flow: Flow) -> None:
"""Heuristic to extract user message examples from a flow."""
elements = [item for item in flow.elements if item["_type"] != "doc_string_stmt" and item["_type"] != "stmt"]
Expand Down Expand Up @@ -247,6 +229,9 @@ async def _init_bot_message_index(self):
if not self.bot_messages:
return

if not self.user_messages:
return

items = []
for intent, utterances in self.bot_messages.items():
for text in utterances:
Expand Down Expand Up @@ -297,6 +282,24 @@ async def _init_flows_index(self):
# NOTE: this should be very fast, otherwise needs to be moved to separate thread.
await self.flows_index.build()

async def _ensure_user_message_index(self):
if self.user_message_index is None and self.user_messages:
async with self._init_lock:
if self.user_message_index is None:
await self._init_user_message_index()

async def _ensure_bot_message_index(self):
if self.bot_message_index is None and self.bot_messages and self.user_messages:
async with self._init_lock:
if self.bot_message_index is None:
await self._init_bot_message_index()

async def _ensure_flows_index(self):
if self.flows_index is None and self.config.flows:
async with self._init_lock:
if self.flows_index is None:
await self._init_flows_index()

def _get_general_instructions(self):
"""Helper to extract the general instruction."""
text = ""
Expand Down Expand Up @@ -370,6 +373,8 @@ async def generate_user_intent(

streaming_handler = streaming_handler_var.get()

await self._ensure_user_message_index()

# TODO: check for an explicit way of enabling the canonical form detection

if self.user_messages:
Expand Down Expand Up @@ -621,6 +626,8 @@ async def generate_next_step(self, events: List[dict], llm: Optional[BaseLLM] =

user_intent = event["intent"]

await self._ensure_flows_index()

# We search for the most relevant similar flows
examples = ""
if self.flows_index:
Expand Down Expand Up @@ -919,6 +926,8 @@ async def generate_bot_message(self, events: List[dict], context: dict, llm: Opt
# Otherwise, we go through the process of creating the altered prompt,
# which includes examples, relevant chunks, etc.

await self._ensure_bot_message_index()

# We search for the most relevant similar bot utterance
examples = ""
# NOTE: disabling bot message index when there are no user messages
Expand Down Expand Up @@ -1042,6 +1051,8 @@ async def generate_value(
if not var_name:
var_name = last_event["action_result_key"]

await self._ensure_flows_index()

# We search for the most relevant flows.
examples = ""
if self.flows_index:
Expand Down Expand Up @@ -1119,6 +1130,10 @@ async def generate_intent_steps_message(

streaming_handler = streaming_handler_var.get()

await self._ensure_user_message_index()
await self._ensure_bot_message_index()
await self._ensure_flows_index()

if self.config.user_messages:
# TODO: based on the config we can use a specific canonical forms model
# or use the LLM to detect the canonical form. The below implementation
Expand Down
25 changes: 25 additions & 0 deletions nemoguardrails/actions/v2_x/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,24 @@ async def _init_flows_index(self) -> None:
if self.instruction_flows_index is None:
self.instruction_flows_index = self.flows_index

async def _ensure_flows_index(self):
if self.flows_index is None and self.config.flows:
async with self._init_lock:
if self.flows_index is None:
await self._init_flows_index()

async def _ensure_instruction_flows_index(self):
if not hasattr(self, "instruction_flows_index") or self.instruction_flows_index is None:
async with self._init_lock:
if not hasattr(self, "instruction_flows_index") or self.instruction_flows_index is None:
if self.config.flows:
await self._init_flows_index()

async def _collect_user_intent_and_examples(
self, state: State, user_action: str, max_example_flows: int
) -> Tuple[List[str], str, bool]:
await self._ensure_user_message_index()

# We search for the most relevant similar user intents
examples = ""
potential_user_intents = []
Expand Down Expand Up @@ -479,6 +494,8 @@ async def generate_flow_from_instructions(
) -> dict:
"""Generate a flow from the provided instructions."""

await self._ensure_instruction_flows_index()

if self.instruction_flows_index is None:
raise RuntimeError("No instruction flows index has been created.")

Expand Down Expand Up @@ -549,6 +566,9 @@ async def generate_flow_from_name(
) -> str:
"""Generate a flow from the provided NAME."""

await self._ensure_flows_index()
await self._ensure_instruction_flows_index()

if self.flows_index is None:
raise RuntimeError("No flows index has been created.")

Expand Down Expand Up @@ -609,6 +629,9 @@ async def generate_flow_continuation(
if temperature is None:
temperature = 0.0

await self._ensure_flows_index()
await self._ensure_instruction_flows_index()

if self.instruction_flows_index is None:
raise RuntimeError("No instruction flows index has been created.")

Expand Down Expand Up @@ -735,6 +758,8 @@ async def generate_value( # pyright: ignore (TODO - different arguments to base
# Use action specific llm if registered else fallback to main llm
generation_llm: Optional[Union[BaseLLM, BaseChatModel]] = llm if llm else self.llm

await self._ensure_flows_index()

# We search for the most relevant flows.
examples = ""
if self.flows_index:
Expand Down
Loading