From 685937b161557f8d03477364cfc703df0da11bbf Mon Sep 17 00:00:00 2001 From: Eileen Li Date: Fri, 17 Oct 2025 10:23:30 -0700 Subject: [PATCH] feat: support thread --- src/components/helper.ts | 20 +- src/components/index.ts | 6 +- src/components/input/baseInput/baseInput.tsx | 1 + .../messageArchive/messageArchive.cy.tsx | 255 ++++++++++++++- .../messageArchive/messageArchive.tsx | 19 +- .../messageCanvas/messageCanvas.css | 17 + .../messageCanvas/messageCanvas.stories.tsx | 37 +++ .../messageCanvas/messageCanvas.tsx | 48 ++- .../messageSpace/messageSpace.cy.tsx | 251 +++++++++++++++ .../messageSpace/messageSpace.stories.tsx | 301 +++++++++++++++++- src/components/messageSpace/messageSpace.tsx | 36 ++- src/components/sharedDescription.ts | 7 + src/components/types.ts | 2 + 13 files changed, 963 insertions(+), 37 deletions(-) diff --git a/src/components/helper.ts b/src/components/helper.ts index 46b02e52..410357b7 100644 --- a/src/components/helper.ts +++ b/src/components/helper.ts @@ -202,19 +202,21 @@ export function toChatRequest( } } +export function getMessageIdentifier(message: Message): string { + if (message.format.startsWith('update') && message.data?.updateId) { + return message.data.updateId + } else if (message.format.includes('Response') && message.data?.inReplyTo) { + return message.data.inReplyTo + } else { + return message.id + } +} + export function getCombinedMessages( messages: { [key: string]: Message[] }, message: Message ) { - let key - - if (message.format.startsWith('update')) { - key = message.data.updateId - } else if (message.format.includes('Response')) { - key = message.data.inReplyTo - } else { - key = message.id - } + let key = getMessageIdentifier(message) if (!key) { return messages diff --git a/src/components/index.ts b/src/components/index.ts index 4b2ecded..481d55db 100644 --- a/src/components/index.ts +++ b/src/components/index.ts @@ -71,7 +71,11 @@ export { } export * from '../rusticTheme' -export { formatDateAndTime, getCombinedMessages } from './helper' +export { + formatDateAndTime, + getCombinedMessages, + getMessageIdentifier, +} from './helper' export type * from './types' export { ParticipantRole, ParticipantType } from './types' export type * from './visualization/mermaidViz/mermaidViz.types' diff --git a/src/components/input/baseInput/baseInput.tsx b/src/components/input/baseInput/baseInput.tsx index 5e783172..39e57562 100644 --- a/src/components/input/baseInput/baseInput.tsx +++ b/src/components/input/baseInput/baseInput.tsx @@ -298,6 +298,7 @@ function BaseInputElement( data: { text: messageText }, inReplyTo: props.lastMsg?.id, messageHistory: props.lastMsg?.messageHistory, + ...(props.threads && { threads: props.threads }), } props.send(formattedMessage) diff --git a/src/components/messageArchive/messageArchive.cy.tsx b/src/components/messageArchive/messageArchive.cy.tsx index bb072ad8..dfec4352 100644 --- a/src/components/messageArchive/messageArchive.cy.tsx +++ b/src/components/messageArchive/messageArchive.cy.tsx @@ -120,7 +120,7 @@ describe('MessageArchive Component', () => { }) }) - it.only(`scrolls to bottom when "Go to bottom" button is clicked on ${viewport} screen`, () => { + it(`scrolls to bottom when "Go to bottom" button is clicked on ${viewport} screen`, () => { const waitTime = 500 cy.viewport(viewport) @@ -171,5 +171,258 @@ describe('MessageArchive Component', () => { cy.get(infoMessage).should('exist') cy.get(infoMessage).should('contain', infoMessageText) }) + + it(`displays thread reply count when threadMessages are provided on ${viewport} screen`, () => { + const message1Id = 'message-1' + const message2Id = 'message-2' + + const threadMessages = { + [message1Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:02:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply 1', + }, + }, + { + ...agentMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:03:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply 2', + }, + }, + ], + [message2Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:14:00.000Z', + format: 'TextFormat', + data: { + text: 'Single thread reply', + }, + }, + ], + } + + cy.viewport(viewport) + cy.mount( + { + return new Promise((resolve) => { + resolve([ + { + ...humanMessageData, + id: message1Id, + timestamp: '2024-01-02T00:00:00.000Z', + format: 'TextFormat', + data: { + text: 'First message with threads', + }, + }, + { + ...agentMessageData, + id: message2Id, + timestamp: '2024-01-02T00:13:00.000Z', + format: 'TextFormat', + data: { + text: 'Second message with thread', + }, + }, + ]) + }) + }} + threadMessages={threadMessages} + supportedElements={supportedElements} + /> + ) + + const expectedThreadCount = 2 + cy.get('.rustic-thread-reply-count').should( + 'have.length', + expectedThreadCount + ) + cy.get('.rustic-thread-reply-count').first().should('contain', '2') + cy.get('.rustic-thread-reply-count').first().should('contain', 'replies') + cy.get('.rustic-thread-reply-count').last().should('contain', '1') + cy.get('.rustic-thread-reply-count').last().should('contain', 'reply') + }) + + it(`calls onThreadOpen when thread reply count is clicked on ${viewport} screen`, () => { + const onThreadOpen = cy.stub() + const message1Id = 'message-1' + + const threadMessages = { + [message1Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:02:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply', + }, + }, + ], + } + + cy.viewport(viewport) + cy.mount( + { + return new Promise((resolve) => { + resolve([ + { + ...humanMessageData, + id: message1Id, + timestamp: '2024-01-02T00:00:00.000Z', + format: 'TextFormat', + data: { + text: 'Message with thread', + }, + }, + ]) + }) + }} + threadMessages={threadMessages} + onThreadOpen={onThreadOpen} + supportedElements={supportedElements} + /> + ) + + cy.get('.rustic-thread-reply-count').click() + cy.wrap(onThreadOpen).should('be.calledWith', message1Id) + }) + + it(`highlights active thread message on ${viewport} screen`, () => { + const message1Id = 'message-1' + const message2Id = 'message-2' + + const threadMessages = { + [message1Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:02:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply', + }, + }, + ], + [message2Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:14:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply', + }, + }, + ], + } + + cy.viewport(viewport) + cy.mount( + { + return new Promise((resolve) => { + resolve([ + { + ...humanMessageData, + id: message1Id, + timestamp: '2024-01-02T00:00:00.000Z', + format: 'TextFormat', + data: { + text: 'First message', + }, + }, + { + ...agentMessageData, + id: message2Id, + timestamp: '2024-01-02T00:13:00.000Z', + format: 'TextFormat', + data: { + text: 'Second message', + }, + }, + ]) + }) + }} + threadMessages={threadMessages} + activeThreadId={message1Id} + supportedElements={supportedElements} + /> + ) + + const messageCanvas = '[data-cy=message-canvas]' + cy.get(messageCanvas) + .first() + .find('.rustic-message-container') + .should('have.css', 'background-color') + .and('not.equal', 'rgb(255, 255, 255)') + }) + + it(`handles thread messages with update messages on ${viewport} screen`, () => { + const updateId = 'update-message-1' + + const threadMessages = { + [updateId]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:02:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply to update message', + }, + }, + ], + } + + cy.viewport(viewport) + cy.mount( + { + return new Promise((resolve) => { + resolve([ + { + ...agentMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:00:00.000Z', + format: 'updateMarkdownFormat', + data: { + text: 'First part', + updateId: updateId, + }, + }, + { + ...agentMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:01:00.000Z', + format: 'updateMarkdownFormat', + data: { + text: ' of message', + updateId: updateId, + }, + }, + ]) + }) + }} + threadMessages={threadMessages} + supportedElements={supportedElements} + /> + ) + + cy.get('.rustic-thread-reply-count').should('have.length', 1) + cy.get('.rustic-thread-reply-count').should('contain', '1') + cy.get('.rustic-thread-reply-count').should('contain', 'reply') + }) }) }) diff --git a/src/components/messageArchive/messageArchive.tsx b/src/components/messageArchive/messageArchive.tsx index 496c7221..fb7fa163 100644 --- a/src/components/messageArchive/messageArchive.tsx +++ b/src/components/messageArchive/messageArchive.tsx @@ -7,7 +7,7 @@ import Box from '@mui/system/Box' import React, { type ReactNode, useEffect, useRef, useState } from 'react' import ElementRenderer from '../elementRenderer/elementRenderer' -import { getCombinedMessages } from '../helper' +import { getCombinedMessages, getMessageIdentifier } from '../helper' import Icon from '../icon/icon' import MessageCanvas, { type MessageContainerProps, @@ -29,6 +29,10 @@ export interface MessageArchiveProps extends MessageContainerProps { disableAutoScroll?: boolean /** If true, disables the scroll down button */ disableScrollButton?: boolean + /** The ID of the active thread message */ + activeThreadId?: string + /** A record mapping message IDs to their thread messages */ + threadMessages?: Record } /** @@ -196,10 +200,16 @@ export default function MessageArchive({ {Object.keys(chatMessages).map((key, index) => { const messages = chatMessages[key] const latestMessage = messages[messages.length - 1] + const firstMessage = messages[0] const hasResponse = latestMessage.format.includes('Response') const inReplyTo = hasResponse && { - inReplyTo: messages[0], + inReplyTo: firstMessage, } + + const messageIdentifier = getMessageIdentifier(firstMessage) + const threadReplies = props.threadMessages?.[messageIdentifier] + const threadReplyCount = threadReplies?.length + return ( props.onThreadOpen?.(messageIdentifier), + isActiveThread: props.activeThreadId === messageIdentifier, + })} > + ), + message: messageFromHuman, + getProfileComponent: getProfileIconAndName, + threadReplyCount: 3, + onThreadOpen: (id: string) => { + alert('Thread opened for message: ' + id) + }, + getActionsComponent: (message: Message) => { + const copyButton = message.format === 'TextFormat' && ( + + ) + if (copyButton) { + return <>{copyButton} + } + }, + }, + parameters: { + docs: { + source: { + code: ` + ${elementRendererString} +`, + }, + }, + }, +} diff --git a/src/components/messageCanvas/messageCanvas.tsx b/src/components/messageCanvas/messageCanvas.tsx index b09f034a..758a4ce4 100644 --- a/src/components/messageCanvas/messageCanvas.tsx +++ b/src/components/messageCanvas/messageCanvas.tsx @@ -1,11 +1,14 @@ import './messageCanvas.css' -import { useTheme } from '@mui/material' import Card from '@mui/material/Card' import Stack from '@mui/material/Stack' +import type { Theme } from '@mui/material/styles' import Typography from '@mui/material/Typography' +import Box from '@mui/system/Box' +import useTheme from '@mui/system/useTheme' import React, { forwardRef, type ReactNode } from 'react' +import Icon from '../icon/icon' import Timestamp from '../timestamp/timestamp' import type { Message } from '../types' @@ -19,6 +22,8 @@ export interface MessageContainerProps { * One such example is the `CopyText` component. */ getActionsComponent?: (message: Message) => ReactNode | undefined + /** A function that is called when user clicks on the thread replies. Can be used to open a thread view. */ + onThreadOpen?: (messageId: string) => void } export interface MessageCanvasProps extends MessageContainerProps { @@ -28,6 +33,10 @@ export interface MessageCanvasProps extends MessageContainerProps { inReplyTo?: Message /** React component to be displayed in the message canvas. */ children: ReactNode + /** Number of thread replies */ + threadReplyCount?: number + /** Whether this message is the active thread */ + isActiveThread?: boolean } /** @@ -40,7 +49,6 @@ function MessageCanvasElement( ref: React.Ref ) { const theme = useTheme() - const messageInfo = props.inReplyTo ? props.inReplyTo : props.message return ( @@ -64,12 +72,20 @@ function MessageCanvasElement( {props.getActionsComponent(messageInfo)} )} - + {props.children} {props.inReplyTo && ( )} + {props.threadReplyCount && ( + props.onThreadOpen?.(props.message.id)} + > + + + {props.threadReplyCount}{' '} + {props.threadReplyCount === 1 ? 'reply' : 'replies'} + + + )} ) } diff --git a/src/components/messageSpace/messageSpace.cy.tsx b/src/components/messageSpace/messageSpace.cy.tsx index 2c6cf7c2..677571de 100644 --- a/src/components/messageSpace/messageSpace.cy.tsx +++ b/src/components/messageSpace/messageSpace.cy.tsx @@ -444,5 +444,256 @@ describe('MessageSpace Component', () => { cy.get(messageContainer).should('contain', 'top prompt') cy.get(messageSpace).should('not.contain', 'bottom prompt') }) + + it(`displays thread reply count when threadMessages are provided on ${viewport} screen`, () => { + const mockWsClient = { + send: cy.stub(), + close: cy.stub(), + reconnect: cy.stub(), + } + const message1Id = 'message-1' + const message2Id = 'message-2' + + const threadMessages = { + [message1Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:02:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply 1', + }, + }, + { + ...agentMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:03:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply 2', + }, + }, + ], + [message2Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:14:00.000Z', + format: 'TextFormat', + data: { + text: 'Single thread reply', + }, + }, + ], + } + + cy.viewport(viewport) + cy.mount( + + ) + + const expectedThreadCount = 2 + cy.get('.rustic-thread-reply-count').should( + 'have.length', + expectedThreadCount + ) + cy.get('.rustic-thread-reply-count').first().should('contain', '2') + cy.get('.rustic-thread-reply-count').first().should('contain', 'replies') + cy.get('.rustic-thread-reply-count').last().should('contain', '1') + cy.get('.rustic-thread-reply-count').last().should('contain', 'reply') + }) + + it(`calls onThreadOpen when thread reply count is clicked on ${viewport} screen`, () => { + const mockWsClient = { + send: cy.stub(), + close: cy.stub(), + reconnect: cy.stub(), + } + const onThreadOpen = cy.stub() + const message1Id = 'message-1' + + const threadMessages = { + [message1Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:02:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply', + }, + }, + ], + } + + cy.viewport(viewport) + cy.mount( + + ) + + cy.get('.rustic-thread-reply-count').click() + cy.wrap(onThreadOpen).should('be.calledWith', message1Id) + }) + + it(`highlights active thread message on ${viewport} screen`, () => { + const mockWsClient = { + send: cy.stub(), + close: cy.stub(), + reconnect: cy.stub(), + } + const message1Id = 'message-1' + const message2Id = 'message-2' + + const threadMessages = { + [message1Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:02:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply', + }, + }, + ], + [message2Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:14:00.000Z', + format: 'TextFormat', + data: { + text: 'Thread reply', + }, + }, + ], + } + + cy.viewport(viewport) + cy.mount( + + ) + + cy.get(messageCanvas) + .first() + .find('.rustic-message-container') + .should('have.css', 'background-color') + .and('not.equal', 'rgb(255, 255, 255)') + }) + + it(`renders rootMessages at the top when provided on ${viewport} screen`, () => { + const mockWsClient = { + send: cy.stub(), + close: cy.stub(), + reconnect: cy.stub(), + } + + const rootMessages = [ + { + ...agentMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:00:00.000Z', + format: 'TextFormat', + data: { + text: 'Root message', + }, + }, + ] + + cy.viewport(viewport) + cy.mount( + + ) + + const expectedMessageCount = 2 + cy.get(messageCanvas).should('have.length', expectedMessageCount) + cy.get(messageCanvas).first().should('contain', 'Root message') + cy.get(messageCanvas).last().should('contain', 'Reply message') + }) }) }) diff --git a/src/components/messageSpace/messageSpace.stories.tsx b/src/components/messageSpace/messageSpace.stories.tsx index 2ef13b6e..6edc3a3a 100644 --- a/src/components/messageSpace/messageSpace.stories.tsx +++ b/src/components/messageSpace/messageSpace.stories.tsx @@ -1,4 +1,9 @@ +import Box from '@mui/material/Box' +import Divider from '@mui/material/Divider' +import IconButton from '@mui/material/IconButton' +import { useTheme } from '@mui/material/styles' import Typography from '@mui/material/Typography' +import useMediaQuery from '@mui/material/useMediaQuery' import type { Meta } from '@storybook/react-webpack5' import type { StoryFn } from '@storybook/react-webpack5' import React from 'react' @@ -172,9 +177,32 @@ const tableData = [ ] const updateIdentifier = getUUID() +const message1Id = getUUID() + +const supportedElements = { + TextFormat: Text, + MarkdownFormat: MarkedMarkdown, + ImageFormat: Image, + LocationFormat: OpenLayersMap, + TableFormat: Table, + CalendarFormat: FCCalendar, + FormFormat: UniformsForm, + PromptsFormat: Prompts, + CodeFormat: CodeSnippet, + AudioFormat: Sound, + VideoFormat: Video, + FilesWithTextFormat: Multipart, +} + +const mockWs = { + send: () => {}, + close: () => {}, + reconnect: () => {}, +} + export const Default = { args: { - ws: { send: () => {} }, + ws: mockWs, sender: humanMessageData.sender, receivedMessages: [ { @@ -427,20 +455,7 @@ export const Default = { }, }, ], - supportedElements: { - TextFormat: Text, - MarkdownFormat: MarkedMarkdown, - ImageFormat: Image, - LocationFormat: OpenLayersMap, - TableFormat: Table, - CalendarFormat: FCCalendar, - FormFormat: UniformsForm, - PromptsFormat: Prompts, - CodeFormat: CodeSnippet, - AudioFormat: Sound, - VideoFormat: Video, - FilesWithTextFormat: Multipart, - }, + supportedElements: supportedElements, getProfileComponent: getProfileIconAndName, getActionsComponent: (message: Message) => { const copyButton = message.format === 'text' && ( @@ -453,6 +468,262 @@ export const Default = { }, } +export const ThreadView = { + decorators: [ + (Story: StoryFn) => { + return ( +
+ +
+ ) + }, + ], + render: () => { + const [activeThreadId, setActiveThreadId] = + React.useState(updateIdentifier) + const [isThreadOpen, setIsThreadOpen] = React.useState(false) + const theme = useTheme() + const isMobile = useMediaQuery(theme.breakpoints.down('md')) + + const threadRootMessages = { + [updateIdentifier]: [ + { + ...agentMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:01:00.000Z', + format: 'updateMarkdownFormat', + data: { + text: '## Title', + updateId: updateIdentifier, + }, + }, + { + ...agentMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:02:01.000Z', + format: 'updateMarkdownFormat', + data: { + text: '\nThis is a paragraph. Lorem Ipsum is simply dummy text of the printing and typesetting industry.', + updateId: updateIdentifier, + }, + }, + ], + [message1Id]: [ + { + ...agentMessageData, + id: message1Id, + timestamp: '2024-01-02T00:21:00.000Z', + format: 'FilesWithTextFormat', + data: { + text: 'Here is an example of the multipart component:', + files: [{ name: 'imageExample.png' }, { name: 'pdfExample.pdf' }], + }, + }, + ], + } + + const threadMessagesData = { + [updateIdentifier]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:03:00.000Z', + format: 'TextFormat', + data: { + text: 'This is a thread reply.', + }, + }, + { + ...agentMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:04:00.000Z', + format: 'TextFormat', + data: { + text: 'Agent response to the previous message.', + }, + }, + ], + [message1Id]: [ + { + ...humanMessageData, + id: getUUID(), + timestamp: '2024-01-02T00:22:00.000Z', + format: 'TextFormat', + data: { + text: 'This is a thread reply to the multipart message. User can start a thread from any message and continue the conversation there.', + }, + }, + ], + } + + const handleThreadOpen = (messageId: string) => { + setActiveThreadId(messageId) + if (isMobile) { + setIsThreadOpen(true) + } + } + + const handleThreadClose = () => { + setIsThreadOpen(false) + } + + return ( + <> +
+ + + Chat name + + + + { + const copyButton = message.format === 'text' && ( + + ) + if (copyButton) { + return <>{copyButton} + } + }} + /> + +
+ {!isMobile && } +
+ + + Thread + + + + + + + + +
+ + ) + }, +} + export const WithDisabledScroll = { args: { ...Default.args, diff --git a/src/components/messageSpace/messageSpace.tsx b/src/components/messageSpace/messageSpace.tsx index 365513b8..9fe24120 100644 --- a/src/components/messageSpace/messageSpace.tsx +++ b/src/components/messageSpace/messageSpace.tsx @@ -5,7 +5,7 @@ import Box from '@mui/system/Box' import React, { useEffect, useRef, useState } from 'react' import ElementRenderer from '../elementRenderer/elementRenderer' -import { getCombinedMessages } from '../helper' +import { getCombinedMessages, getMessageIdentifier } from '../helper' import Icon from '../icon/icon' import MessageCanvas, { type MessageContainerProps, @@ -27,6 +27,12 @@ export interface MessageSpaceProps extends MessageContainerProps { disableAutoScroll?: boolean /** If true, disables the scroll down button */ disableScrollButton?: boolean + /** Root messages to be displayed at the top of the message space */ + rootMessages?: Message[] + /** The ID of the active thread message */ + activeThreadId?: string + /** A record mapping message IDs to their thread messages */ + threadMessages?: Record } function usePrevious(value: number) { @@ -37,6 +43,14 @@ function usePrevious(value: number) { return ref.current } +function getInitialChatMessages(rootMessages?: Message[]) { + if (!rootMessages) { + return {} + } + const messageIdentifier = getMessageIdentifier(rootMessages[0]) + return { [messageIdentifier]: rootMessages } +} + /** The `MessageSpace` component uses `MessageCanvas` and `ElementRenderer` to render a list of messages. It serves as a container for individual message items, each encapsulated within a `MessageCanvas` for consistent styling and layout. @@ -162,6 +176,8 @@ export default function MessageSpace({ }, [isScrolledToBottom, Object.keys(chatMessages).length, disableAutoScroll]) useEffect(() => { + setIsScrolledToBottom(true) + const rootMessagesDict = getInitialChatMessages(props.rootMessages) let messageDict: { [messageId: string]: Message[] } = {} props.receivedMessages?.forEach((message) => { @@ -169,8 +185,8 @@ export default function MessageSpace({ messageDict = newMessageDict }) - setChatMessages(messageDict) - }, [props.receivedMessages?.length]) + setChatMessages({ ...rootMessagesDict, ...messageDict }) + }, [props.receivedMessages?.length, props.rootMessages]) function handleIncomingMessage(message: Message) { setChatMessages((prevMessages) => @@ -215,11 +231,16 @@ export default function MessageSpace({ {Object.keys(chatMessages).map((key, index) => { const messages = chatMessages[key] const latestMessage = messages[messages.length - 1] + const firstMessage = messages[0] const hasResponse = latestMessage.format.includes('Response') const inReplyTo = hasResponse && { - inReplyTo: messages[0], + inReplyTo: firstMessage, } + if (!latestMessage.format.toLowerCase().startsWith('prompts')) { + const messageIdentifier = getMessageIdentifier(firstMessage) + const threadReplies = props.threadMessages?.[messageIdentifier] + const threadReplyCount = threadReplies?.length return ( props.onThreadOpen?.(messageIdentifier), + isActiveThread: props.activeThreadId === messageIdentifier, + })} > diff --git a/src/components/sharedDescription.ts b/src/components/sharedDescription.ts index e81c66b3..568dc160 100644 --- a/src/components/sharedDescription.ts +++ b/src/components/sharedDescription.ts @@ -154,4 +154,11 @@ export const textInputDescription: InputType = { defaultValue: { summary: '5' }, }, }, + threads: { + description: + 'Optional array of thread IDs to associate with the message. If provided, the message will be linked to the specified threads.', + table: { + type: { summary: 'string[]' }, + }, + }, } diff --git a/src/components/types.ts b/src/components/types.ts index e302f814..1249dea8 100644 --- a/src/components/types.ts +++ b/src/components/types.ts @@ -325,6 +325,8 @@ export interface BaseInputProps { lastMsg?: Message /** If the input should be focused automatically **/ autoFocus?: boolean + /** Optional array of thread IDs to associate with the message */ + threads?: string[] } export interface TextInputProps