Skip to content

Commit

Permalink
generate voice messsage
Browse files Browse the repository at this point in the history
  • Loading branch information
AnniePacheco committed Sep 27, 2024
1 parent 1c4bf87 commit 0a008df
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 49 deletions.
96 changes: 89 additions & 7 deletions apis/paios/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,67 @@ paths:
headers:
X-Total-Count:
$ref: '#/components/headers/X-Total-Count'
post:
summary: Create new voice message.
description: Creates a new voice message
operationId: backend.api.VoicesFacesView.post
tags:
- Voice Management
responses:
'200':
description: OK
content:
application/json:
schema:
type: object
properties:
chat_response:
type: string
description: The chat response message.
headers:
Content-Disposition:
description: Indicates that the response includes an attachment.
schema:
type: string
'201':
description: Created
content:
application/json:
schema:
$ref: '#/components/schemas/VoiceMessage'
'206':
description: Partial Content (MP3 file streamed)
content:
audio/mpeg:
schema:
type: string
format: binary
headers:
Content-Disposition:
description: Indicates that the response includes an attachment.
schema:
type: string
'400':
description: Missing Required Information
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/VoiceMessageCreate'
'/voices/{id}':
delete:
summary: Delete voice message by id
description: Deletes the voice message with the given id.
operationId: backend.api.VoicesFacesView.delete
tags:
- Voice Management
parameters:
- $ref: '#/components/parameters/id'
responses:
'204':
description: No Content
'404':
description: Not Found
tags:
- name: Abilities Management
description: Installation and configuration of abilities
Expand Down Expand Up @@ -1386,22 +1447,38 @@ components:
type: object
title: VoiceCreate
description: Voice without id which is server-generated.
properties:
properties:
xi_id:
$ref: '#/components/schemas/textShort'
name:
$ref: '#/components/schemas/textShort'
properties:
xi_id:
$ref: '#/components/schemas/textShort'
name:
$ref: '#/components/schemas/textShort'
text_to_speak:
type: string
nullable: true
image_url:
$ref: '#/components/schemas/voice_image_url'
VoiceMessage:
type: object
title: VoiceMessage
properties:
audio_msg_path:
type: string
format: uri
nullable: true
sample_url:
$ref: '#/components/schemas/sample_mp3_url'
nullable: true
required:
- id
- audio_msg_path
VoiceMessageCreate:
type: object
title: VoiceMessageCreate
description: Voice without id which is server-generated.
properties:
msg_id:
$ref: '#/components/schemas/uuid4'
required:
- msg_id
Face:
type: object
title: Face
Expand Down Expand Up @@ -1429,6 +1506,9 @@ components:
$ref: '#/components/schemas/textLong'
chat_response:
type: string
voice_active:
$ref: '#/components/schemas/boolean_str'
nullable: true
required:
- assistant_id
- conversation_id
Expand All @@ -1447,6 +1527,8 @@ components:
nullable: true
prompt:
$ref: '#/components/schemas/textLong'
voice_active:
$ref: '#/components/schemas/boolean_str'
required:
- assistant_id
- prompt
Expand Down
10 changes: 5 additions & 5 deletions backend/api/MessagesView.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from starlette.responses import JSONResponse, Response
from starlette.responses import JSONResponse
from backend.managers.MessagesManager import MessagesManager
from common.paths import api_base_url
from backend.pagination import parse_pagination_params
from backend.schemas import MessageCreateSchema

from starlette.responses import JSONResponse

class MessagesView:
def __init__(self):
Expand All @@ -12,10 +12,10 @@ def __init__(self):
async def post(self, body: MessageCreateSchema):
response, error_message = await self.mm.create_message(body)
if error_message:
return JSONResponse({"error": error_message}, status_code=404)
return JSONResponse({"error": error_message}, status_code=404)

