Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/client/VZCodeContext/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ export type VZCodeContextValue = {
messageToSend?: string,
options?: Record<string, string>,
) => void;
handleStopGeneration: (chatId: string) => void;

// Message history navigation
navigateMessageHistoryUp: () => void;
Expand Down
27 changes: 27 additions & 0 deletions src/client/VZCodeContext/useVZCodeState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,32 @@ export const useVZCodeState = ({
],
);

const handleStopGeneration = useCallback(
async (chatId: string) => {
try {
const response = await fetch('/ai-chat-stop', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
chatId,
}),
});

if (!response.ok) {
console.error(
'Failed to stop generation:',
response.statusText,
);
}
} catch (error) {
console.error('Error stopping generation:', error);
}
},
[],
);

// Config.json change detection and iframe notification
// This logic was moved from VisualEditor.tsx to ensure it runs
// even when the visual editor is not open, fixing cross-client propagation
Expand Down Expand Up @@ -685,6 +711,7 @@ export const useVZCodeState = ({
aiErrorMessage,
setAIErrorMessage,
handleSendMessage,
handleStopGeneration,

// Message history navigation
navigateMessageHistoryUp,
Expand Down
59 changes: 44 additions & 15 deletions src/client/VZSidebar/AIChat/ChatInput.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
} from '../../bootstrap';
import { enableAskMode } from '../../featureFlags';
import { useSpeechRecognition } from './useSpeechRecognition';
import { MicSVG, MicOffSVG } from '../../Icons';
import { MicSVG, MicOffSVG, StopSVG } from '../../Icons';

interface ChatInputProps {
aiChatMessage: string;
Expand All @@ -24,6 +24,9 @@ interface ChatInputProps {
navigateMessageHistoryUp: () => void;
navigateMessageHistoryDown: () => void;
resetMessageHistoryNavigation: () => void;
isStreaming?: boolean;
onStopGeneration?: () => void;
lastUserMessage?: string;
}

const ChatInputComponent = ({
Expand All @@ -36,6 +39,9 @@ const ChatInputComponent = ({
navigateMessageHistoryUp,
navigateMessageHistoryDown,
resetMessageHistoryNavigation,
isStreaming = false,
onStopGeneration,
lastUserMessage,
}: ChatInputProps) => {
const inputRef = useRef<HTMLTextAreaElement>(null);

Expand Down Expand Up @@ -110,6 +116,16 @@ const ChatInputComponent = ({
onSendMessage();
}, [onSendMessage, isSpeaking, stopSpeaking]);

const handleStopClick = useCallback(() => {
if (onStopGeneration) {
// Restore the last user message to the input field for refinement
if (lastUserMessage) {
setAIChatMessage(lastUserMessage);
}
onStopGeneration();
}
}, [onStopGeneration, lastUserMessage, setAIChatMessage]);

const handleChange = useCallback(
(event: React.ChangeEvent<HTMLTextAreaElement>) => {
setAIChatMessage(event.target.value);
Expand Down Expand Up @@ -213,20 +229,33 @@ const ChatInputComponent = ({
>
{isSpeaking ? <MicOffSVG /> : <MicSVG />}
</Button>
<Button
variant={
aiChatMessage.trim()
? 'primary'
: 'outline-secondary'
}
onClick={handleSendClick}
disabled={!aiChatMessage.trim()}
className="ai-chat-send-button"
aria-label="Send message"
title="Send message (Enter)"
>
Send
</Button>
{isStreaming ? (
<Button
variant="danger"
onClick={handleStopClick}
className="ai-chat-stop-button"
aria-label="Stop generation"
title="Stop generation"
>
<StopSVG />
Stop
</Button>
) : (
<Button
variant={
aiChatMessage.trim()
? 'primary'
: 'outline-secondary'
}
onClick={handleSendClick}
disabled={!aiChatMessage.trim()}
className="ai-chat-send-button"
aria-label="Send message"
title="Send message (Enter)"
>
Send
</Button>
)}
</div>
</div>
</Form.Group>
Expand Down
18 changes: 18 additions & 0 deletions src/client/VZSidebar/AIChat/StreamingMessage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,24 @@ export const StreamingMessage: React.FC<
case 'file_start':
// File start events are now handled by centralized status logic in MessageList
return null;
case 'stopped':
return (
<div
key={`stopped-${index}`}
className="text-chunk stopped-message"
>
<em>Generation stopped by user.</em>
</div>
);
case 'error':
return (
<div
key={`error-${index}`}
className="text-chunk error-message"
>
<em>Error: {event.message}</em>
</div>
);
default:
return null;
}
Expand Down
25 changes: 25 additions & 0 deletions src/client/VZSidebar/AIChat/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export const AIChat = () => {
getStoredAIPrompt,
setAIChatMessage,
handleSendMessage,
handleStopGeneration,
setAIErrorMessage,
navigateMessageHistoryUp,
navigateMessageHistoryDown,
Expand Down Expand Up @@ -104,6 +105,27 @@ export const AIChat = () => {
const isEmptyState =
!selectedChatId || rawMessages.length === 0;

// Determine if streaming is active
const isStreaming =
(currentChat as ExtendedVizChat)?.isStreaming || false;

// Get the last user message for restoration on stop
const lastUserMessage = useMemo(() => {
const userMessages = rawMessages.filter(
(msg) => msg.role === 'user',
);
return userMessages.length > 0
? userMessages[userMessages.length - 1].content
: '';
}, [rawMessages]);

// Create stop handler for the current chat
const handleStopCurrentGeneration = useCallback(() => {
if (activeChatId) {
handleStopGeneration(activeChatId);
}
}, [activeChatId, handleStopGeneration]);

// Get all existing chats
const allChats = content?.chats || {};
const existingChats = Object.values(allChats).filter(
Expand Down Expand Up @@ -357,6 +379,9 @@ export const AIChat = () => {
resetMessageHistoryNavigation={
resetMessageHistoryNavigation
}
isStreaming={isStreaming}
onStopGeneration={handleStopCurrentGeneration}
lastUserMessage={lastUserMessage}
/>
</div>
</div>
Expand Down
31 changes: 31 additions & 0 deletions src/server/aiChatHandler/chatStopFlag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { ShareDBDoc } from '../../types.js';
import { VizChatId } from '@vizhub/viz-types';
import { diff } from '../../ot.js';

export const setStopRequested = (
shareDBDoc: ShareDBDoc<any>,
chatId: VizChatId,
value: boolean,
) => {
// Store under your existing chat state, e.g. data.chats[chatId].stopRequested = value
const currentStopRequested =
!!shareDBDoc.data.chats?.[chatId]?.stopRequested;
if (currentStopRequested !== value) {
const op = diff(shareDBDoc.data, {
...shareDBDoc.data,
chats: {
...shareDBDoc.data.chats,
[chatId]: {
...shareDBDoc.data.chats[chatId],
stopRequested: value,
},
},
});
shareDBDoc.submitOp(op);
}
};

export const isStopRequested = (
shareDBDoc: ShareDBDoc<any>,
chatId: VizChatId,
) => !!shareDBDoc.data.chats?.[chatId]?.stopRequested;
20 changes: 20 additions & 0 deletions src/server/aiChatHandler/generationControl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { VizChatId } from '@vizhub/viz-types';

const controllers = new Map<VizChatId, AbortController>();

export const registerController = (
chatId: VizChatId,
controller: AbortController,
) => {
controllers.set(chatId, controller);
};

export const deregisterController = (chatId: VizChatId) => {
controllers.delete(chatId);
};

export const stopGenerationNow = (chatId: VizChatId) => {
const ctrl = controllers.get(chatId);
if (ctrl) ctrl.abort(); // Triggers AbortError in your streaming loop
controllers.delete(chatId);
};
39 changes: 39 additions & 0 deletions src/server/aiChatHandler/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import { ShareDBDoc } from '../../types.js';
import { VizContent } from '@vizhub/viz-types';
import { createSubmitOperation } from '../../submitOperation.js';
import { getGenerationMetadata } from 'editcodewithai';
import { stopGenerationNow } from './generationControl.js';
import { setStopRequested } from './chatStopFlag.js';

const DEBUG = false;

Expand Down Expand Up @@ -148,3 +150,40 @@ const processAIRequestAsync = async ({
handleBackgroundError(shareDBDoc, chatId, error);
}
};

export const handleStopGeneration =
({
shareDBDoc,
}: {
shareDBDoc: ShareDBDoc<VizContent>;
}) =>
async (req: any, res: any) => {
const { chatId } = req.body;

if (DEBUG) {
console.log('[handleStopGeneration] chatId:', chatId);
}

try {
// Validate that chatId is provided
if (!chatId) {
return res
.status(400)
.json({ error: 'Missing chatId' });
}

// Set the stop flag in ShareDB
setStopRequested(shareDBDoc, chatId, true);

// Abort the network request
stopGenerationNow(chatId);

// Return success
res.status(200).json({ success: true });
} catch (error) {
console.error('Stop generation error:', error);
res
.status(500)
.json({ error: 'Failed to stop generation' });
}
};
Loading