diff --git a/apps/postgres-new/components/chat.tsx b/apps/postgres-new/components/chat.tsx
index 9b375864..71747ada 100644
--- a/apps/postgres-new/components/chat.tsx
+++ b/apps/postgres-new/components/chat.tsx
@@ -77,28 +77,43 @@ export default function Chat() {
const sendCsv = useCallback(
async (file: File) => {
- if (file.type !== 'text/csv') {
- // Add an artificial tool call requesting the CSV
- // with an error indicating the file wasn't a CSV
- appendMessage({
- role: 'assistant',
- content: '',
- toolInvocations: [
- {
- state: 'result',
- toolCallId: generateId(),
- toolName: 'requestCsv',
- args: {},
- result: {
- success: false,
- error: `The file has type '${file.type}'. Let the user know that only CSV imports are currently supported.`,
+ const fileId = generateId()
+
+ await saveFile(fileId, file)
+
+ const text = await file.text()
+
+ // Add an artificial tool call requesting the CSV
+ // with the file result all in one operation.
+ appendMessage({
+ role: 'assistant',
+ content: '',
+ toolInvocations: [
+ {
+ state: 'result',
+ toolCallId: generateId(),
+ toolName: 'requestCsv',
+ args: {},
+ result: {
+ success: true,
+ fileId: fileId,
+ file: {
+ name: file.name,
+ size: file.size,
+ type: file.type,
+ lastModified: file.lastModified,
},
+ preview: text.split('\n').slice(0, 4).join('\n').trim(),
},
- ],
- })
- return
- }
+ },
+ ],
+ })
+ },
+ [appendMessage]
+ )
+ const sendSql = useCallback(
+ async (file: File) => {
const fileId = generateId()
await saveFile(fileId, file)
@@ -114,7 +129,7 @@ export default function Chat() {
{
state: 'result',
toolCallId: generateId(),
- toolName: 'requestCsv',
+ toolName: 'requestSql',
args: {},
result: {
success: true,
@@ -125,7 +140,7 @@ export default function Chat() {
type: file.type,
lastModified: file.lastModified,
},
- preview: text.split('\n').slice(0, 4).join('\n').trim(),
+ preview: text.split('\n').slice(0, 10).join('\n').trim(),
},
},
],
@@ -147,7 +162,16 @@ export default function Chat() {
const [file] = files
if (file) {
- await sendCsv(file)
+ if (file.type === 'text/csv' || file.name.endsWith('.csv')) {
+ await sendCsv(file)
+ } else if (file.type === 'application/sql' || file.name.endsWith('.sql')) {
+ await sendSql(file)
+ } else {
+ appendMessage({
+ role: 'assistant',
+ content: `Only CSV and SQL files are currently supported.`,
+ })
+ }
}
},
cursorElement: (
diff --git a/apps/postgres-new/components/ide.tsx b/apps/postgres-new/components/ide.tsx
index c943e82d..d9e6767c 100644
--- a/apps/postgres-new/components/ide.tsx
+++ b/apps/postgres-new/components/ide.tsx
@@ -51,7 +51,9 @@ export default function IDE({ children, className }: IDEProps) {
return toolInvocations
.map((tool) =>
// Only include SQL that successfully executed against the DB
- tool.toolName === 'executeSql' && 'result' in tool && tool.result.success === true
+ (tool.toolName === 'executeSql' || tool.toolName === 'importSql') &&
+ 'result' in tool &&
+ tool.result.success === true
? tool.args.sql
: undefined
)
diff --git a/apps/postgres-new/components/tools/index.tsx b/apps/postgres-new/components/tools/index.tsx
index c258acbf..1d30d76a 100644
--- a/apps/postgres-new/components/tools/index.tsx
+++ b/apps/postgres-new/components/tools/index.tsx
@@ -6,6 +6,8 @@ import CsvRequest from './csv-request'
import ExecutedSql from './executed-sql'
import GeneratedChart from './generated-chart'
import GeneratedEmbedding from './generated-embedding'
+import SqlImport from './sql-import'
+import SqlRequest from './sql-request'
export type ToolUiProps = {
toolInvocation: ToolInvocation
@@ -23,6 +25,10 @@ export function ToolUi({ toolInvocation }: ToolUiProps) {
return
case 'exportCsv':
return
+ case 'requestSql':
+ return
+ case 'importSql':
+ return
case 'renameConversation':
return
case 'embed':
diff --git a/apps/postgres-new/components/tools/sql-import.tsx b/apps/postgres-new/components/tools/sql-import.tsx
new file mode 100644
index 00000000..b548cb05
--- /dev/null
+++ b/apps/postgres-new/components/tools/sql-import.tsx
@@ -0,0 +1,33 @@
+import { useMemo } from 'react'
+import { formatSql } from '~/lib/sql-util'
+import { ToolInvocation } from '~/lib/tools'
+import CodeAccordion from '../code-accordion'
+
+export type SqlImportProps = {
+ toolInvocation: ToolInvocation<'importSql'>
+}
+
+export default function SqlImport({ toolInvocation }: SqlImportProps) {
+ const { fileId, sql } = toolInvocation.args
+
+ const formattedSql = useMemo(() => formatSql(sql), [sql])
+
+ if (!('result' in toolInvocation)) {
+ return null
+ }
+
+ const { result } = toolInvocation
+
+ if (!result.success) {
+ return (
+
+ )
+ }
+
+ return
+}
diff --git a/apps/postgres-new/components/tools/sql-request.tsx b/apps/postgres-new/components/tools/sql-request.tsx
new file mode 100644
index 00000000..218b5ce1
--- /dev/null
+++ b/apps/postgres-new/components/tools/sql-request.tsx
@@ -0,0 +1,114 @@
+import { generateId } from 'ai'
+import { useChat } from 'ai/react'
+import { m } from 'framer-motion'
+import { Paperclip } from 'lucide-react'
+import { loadFile, saveFile } from '~/lib/files'
+import { ToolInvocation } from '~/lib/tools'
+import { downloadFile } from '~/lib/util'
+import { useWorkspace } from '../workspace'
+
+export type SqlRequestProps = {
+ toolInvocation: ToolInvocation<'requestSql'>
+}
+
+export default function SqlRequest({ toolInvocation }: SqlRequestProps) {
+ const { databaseId } = useWorkspace()
+
+ const { addToolResult } = useChat({
+ id: databaseId,
+ api: '/api/chat',
+ })
+
+ if ('result' in toolInvocation) {
+ const { result } = toolInvocation
+
+ if (!result.success) {
+ return (
+
+ No SQL file selected
+
+ )
+ }
+
+ return (
+
+
+ {
+ const file = await loadFile(result.fileId)
+ downloadFile(file)
+ }}
+ >
+ {result.file.name}
+
+
+ )
+ }
+
+ return (
+
+ {
+ if (e.target.files) {
+ try {
+ const [file] = Array.from(e.target.files)
+
+ if (!file) {
+ throw new Error('No file found')
+ }
+
+ if (file.type !== 'text/sql') {
+ throw new Error('File is not a SQL file')
+ }
+
+ const fileId = generateId()
+
+ await saveFile(fileId, file)
+
+ const text = await file.text()
+
+ addToolResult({
+ toolCallId: toolInvocation.toolCallId,
+ result: {
+ success: true,
+ fileId: fileId,
+ file: {
+ name: file.name,
+ size: file.size,
+ type: file.type,
+ lastModified: file.lastModified,
+ },
+ preview: text.split('\n').slice(0, 10).join('\n').trim(),
+ },
+ })
+ } catch (error) {
+ addToolResult({
+ toolCallId: toolInvocation.toolCallId,
+ result: {
+ success: false,
+ error: error instanceof Error ? error.message : 'An unknown error occurred',
+ },
+ })
+ }
+ }
+ }}
+ />
+
+ )
+}
diff --git a/apps/postgres-new/lib/hooks.ts b/apps/postgres-new/lib/hooks.ts
index befcd25b..a412836c 100644
--- a/apps/postgres-new/lib/hooks.ts
+++ b/apps/postgres-new/lib/hooks.ts
@@ -435,6 +435,25 @@ export function useOnToolCall(databaseId: string) {
}
}
}
+ case 'importSql': {
+ const { fileId } = toolCall.args
+
+ try {
+ const file = await loadFile(fileId)
+ await db.exec(await file.text())
+ await refetchTables()
+
+ return {
+ success: true,
+ message: 'The SQL file has been executed successfully.',
+ }
+ } catch (error) {
+ return {
+ success: false,
+ error: error instanceof Error ? error.message : 'An unknown error has occurred',
+ }
+ }
+ }
}
},
[dbManager, refetchTables, updateDatabase, databaseId, vectorDataTypeId]
diff --git a/apps/postgres-new/lib/tools.ts b/apps/postgres-new/lib/tools.ts
index 46a301b9..8b51b02c 100644
--- a/apps/postgres-new/lib/tools.ts
+++ b/apps/postgres-new/lib/tools.ts
@@ -162,6 +162,40 @@ export const tools = {
})
),
},
+ requestSql: {
+ description: codeBlock`
+ Requests a SQL file upload from the user.
+ `,
+ args: z.object({}),
+ result: result(
+ z.object({
+ fileId: z.string(),
+ file: z.object({
+ name: z.string(),
+ size: z.number(),
+ type: z.string(),
+ lastModified: z.number(),
+ }),
+ preview: z.string(),
+ })
+ ),
+ },
+ importSql: {
+ description: codeBlock`
+ Executes a Postgres SQL file with the specified ID against the user's database. Call \`requestSql\` first.
+ `,
+ args: z.object({
+ fileId: z.string().describe('The ID of the SQL file to execute'),
+ sql: z.string().describe(codeBlock`
+ The Postgres SQL file content to execute against the user's database.
+ `),
+ }),
+ result: result(
+ z.object({
+ message: z.string(),
+ })
+ ),
+ },
embed: {
description: codeBlock`
Generates vector embeddings for texts. Use with pgvector extension.