Skip to content

Commit

Permalink
improv
Browse files Browse the repository at this point in the history
  • Loading branch information
apocas committed Feb 1, 2025
1 parent 4b866c3 commit d4f2091
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 49 deletions.
5 changes: 2 additions & 3 deletions app/models/databasemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ class ProjectDatabase(Base):
creator = Column(Integer)
public = Column(Boolean, default=False)
default_prompt = Column(Text)
users = relationship('UserDatabase', secondary=users_projects, back_populates='projects')
entrances = relationship("RouterEntrancesDatabase", back_populates="project")
#outputs = relationship("OutputDatabase", back_populates="project")
users = relationship('UserDatabase', secondary=users_projects, back_populates='projects', lazy="select")
entrances = relationship("RouterEntrancesDatabase", back_populates="project", lazy="select")

class UserDatabase(Base):
__tablename__ = "users"
Expand Down
7 changes: 4 additions & 3 deletions app/projects/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from app import tools
from app.chat import Chat
from app.database import DBWrapper
from app.guard import Guard
Expand Down Expand Up @@ -54,10 +53,12 @@ def chat(self, project: Project, chat_model: ChatModel, user: User, db: DBWrappe
try:
if chat_model.stream:
resp_gen = model.llm.stream_chat(messages)
response = ""
# Collect parts instead of appending string repeatedly
parts = []
for text in resp_gen:
response += text.delta
parts.append(text.delta)
yield "data: " + json.dumps({"text": text.delta}) + "\n\n"
response = "".join(parts)
output["answer"] = response
chat.memory.chat_store.add_message(chat.memory.chat_store_key,
ChatMessage(role=MessageRole.ASSISTANT, content=response))
Expand Down
71 changes: 28 additions & 43 deletions app/routers/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,26 @@

router = APIRouter()

# New helper to fetch project once
def get_project(projectName: str, db_wrapper: DBWrapper, brain):
project = brain.find_project(projectName, db_wrapper)
if project is None:
raise HTTPException(status_code=404, detail="Project not found")
return project

@router.get("/projects", response_model=ProjectsResponse)
async def route_get_projects(_: Request,
v_filter: str = Query("own", alias="filter"),
user: User = Depends(get_current_username),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
projects = []
# get projects only once instead of iterating multiple times
all_projects = db_wrapper.get_projects()
if v_filter == "own":
if user.is_admin:
projects = db_wrapper.get_projects()
else:
for project in user.projects:
for p in db_wrapper.get_projects():
if project.name == p.name:
projects.append(p)
projects = all_projects if user.is_admin else [p for p in all_projects if p.name in {proj.name for proj in user.projects}]
elif v_filter == "public":
for project in db_wrapper.get_projects():
if project.public:
projects.append(project)

projects = [p for p in all_projects if p.public]
else:
projects = []
return {"projects": projects}


Expand All @@ -64,11 +64,7 @@ async def route_get_project(request: Request,
user: User = Depends(get_current_username_project_public),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
project = request.app.state.brain.find_project(projectName, db_wrapper)

if project is None:
raise HTTPException(
status_code=404, detail='Project not found')
project = get_project(projectName, db_wrapper, request.app.state.brain)

output = project.model.model_dump()
final_output = {}
Expand Down Expand Up @@ -141,14 +137,10 @@ async def route_delete_project(request: Request,
_: User = Depends(get_current_username_project),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
proj = request.app.state.brain.find_project(projectName, db_wrapper)
proj = get_project(projectName, db_wrapper, request.app.state.brain)

if proj is not None:
db_wrapper.delete_project(db_wrapper.get_project_by_name(projectName))
proj.delete()
else:
raise HTTPException(
status_code=404, detail='Project not found')
db_wrapper.delete_project(db_wrapper.get_project_by_name(projectName))
proj.delete()

return {"project": projectName}

Expand Down Expand Up @@ -278,7 +270,7 @@ async def reset_embeddings(
_: User = Depends(get_current_username_project),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
project = request.app.state.brain.find_project(projectName, db_wrapper)
project = get_project(projectName, db_wrapper, request.app.state.brain)

if project.model.type != "rag":
raise HTTPException(
Expand All @@ -298,10 +290,7 @@ async def clone_project(request: Request, projectName: str, newProjectName: str,
_: User = Depends(get_current_username_project),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
project = request.app.state.brain.find_project(projectName, db_wrapper)
if project is None:
raise HTTPException(
status_code=404, detail='Project not found')
project = get_project(projectName, db_wrapper, request.app.state.brain)

newProject = db_wrapper.get_project_by_name(newProjectName)
if newProject is not None:
Expand Down Expand Up @@ -352,7 +341,7 @@ async def find_embedding(request: Request, projectName: str, embedding: FindMode
_: User = Depends(get_current_username_project_public),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
project = request.app.state.brain.find_project(projectName, db_wrapper)
project = get_project(projectName, db_wrapper, request.app.state.brain)

if project.model.type != "rag":
raise HTTPException(
Expand Down Expand Up @@ -401,7 +390,7 @@ async def get_embedding(request: Request, projectName: str, source: str,
_: User = Depends(get_current_username_project_public),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
project = request.app.state.brain.find_project(projectName, db_wrapper)
project = get_project(projectName, db_wrapper, request.app.state.brain)

if project.model.type != "rag":
raise HTTPException(
Expand All @@ -425,7 +414,7 @@ async def get_embedding(request: Request, projectName: str,
_: User = Depends(get_current_username_project_public),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
project = request.app.state.brain.find_project(projectName, db_wrapper)
project = get_project(projectName, db_wrapper, request.app.state.brain)

if project.model.type != "rag":
raise HTTPException(
Expand All @@ -444,7 +433,7 @@ async def ingest_text(request: Request, projectName: str, ingest: TextIngestMode
_: User = Depends(get_current_username_project),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
project = request.app.state.brain.find_project(projectName, db_wrapper)
project = get_project(projectName, db_wrapper, request.app.state.brain)

if project.model.type != "rag":
raise HTTPException(
Expand Down Expand Up @@ -486,7 +475,7 @@ async def ingest_url(request: Request, projectName: str, ingest: URLIngestModel,
raise HTTPException(
status_code=400, detail="Specify the protocol http:// or https://")

project = request.app.state.brain.find_project(projectName, db_wrapper)
project = get_project(projectName, db_wrapper, request.app.state.brain)

if project.model.type != "rag":
raise HTTPException(
Expand Down Expand Up @@ -526,7 +515,7 @@ async def ingest_file(
from llama_index.readers.docling import DoclingReader


project = request.app.state.brain.find_project(projectName, db_wrapper)
project = get_project(projectName, db_wrapper, request.app.state.brain)

if project.model.type != "rag":
raise HTTPException(
Expand Down Expand Up @@ -604,7 +593,7 @@ async def get_embeddings(
_: User = Depends(get_current_username_project_public),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
project = request.app.state.brain.find_project(projectName, db_wrapper)
project = get_project(projectName, db_wrapper, request.app.state.brain)

if project.model.type != "rag":
raise HTTPException(
Expand All @@ -630,7 +619,7 @@ async def delete_embedding(
_: User = Depends(get_current_username_project),
db_wrapper: DBWrapper = Depends(get_db_wrapper)):
try:
project = request.app.state.brain.find_project(projectName, db_wrapper)
project = get_project(projectName, db_wrapper, request.app.state.brain)

if project.model.type != "rag":
raise HTTPException(
Expand Down Expand Up @@ -658,9 +647,7 @@ async def chat_query(
raise HTTPException(
status_code=400, detail="Missing question")

project = request.app.state.brain.find_project(projectName, db_wrapper)
if project is None:
raise Exception("Project not found")
project = get_project(projectName, db_wrapper, request.app.state.brain)

return await chat_main(request, request.app.state.brain, project, q_input, user, db_wrapper, background_tasks)
except Exception as e:
Expand All @@ -682,9 +669,7 @@ async def question_query_endpoint(
raise HTTPException(
status_code=400, detail="Question missing")

project = request.app.state.brain.find_project(projectName, db_wrapper)
if project is None:
raise Exception("Project not found")
project = get_project(projectName, db_wrapper, request.app.state.brain)

if user.level == "public":
q_input = QuestionModel(question=q_input.question, image=q_input.image, negative=q_input.negative)
Expand Down

0 comments on commit d4f2091

Please sign in to comment.