if body.get("conversation_id"): # If conversation_id was provided, retrieve the full message
if body.get("conversation_id"):
message = await self.mm.retrieve_message(response)
return JSONResponse(message.dict(), status_code=201, headers={'Location': f'{api_base_url}/messages/{response}'})
else: # If conversation_id was not provided, return only the chat_response
else:
return JSONResponse({"chat_response": response}, status_code=200)
40 changes: 29 additions & 11 deletions backend/api/VoicesFacesView.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from starlette.responses import JSONResponse
from starlette.responses import JSONResponse, Response
from backend.managers.VoicesFacesManager import VoicesFacesManager
from backend.managers.MessagesManager import MessagesManager
from common.paths import api_base_url
from backend.pagination import parse_pagination_params
from backend.schemas import VoiceCreateSchema
from starlette.responses import JSONResponse, StreamingResponse
import os
from pathlib import Path
import shutil



class VoicesFacesView:
def __init__(self):
self.vfm = VoicesFacesManager()
# TODO: Finish text to speech
async def post(self, id: str, body: VoiceCreateSchema):
response, error_message = await self.vfm.text_to_speech(id, body)
if error_message:
return JSONResponse({"error": error_message}, status_code=404)
else:
return JSONResponse(response, status_code=200)

self.vfm = VoicesFacesManager()
self.mm = MessagesManager()

async def search(self, filter: str = None, range: str = None, sort: str = None):
result = parse_pagination_params(filter, range, sort)
Expand All @@ -36,4 +34,24 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):
'X-Total-Count': str(total_count),
'Content-Range': f'voices {offset}-{offset + len(voices) - 1}/{total_count}'
}
return JSONResponse([voice.dict() for voice in voices], status_code=200, headers=headers)
return JSONResponse([voice.dict() for voice in voices], status_code=200, headers=headers)


async def post(self, body: VoiceCreateSchema):
msg_id = body.get('msg_id')
message = await self.mm.retrieve_message(msg_id)
if message.voice_active == 'True':
audio_msg_path, error_message = await self.vfm.generate_voice_response(message.assistant_id, message.chat_response, msg_id)
if error_message:
return JSONResponse({"error": error_message}, status_code=404)
if audio_msg_path and os.path.exists(audio_msg_path):
streaming_response = StreamingResponse(
self.vfm.async_file_generator(audio_msg_path),
media_type='audio/mpeg',
headers={
"Content-Disposition": f"attachment; filename={os.path.basename(audio_msg_path)}"
}
)
return streaming_response
else:
return JSONResponse({"error": "File not found"}, status_code=404)
9 changes: 5 additions & 4 deletions backend/managers/MessagesManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ async def create_message(self, message_data: MessageCreateSchema) -> Tuple[Optio
rm = RagManager()
response = await rm.retrieve_and_generate(assistant_id, query, llm)
chat_response = response["answer"]

if conversation_id:
message_data["chat_response"] = chat_response
message_data['chat_response'] = chat_response
message_data['timestamp'] = timestamp

new_message = Message(id=str(uuid4()), **message_data)
session.add(new_message)
await session.commit()
Expand All @@ -100,7 +100,8 @@ async def retrieve_message(self, id:str) -> Optional[MessageSchema]:
conversation_id=message.conversation_id,
timestamp=message.timestamp,
prompt=message.prompt,
chat_response=message.chat_response,
chat_response=message.chat_response,
voice_active=message.voice_active
)
return None

Expand Down
70 changes: 51 additions & 19 deletions backend/managers/VoicesFacesManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import List, Tuple, Optional, Dict, Any
from sqlalchemy import select, func
from backend.schemas import VoiceSchema
from pathlib import Path
from backend.models import Resource, Persona


