Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 255 additions & 5 deletions app/(chat)/api/chat/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,17 @@ import { getUsage } from "tokenlens/helpers";
import { auth, type UserType } from "@/app/(auth)/auth";
import type { VisibilityType } from "@/components/visibility-selector";
import { entitlementsByUserType } from "@/lib/ai/entitlements";
import {
type FileAttachment,
generateCompatibilityErrorMessage,
validateFileCompatibility,
} from "@/lib/ai/file-compatibility";
import type { ChatModel } from "@/lib/ai/models";
import { type RequestHints, systemPrompt } from "@/lib/ai/prompts";
import {
analyzeAttachmentPrompt,
type RequestHints,
systemPrompt,
} from "@/lib/ai/prompts";
import { myProvider } from "@/lib/ai/providers";
import { createDocument } from "@/lib/ai/tools/create-document";
import { getWeather } from "@/lib/ai/tools/get-weather";
Expand All @@ -45,7 +54,11 @@ import {
generateTitleFromUserMessage,
saveChatModelAsCookie,
} from "../../actions";
import { type PostRequestBody, postRequestBodySchema } from "./schema";
import {
type FilePart,
type PostRequestBody,
postRequestBodySchema,
} from "./schema";

export const maxDuration = 60;

Expand Down Expand Up @@ -87,6 +100,173 @@ export function getStreamContext() {
return globalStreamContext;
}

/**
* Fetches text file content from URL with timeout handling
*/
async function fetchTextFileContent(file: {
name: string;
url: string;
mediaType: string;
}): Promise<string> {
try {
const controller = new AbortController();
const timeoutId = setTimeout(() => {
console.log(`Timeout for ${file.name}`);
controller.abort();
}, 30_000);

const response = await fetch(file.url, {
signal: controller.signal,
});
clearTimeout(timeoutId);

if (!response.ok) {
console.error(`Failed: ${response.status}`);
return `\n\n[File Upload: ${file.name} (${file.mediaType}) - Failed: ${response.status}]`;
}

const content = await response.text();

// Check if content is too large (max 0.5MB for text files)
const maxSize = 0.5 * 1024 * 1024;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The max size for text files (0.5 * 1024 * 1024) is hardcoded here. It would be better to define this as a named constant, for example MAX_TEXT_FILE_SIZE_BYTES, at the top of the file or in a shared constants file. This improves maintainability and makes the code easier to understand.

if (new TextEncoder().encode(content).length > maxSize) {
return `\n\n[File Upload: ${file.name} (${file.mediaType}) - File too large (${Math.round(content.length / 1024)}KB)]`;
}

return `\n\n[File Upload: ${file.name} (${file.mediaType})]\n${content}`;
} catch (error) {
console.error(`Error fetching ${file.name}:`, error);
return `\n\n[File Upload: ${file.name} (${file.mediaType}) - Failed to load]`;
}
}

/**
* Processes UI messages to extract text files and convert them to text parts
* Text files (txt, csv, md) are removed from file parts and their content is appended as text
*/
async function processUIMessagesWithTextFiles(
uiMessages: ChatMessage[]
): Promise<ChatMessage[]> {
return await Promise.all(
uiMessages.map(async (msg) => {
if (msg.role !== "user" || !msg.parts) {
return msg;
}

type FilePartWithUrl = {
type: "file";
name: string;
url: string;
mediaType: string;
};

const textFiles: FilePartWithUrl[] = [];
const nonTextParts: ChatMessage["parts"] = [];

// Separate text files from other parts
for (const part of msg.parts) {
if (part.type === "file") {
const filePart = part as FilePartWithUrl;
const mediaType = filePart.mediaType;
const isTextFile =
mediaType === "text/plain" ||
mediaType === "text/csv" ||
mediaType === "text/markdown" ||
mediaType === "application/csv";

if (isTextFile) {
textFiles.push(filePart);
} else {
// Keep images and PDFs as file parts
nonTextParts.push(part);
}
} else {
// Keep all non-file parts (text, tool calls, etc.)
nonTextParts.push(part);
}
}

// Fetch text file contents and append to message
if (textFiles.length > 0) {
const textFileContents = await Promise.all(
textFiles.map((file) => fetchTextFileContent(file))
);

const appendedText = textFileContents.join("");

// Find existing text part or create new one
const textPartIndex = nonTextParts.findIndex((p) => p.type === "text");
if (textPartIndex >= 0) {
const existingPart = nonTextParts[textPartIndex];
if (existingPart.type === "text") {
nonTextParts[textPartIndex] = {
...existingPart,
text: existingPart.text + appendedText,
};
}
} else {
nonTextParts.push({
type: "text",
text: appendedText.trim(),
});
}
}

return {
...msg,
parts: nonTextParts,
};
})
);
}

