Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/building_applications/rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ response = agent.create_turn(
],
documents=[
{
"content": "https://raw.githubusercontent.com/example/doc.rst",
"content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst",
"mime_type": "text/plain",
}
],
Expand Down
11 changes: 6 additions & 5 deletions llama_stack/distribution/routers/routing_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,14 @@ async def register_vector_db(
if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type
if len(self.impls_by_provider_id) == 1:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else:
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
raise ValueError("No provider available. Please configure a vector_io provider.")
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
raise ValueError(f"Model {embedding_model} not found")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,35 @@
AgentTurnResponseTurnCompletePayload,
StepType,
)
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.content_types import URL, TextDelta
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionMessage,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
UserMessage,
)
from llama_stack.apis.safety import RunShieldResponse
from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool,
ToolDef,
ToolGroup,
ToolHost,
ToolInvocationResult,
)
from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
Expand All @@ -54,36 +58,37 @@
class MockInferenceAPI:
async def chat_completion(
self,
model: str,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
async def stream_response():
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="start",
delta="",
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="progress",
delta="AI is a fascinating field...",
event_type=ChatCompletionResponseEventType.progress,
delta=TextDelta(text="AI is a fascinating field..."),
)
)

yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type="complete",
delta="",
stop_reason="end_of_turn",
event_type=ChatCompletionResponseEventType.complete,
delta=TextDelta(text=""),
stop_reason=StopReason.end_of_turn,
)
)

Expand Down Expand Up @@ -133,35 +138,39 @@ async def get_tool_group(self, toolgroup_id: str) -> ToolGroup:
provider_resource_id=toolgroup_id,
)

async def list_tool_groups(self) -> List[ToolGroup]:
return []

async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
if tool_group_id == MEMORY_TOOLGROUP:
return [
Tool(
identifier=MEMORY_QUERY_TOOL,
provider_resource_id=MEMORY_QUERY_TOOL,
toolgroup_id=MEMORY_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::rag",
parameters=[],
)
]
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
return [
Tool(
identifier="code_interpreter",
provider_resource_id="code_interpreter",
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::code_interpreter",
parameters=[],
)
]
return []
async def list_tool_groups(self) -> ListToolGroupsResponse:
return ListToolGroupsResponse(data=[])

async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
if toolgroup_id == MEMORY_TOOLGROUP:
return ListToolsResponse(
data=[
Tool(
identifier=MEMORY_QUERY_TOOL,
provider_resource_id=MEMORY_QUERY_TOOL,
toolgroup_id=MEMORY_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::rag",
parameters=[],
)
]
)
if toolgroup_id == CODE_INTERPRETER_TOOLGROUP:
return ListToolsResponse(
data=[
Tool(
identifier="code_interpreter",
provider_resource_id="code_interpreter",
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
tool_host=ToolHost.client,
description="Mock tool",
provider_id="builtin::code_interpreter",
parameters=[],
)
]
)
return ListToolsResponse(data=[])

async def get_tool(self, tool_name: str) -> Tool:
return Tool(
Expand All @@ -174,7 +183,7 @@ async def get_tool(self, tool_name: str) -> Tool:
parameters=[],
)

async def unregister_tool_group(self, tool_group_id: str) -> None:
async def unregister_tool_group(self, toolgroup_id: str) -> None:
pass


Expand Down Expand Up @@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
chat_agent = await impl.get_agent(response.agent_id)

tool_defs, _ = await chat_agent._get_tool_defs()
tool_defs_names = [t.tool_name for t in tool_defs]
if expected_memory:
assert MEMORY_QUERY_TOOL in tool_defs
assert MEMORY_QUERY_TOOL in tool_defs_names
if expected_code_interpreter:
assert BuiltinTool.code_interpreter in tool_defs
assert BuiltinTool.code_interpreter in tool_defs_names
if expected_memory and expected_code_interpreter:
# override the tools for turn
new_tool_defs, _ = await chat_agent._get_tool_defs(
Expand All @@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
)
]
)
assert MEMORY_QUERY_TOOL in new_tool_defs
assert BuiltinTool.code_interpreter not in new_tool_defs
new_tool_defs_names = [t.tool_name for t in new_tool_defs]
assert MEMORY_QUERY_TOOL in new_tool_defs_names
assert BuiltinTool.code_interpreter not in new_tool_defs_names