class VoicesFacesManager:
Expand Down Expand Up @@ -112,21 +114,25 @@ async def retrieve_voices(self, offset: int = 0, limit: int = 100, sort_by: Opti

return voices, total_count


async def text_to_speech(self, voice_id: str, body) -> str:
async def text_to_speech(self, voice_id: str, body: str, assistant_id: str, msg_id: str) -> Tuple[Optional[dict], Optional[str]]:
voice = await self.retrieve_voice(voice_id)
xi_api_key = os.environ.get('XI_API_KEY')
xi_id = body['xi_id']
OUTPUT_PATH = f"{xi_id}.mp3" # Path to save the output audio file

tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{voice_id}/stream"
xi_id = voice.xi_id
temp = os.path.dirname(os.path.realpath(__file__))
directory = Path(os.path.join(os.path.dirname(temp), f'public/{assistant_id}'))
directory.mkdir(parents=True, exist_ok=True)
file_path = directory / f'{msg_id}.mp3'
audio_msg_path = str(file_path)
print('audio_msg_path: ', audio_msg_path)
tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{xi_id}/stream"

headers = {
"Accept": "application/json",
"xi-api-key": xi_api_key
}

data = {
"text": "Hola" ,
"text": body,
"model_id": "eleven_multilingual_v2",
"voice_settings": {
"stability": 0.5,
Expand All @@ -135,16 +141,42 @@ async def text_to_speech(self, voice_id: str, body) -> str:
"use_speaker_boost": True
}
}

response = requests.post(tts_url, headers=headers, json=data, stream=True)

if response.ok:
with open(OUTPUT_PATH, "wb") as f:
for chunk in response.iter_content(chunk_size=os.environ.get('XI_CHUNK_SIZE')):
f.write(chunk)
print("Audio stream saved successfully.")
else:
print(response.text)
return response.text
try:
response = requests.post(tts_url, headers=headers, json=data, stream=True)
if response.ok:
with open(audio_msg_path, "wb") as f:
for chunk in response.iter_content(chunk_size=int(os.environ.get('XI_CHUNK_SIZE'))):
f.write(chunk)
print("Audio stream saved successfully.")
return {"message": "Audio stream saved successfully.", "audio_msg_path": audio_msg_path}, None
else:
error_message = response.text
print(error_message)
return None, error_message
except Exception as e:
error_message = str(e)
print(f"An error occurred: {error_message}")
return None, error_message


async def generate_voice_response(self, assistant_id, chat_response, message_id) -> Tuple[Optional[dict], Optional[str]]:
try:
async with db_session_context() as session:
resource = await session.execute(select(Resource).filter(Resource.id == assistant_id))
resource = resource.scalar_one_or_none()
persona_id = resource.persona_id
if persona_id:
persona = await session.execute(select(Persona).filter(Persona.id == persona_id))
persona = persona.scalar_one_or_none()
voice_id = persona.voice_id
vfm = VoicesFacesManager()
response, error_message = await vfm.text_to_speech(voice_id, chat_response, assistant_id, message_id)
if error_message:
return None, error_message
return response.get("audio_msg_path"), None
except Exception as e:
return None, f"An unexpected error occurred while generating a voice response: {str(e)}"

async def async_file_generator(self, file_path):
with open(file_path, 'rb') as audio_file:
while chunk := audio_file.read(1024):
yield chunk
1 change: 1 addition & 0 deletions backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class Message(Base):
timestamp = Column(String, nullable=False)
prompt = Column(String, nullable=False)
chat_response = Column(String, nullable=False)
voice_active = Column(String, nullable=False)

class Conversation(Base):
__tablename__ = "conversation"
Expand Down
9 changes: 6 additions & 3 deletions backend/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ class VoiceBaseSchema(BaseModel):
name: str
text_to_speak: Optional[str] = None
image_url: Optional[str] = None
sample_url: Optional[str] = None
sample_url: Optional[str] = None
msg_id: Optional[str] = None
audio_msg_path: Optional[str] = None
class Config:
orm_mode = True
from_attributes = True
Expand Down Expand Up @@ -119,6 +121,7 @@ class MessageBaseSchema(BaseModel):
timestamp: str
prompt: str
chat_response: str
voice_active: str
class Config:
orm_mode = True
from_attributes = True
Expand All @@ -136,7 +139,7 @@ class ConversationBaseSchema(BaseModel):
last_updated_timestamp: str
archive: str
assistant_id: str
messages: Optional[List[MessageBaseSchema]] = None
messages: Optional[List[MessageSchema]] = None

class Config:
orm_mode = True
Expand All @@ -162,4 +165,4 @@ class FileCreateSchema(FileBaseSchema):
pass

class FileSchema(FileBaseSchema):
id: str
id: str
1 change: 1 addition & 0 deletions migrations/versions/54bbca89da7b_added_message_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def upgrade() -> None:
sa.Column('timestamp', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('prompt', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('chat_response', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('voice_active', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint('id')
)

Expand Down

0 comments on commit 0a008df

Please sign in to comment.