-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathinterrupt_user.py
88 lines (69 loc) · 3.21 KB
/
interrupt_user.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# This agent keeps track of the number of sentences the user has spoken
# and interrupts them if they've said a certain number of sentences.
# We use session.say() to interrupt the user, and set allow_interruptions=False
# on that specific call to prevent the user from interrupting the agent.
# After the agent has spoken, allow_interruptions is once again True so the agent
# can listen for the user's response.
import logging
from pathlib import Path
from dotenv import load_dotenv
from livekit.agents import JobContext, WorkerOptions, cli
from livekit.agents.voice import Agent, AgentSession
from livekit.plugins import openai, deepgram, silero
from livekit.agents.llm import ChatContext, ChatMessage
import re
import asyncio
load_dotenv(dotenv_path=Path(__file__).parent.parent / '.env')
logger = logging.getLogger("interrupt-user")
logger.setLevel(logging.INFO)
def count_sentences(text):
"""Count the number of sentences in text"""
sentences = re.findall(r'[^.!?]+[.!?](?:\s|$)', text)
return len(sentences)
async def entrypoint(ctx: JobContext):
await ctx.connect()
session = AgentSession()
agent = Agent(
instructions="You are a helpful agent that politely interrupts users when they talk too much.",
stt=deepgram.STT(),
llm=openai.LLM(),
tts=openai.TTS(),
vad=silero.VAD.load()
)
async def handle_interruption(context):
await agent.update_chat_ctx(context)
session.say("Sorry, can I pause you there?", allow_interruptions=False)
await session.generate_reply(allow_interruptions=False)
transcript_buffer = ""
max_sentences = 3
@session.on("user_input_transcribed")
def on_transcript(transcript):
nonlocal transcript_buffer
if transcript.is_final:
logger.info(f"Received final transcript: {transcript.transcript}")
return
transcript_buffer += " " + transcript.transcript
transcript_buffer = transcript_buffer.strip()
logger.info(f"Buffer: {transcript_buffer}")
sentence_count = count_sentences(transcript_buffer)
logger.info(f"Sentence count: {sentence_count}")
if sentence_count >= max_sentences:
logger.info("Interrupting user...")
interruption_ctx = ChatContext([
ChatMessage(
type="message",
role="system",
content=["You are an agent that politely interrupts users who speak too much. Create a brief response that acknowledges what they've said so far, then redirects to get more focused information."]
),
ChatMessage(type="message", role="user", content=[f"User has been speaking and said: {transcript_buffer}"])
])
asyncio.create_task(handle_interruption(interruption_ctx))
transcript_buffer = ""
@session.on("session_start")
def on_session_start():
nonlocal transcript_buffer
transcript_buffer = ""
session.generate_reply()
await session.start(agent=agent, room=ctx.room)
if __name__ == "__main__":
cli.run_app(WorkerOptions(entrypoint_fnc=entrypoint))