/**
* Converts processed UI messages to model messages and applies Gateway AI format for images
* Images get converted to { type: "file", data: URL, filename, mediaType }
* PDFs and other files stay in standard format
*/
function convertToGatewayModelMessages(
processedUIMessages: ChatMessage[]
): Awaited<ReturnType<typeof convertToModelMessages>> {
const modelMessages = convertToModelMessages(processedUIMessages);

return modelMessages.map((msg) => {
if (msg.role === "user" && msg.content && Array.isArray(msg.content)) {
return {
...msg,
content: msg.content.map((part) => {
// Convert image files to Gateway AI format
if (
part.type === "file" &&
typeof part === "object" &&
"url" in part &&
part.url
) {
const filePart = part as {
type: "file";
url: string;
name?: string;
mimeType?: string;
mediaType?: string;
};
const mediaType = filePart.mimeType || filePart.mediaType;
if (mediaType?.startsWith("image/")) {
return {
type: "file" as const,
data: filePart.url,
filename: filePart.name,
mediaType,
};
}
}
return part;
}),
};
}
return msg;
});
}

export async function POST(request: Request) {
let requestBody: PostRequestBody;

Expand Down Expand Up @@ -121,6 +301,36 @@ export async function POST(request: Request) {
// Update chat model cookie with the current model ID
await saveChatModelAsCookie(selectedChatModel);

// Validate file compatibility with selected model
const fileParts = message.parts.filter(
(part): part is FilePart => part.type === "file"
);

if (fileParts.length > 0) {
const fileAttachments: FileAttachment[] = fileParts.map((part) => ({
name: part.name,
url: part.url,
mediaType: part.mediaType as FileAttachment["mediaType"],
}));

const incompatibleFiles = validateFileCompatibility(
fileAttachments,
selectedChatModel
);

if (incompatibleFiles.length > 0) {
const errorMessage =
generateCompatibilityErrorMessage(incompatibleFiles);
return Response.json(
{
error: errorMessage,
incompatibleFiles,
},
{ status: 400 }
);
}
}

// TODO: credit based limit per month
const messageCount = await getMessageCountByUserId({
id: session.user.id,
Expand Down Expand Up @@ -151,7 +361,28 @@ export async function POST(request: Request) {
}

const messagesFromDb = await getMessagesByChatId({ id });
const uiMessages = [...convertToUIMessages(messagesFromDb), message];

const hasTextPart = message.parts.some((part) => part.type === "text");
const hasFilePart = message.parts.some((part) => part.type === "file");

const messageForModel =
!hasTextPart && hasFilePart
? {
...message,
parts: [
...message.parts,
{
type: "text" as const,
text: analyzeAttachmentPrompt,
},
],
}
: message;

const uiMessages = [
...convertToUIMessages(messagesFromDb),
messageForModel,
];

const { longitude, latitude, city, country } = geolocation(request);

Expand All @@ -168,7 +399,7 @@ export async function POST(request: Request) {
chatId: id,
id: message.id,
role: "user",
parts: message.parts,
parts: message.parts, // Save original message without injected prompt
attachments: [],
createdAt: new Date(),
},
Expand All @@ -180,12 +411,31 @@ export async function POST(request: Request) {

let finalMergedUsage: AppUsage | undefined;

// Process text files
let processedUIMessages: ChatMessage[];
try {
processedUIMessages = await processUIMessagesWithTextFiles(uiMessages);
} catch (error) {
console.error("Error processing text files:", error);
processedUIMessages = uiMessages;
}

// Convert to model messages with Gateway AI format for images
let modelMessages: Awaited<ReturnType<typeof convertToModelMessages>>;
try {
modelMessages = convertToGatewayModelMessages(processedUIMessages);
} catch (error) {
console.error("Error converting model messages:", error);
modelMessages = convertToModelMessages(processedUIMessages);
}

// Create and execute the stream
const stream = createUIMessageStream({
execute: ({ writer: dataStream }) => {
const result = streamText({
model: myProvider.languageModel(selectedChatModel),
system: systemPrompt({ selectedChatModel, requestHints }),
messages: convertToModelMessages(uiMessages),
messages: modelMessages,
stopWhen: stepCountIs(5),
experimental_activeTools:
selectedChatModel === "chat-model-reasoning"
Expand Down
16 changes: 13 additions & 3 deletions app/(chat)/api/chat/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@ import { ALL_MODEL_IDS } from "@/lib/ai/models";

const textPartSchema = z.object({
type: z.enum(["text"]),
text: z.string().min(1).max(5000),
text: z.string().min(1).max(100_000),
});

const filePartSchema = z.object({
type: z.enum(["file"]),
mediaType: z.enum(["image/jpeg", "image/png", "text"]),
mediaType: z.enum([
"image/jpeg",
"image/png",
"image/heic",
"application/pdf",
"text/plain",
"text/csv",
"text/markdown",
"application/csv",
]),
name: z.string().min(1).max(100),
url: z.string().url(),
});
Expand All @@ -20,10 +29,11 @@ export const postRequestBodySchema = z.object({
message: z.object({
id: z.string().uuid(),
role: z.enum(["user"]),
parts: z.array(partSchema),
parts: z.array(partSchema).min(1, "Message must contain at least one part"),
}),
selectedChatModel: z.enum([...ALL_MODEL_IDS] as [string, ...string[]]),
selectedVisibilityType: z.enum(["public", "private"]),
});

export type PostRequestBody = z.infer<typeof postRequestBodySchema>;
export type FilePart = z.infer<typeof filePartSchema>;
Loading
Loading