Skip to content

Commit

Permalink
Basic Syllabus Bot for demo (#1961)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertrand authored Jan 14, 2025
1 parent f7417a5 commit 8275ce3
Show file tree
Hide file tree
Showing 14 changed files with 1,277 additions and 33 deletions.
111 changes: 111 additions & 0 deletions ai_chat/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import requests
from django.conf import settings
from django.core.cache import caches
from django.urls import reverse
from django.utils.module_loading import import_string
from llama_index.agent.openai import OpenAIAgent
from llama_index.core.agent import AgentRunner
Expand Down Expand Up @@ -481,3 +482,113 @@ def get_comment_metadata(self) -> str:
}
}
return json.dumps(metadata)


class SyllabusAgent(SearchAgent):
"""Service class for the AI syllabus agent"""

JOB_ID = "syllabus_agent"
TASK_NAME = "syllabus_task"

INSTRUCTIONS = """You are an assistant helping users answer questions related
to a syllabus.
Your job:
1. Use the available function to gather relevant information about the user's question.
2. Provide a clear, user-friendly summary of the information retrieved by the tool to
answer the user's question.
The tool knows which course the user is asking about, so you don't need to ask for it.
Always run the tool to answer questions, and answer only based on the tool
output. VERY IMPORTANT: NEVER USE ANY INFORMATION OUTSIDE OF THE TOOL OUTPUT TO
ANSWER QUESTIONS. If no results are returned, say you could not find any relevant
information.
"""

class SyllabusToolSchema(pydantic.BaseModel):
"""Schema for searching MIT contentfile chunks."""

def __init__( # noqa: PLR0913
self,
name: str,
*,
model: Optional[str] = None,
temperature: Optional[float] = None,
instructions: Optional[str] = None,
user_id: Optional[str] = None,
save_history: Optional[bool] = False,
cache_key: Optional[str] = None,
cache_timeout: Optional[int] = None,
):
"""Initialize the AI search agent service"""
super().__init__(
name,
model=model or settings.AI_MODEL,
temperature=temperature,
instructions=instructions,
save_history=save_history,
user_id=user_id,
cache_key=cache_key,
cache_timeout=cache_timeout or settings.AI_CACHE_TIMEOUT,
)
self.search_parameters = []
self.search_results = []
self.agent = self.create_agent()
self.create_agent()

def search_content_files(self) -> str:
"""
Query the MIT contentfile chunks API, and
return results as a JSON string
"""
url = settings.AI_MIT_SYLLABUS_URL or reverse(
"vector_search:v0:vector_content_files_search"
)
params = {
"q": self.user_message,
"resource_readable_id": self.readable_id,
"limit": 20,
}
self.search_parameters.append(params)
try:
response = requests.get(url, params=params, timeout=30)
response.raise_for_status()
raw_results = response.json().get("results", [])
# Simplify the response to only include the main properties
simplified_results = []
for result in raw_results:
simplified_result = {"chunk_content": result.get("chunk_content")}
simplified_results.append(simplified_result)
self.search_results.extend(simplified_results)
return json.dumps(simplified_results)
except requests.exceptions.RequestException:
log.exception("Error querying MIT API")
return json.dumps({"error": "An error occurred while searching"})

def create_tools(self):
"""Create tools required by the agent"""
return [self.create_search_tool()]

def create_search_tool(self) -> FunctionTool:
"""Create the search tool for the AI agent"""
metadata = ToolMetadata(
name="search_content_files",
description="Search for learning resources in the MIT catalog",
fn_schema=self.SyllabusToolSchema,
)
return FunctionTool.from_defaults(
fn=self.search_content_files, tool_metadata=metadata
)

def get_completion(
self, message: str, readable_id: str, *, debug: bool = settings.AI_DEBUG
) -> str:
"""
Get a response to the user's message. Use the exact user message as the
q parameter value for the vector search.
"""
self.user_message = message
self.readable_id = readable_id
historical_message = f"{message}\n\ncourse readable_id: {readable_id}"
return super().get_completion(historical_message, debug=debug)
6 changes: 6 additions & 0 deletions ai_chat/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,9 @@ def validate_instructions(self, value):
msg = "You are not allowed to modify the AI system prompt."
raise serializers.ValidationError(msg)
return value


class SyllabusChatRequestSerializer(ChatRequestSerializer):
"""DRF serializer for syllabus chatbot requests"""

readable_id = serializers.CharField(required=True)
5 changes: 5 additions & 0 deletions ai_chat/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

v0_urls = [
re_path(r"^chat_agent/", views.SearchAgentView.as_view(), name="chatbot_agent_api"),
re_path(
r"^syllabus_agent/",
views.SyllabusAgentView.as_view(),
name="syllabus_agent_api",
),
]
urlpatterns = [
re_path(r"^api/v0/", include((v0_urls, "v0"))),
Expand Down
55 changes: 55 additions & 0 deletions ai_chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,58 @@ def post(self, request: Request) -> StreamingHttpResponse:
content_type="text/event-stream",
headers={"X-Accel-Buffering": "no"},
)


class SyllabusAgentView(views.APIView):
"""
DRF view for an AI agent that answers user queries
by performing a relevant contentfile search for a
specified course.
"""

http_method_names = ["post"]
serializer_class = serializers.SyllabusChatRequestSerializer
permission_classes = (SearchAgentPermissions,) # Add IsAuthenticated

@extend_schema(
responses={
(200, "text/event-stream"): {
"description": "Chatbot response",
"type": "string",
}
}
)
def post(self, request: Request) -> StreamingHttpResponse:
"""Handle a POST request to the chatbot agent."""
from ai_chat.agents import SyllabusAgent

serializer = serializers.SyllabusChatRequestSerializer(
data=request.data, context={"request": request}
)
serializer.is_valid(raise_exception=True)
if not request.session.session_key:
request.session.save()
cache_id = (
request.user.email
if request.user.is_authenticated
else request.session.session_key
)
# Make anonymous users share a common LiteLLM budget/rate limit.
user_id = request.user.email if request.user.is_authenticated else "anonymous"
message = serializer.validated_data.pop("message", "")
readable_id = (serializer.validated_data.pop("readable_id"),)
clear_history = serializer.validated_data.pop("clear_history", False)
agent = SyllabusAgent(
"Learning Resource Search AI Assistant",
user_id=user_id,
cache_key=f"{cache_id}_search_chat_history",
save_history=True,
**serializer.validated_data,
)
if clear_history:
agent.clear_chat_history()
return StreamingHttpResponse(
agent.get_completion(message, readable_id),
content_type="text/event-stream",
headers={"X-Accel-Buffering": "no"},
)
Loading

0 comments on commit 8275ce3

Please sign in to comment.