diff --git a/apps/postgres-new/components/chat.tsx b/apps/postgres-new/components/chat.tsx index 9b375864..71747ada 100644 --- a/apps/postgres-new/components/chat.tsx +++ b/apps/postgres-new/components/chat.tsx @@ -77,28 +77,43 @@ export default function Chat() { const sendCsv = useCallback( async (file: File) => { - if (file.type !== 'text/csv') { - // Add an artificial tool call requesting the CSV - // with an error indicating the file wasn't a CSV - appendMessage({ - role: 'assistant', - content: '', - toolInvocations: [ - { - state: 'result', - toolCallId: generateId(), - toolName: 'requestCsv', - args: {}, - result: { - success: false, - error: `The file has type '${file.type}'. Let the user know that only CSV imports are currently supported.`, + const fileId = generateId() + + await saveFile(fileId, file) + + const text = await file.text() + + // Add an artificial tool call requesting the CSV + // with the file result all in one operation. + appendMessage({ + role: 'assistant', + content: '', + toolInvocations: [ + { + state: 'result', + toolCallId: generateId(), + toolName: 'requestCsv', + args: {}, + result: { + success: true, + fileId: fileId, + file: { + name: file.name, + size: file.size, + type: file.type, + lastModified: file.lastModified, }, + preview: text.split('\n').slice(0, 4).join('\n').trim(), }, - ], - }) - return - } + }, + ], + }) + }, + [appendMessage] + ) + const sendSql = useCallback( + async (file: File) => { const fileId = generateId() await saveFile(fileId, file) @@ -114,7 +129,7 @@ export default function Chat() { { state: 'result', toolCallId: generateId(), - toolName: 'requestCsv', + toolName: 'requestSql', args: {}, result: { success: true, @@ -125,7 +140,7 @@ export default function Chat() { type: file.type, lastModified: file.lastModified, }, - preview: text.split('\n').slice(0, 4).join('\n').trim(), + preview: text.split('\n').slice(0, 10).join('\n').trim(), }, }, ], @@ -147,7 +162,16 @@ export default function Chat() { const [file] = files if (file) { - await sendCsv(file) + if (file.type === 'text/csv' || file.name.endsWith('.csv')) { + await sendCsv(file) + } else if (file.type === 'application/sql' || file.name.endsWith('.sql')) { + await sendSql(file) + } else { + appendMessage({ + role: 'assistant', + content: `Only CSV and SQL files are currently supported.`, + }) + } } }, cursorElement: ( diff --git a/apps/postgres-new/components/ide.tsx b/apps/postgres-new/components/ide.tsx index c943e82d..d9e6767c 100644 --- a/apps/postgres-new/components/ide.tsx +++ b/apps/postgres-new/components/ide.tsx @@ -51,7 +51,9 @@ export default function IDE({ children, className }: IDEProps) { return toolInvocations .map((tool) => // Only include SQL that successfully executed against the DB - tool.toolName === 'executeSql' && 'result' in tool && tool.result.success === true + (tool.toolName === 'executeSql' || tool.toolName === 'importSql') && + 'result' in tool && + tool.result.success === true ? tool.args.sql : undefined ) diff --git a/apps/postgres-new/components/tools/index.tsx b/apps/postgres-new/components/tools/index.tsx index c258acbf..1d30d76a 100644 --- a/apps/postgres-new/components/tools/index.tsx +++ b/apps/postgres-new/components/tools/index.tsx @@ -6,6 +6,8 @@ import CsvRequest from './csv-request' import ExecutedSql from './executed-sql' import GeneratedChart from './generated-chart' import GeneratedEmbedding from './generated-embedding' +import SqlImport from './sql-import' +import SqlRequest from './sql-request' export type ToolUiProps = { toolInvocation: ToolInvocation @@ -23,6 +25,10 @@ export function ToolUi({ toolInvocation }: ToolUiProps) { return case 'exportCsv': return + case 'requestSql': + return + case 'importSql': + return case 'renameConversation': return case 'embed': diff --git a/apps/postgres-new/components/tools/sql-import.tsx b/apps/postgres-new/components/tools/sql-import.tsx new file mode 100644 index 00000000..b548cb05 --- /dev/null +++ b/apps/postgres-new/components/tools/sql-import.tsx @@ -0,0 +1,33 @@ +import { useMemo } from 'react' +import { formatSql } from '~/lib/sql-util' +import { ToolInvocation } from '~/lib/tools' +import CodeAccordion from '../code-accordion' + +export type SqlImportProps = { + toolInvocation: ToolInvocation<'importSql'> +} + +export default function SqlImport({ toolInvocation }: SqlImportProps) { + const { fileId, sql } = toolInvocation.args + + const formattedSql = useMemo(() => formatSql(sql), [sql]) + + if (!('result' in toolInvocation)) { + return null + } + + const { result } = toolInvocation + + if (!result.success) { + return ( + + ) + } + + return +} diff --git a/apps/postgres-new/components/tools/sql-request.tsx b/apps/postgres-new/components/tools/sql-request.tsx new file mode 100644 index 00000000..218b5ce1 --- /dev/null +++ b/apps/postgres-new/components/tools/sql-request.tsx @@ -0,0 +1,114 @@ +import { generateId } from 'ai' +import { useChat } from 'ai/react' +import { m } from 'framer-motion' +import { Paperclip } from 'lucide-react' +import { loadFile, saveFile } from '~/lib/files' +import { ToolInvocation } from '~/lib/tools' +import { downloadFile } from '~/lib/util' +import { useWorkspace } from '../workspace' + +export type SqlRequestProps = { + toolInvocation: ToolInvocation<'requestSql'> +} + +export default function SqlRequest({ toolInvocation }: SqlRequestProps) { + const { databaseId } = useWorkspace() + + const { addToolResult } = useChat({ + id: databaseId, + api: '/api/chat', + }) + + if ('result' in toolInvocation) { + const { result } = toolInvocation + + if (!result.success) { + return ( + + No SQL file selected + + ) + } + + return ( + + + { + const file = await loadFile(result.fileId) + downloadFile(file) + }} + > + {result.file.name} + + + ) + } + + return ( + + { + if (e.target.files) { + try { + const [file] = Array.from(e.target.files) + + if (!file) { + throw new Error('No file found') + } + + if (file.type !== 'text/sql') { + throw new Error('File is not a SQL file') + } + + const fileId = generateId() + + await saveFile(fileId, file) + + const text = await file.text() + + addToolResult({ + toolCallId: toolInvocation.toolCallId, + result: { + success: true, + fileId: fileId, + file: { + name: file.name, + size: file.size, + type: file.type, + lastModified: file.lastModified, + }, + preview: text.split('\n').slice(0, 10).join('\n').trim(), + }, + }) + } catch (error) { + addToolResult({ + toolCallId: toolInvocation.toolCallId, + result: { + success: false, + error: error instanceof Error ? error.message : 'An unknown error occurred', + }, + }) + } + } + }} + /> + + ) +} diff --git a/apps/postgres-new/lib/hooks.ts b/apps/postgres-new/lib/hooks.ts index befcd25b..a412836c 100644 --- a/apps/postgres-new/lib/hooks.ts +++ b/apps/postgres-new/lib/hooks.ts @@ -435,6 +435,25 @@ export function useOnToolCall(databaseId: string) { } } } + case 'importSql': { + const { fileId } = toolCall.args + + try { + const file = await loadFile(fileId) + await db.exec(await file.text()) + await refetchTables() + + return { + success: true, + message: 'The SQL file has been executed successfully.', + } + } catch (error) { + return { + success: false, + error: error instanceof Error ? error.message : 'An unknown error has occurred', + } + } + } } }, [dbManager, refetchTables, updateDatabase, databaseId, vectorDataTypeId] diff --git a/apps/postgres-new/lib/tools.ts b/apps/postgres-new/lib/tools.ts index 46a301b9..8b51b02c 100644 --- a/apps/postgres-new/lib/tools.ts +++ b/apps/postgres-new/lib/tools.ts @@ -162,6 +162,40 @@ export const tools = { }) ), }, + requestSql: { + description: codeBlock` + Requests a SQL file upload from the user. + `, + args: z.object({}), + result: result( + z.object({ + fileId: z.string(), + file: z.object({ + name: z.string(), + size: z.number(), + type: z.string(), + lastModified: z.number(), + }), + preview: z.string(), + }) + ), + }, + importSql: { + description: codeBlock` + Executes a Postgres SQL file with the specified ID against the user's database. Call \`requestSql\` first. + `, + args: z.object({ + fileId: z.string().describe('The ID of the SQL file to execute'), + sql: z.string().describe(codeBlock` + The Postgres SQL file content to execute against the user's database. + `), + }), + result: result( + z.object({ + message: z.string(), + }) + ), + }, embed: { description: codeBlock` Generates vector embeddings for texts. Use with pgvector extension.