diff --git a/src/app/api/assistant-modes/route.ts b/src/app/api/assistant-modes/route.ts index a05e880..8fdd304 100644 --- a/src/app/api/assistant-modes/route.ts +++ b/src/app/api/assistant-modes/route.ts @@ -1,5 +1,6 @@ import {NextResponse} from "next/server"; import prisma, {Prisma} from "@/lib/prisma"; +import {RAGDocumentParser, RAGVisionParser} from "@/lib/health-data/parser/rag"; // eslint-disable-next-line @typescript-eslint/no-empty-object-type export interface AssistantMode extends Prisma.AssistantModeGetPayload<{ @@ -14,8 +15,20 @@ export interface AssistantModeListResponse { export async function GET() { const assistantModes = await prisma.assistantMode.findMany({ orderBy: {id: 'asc'}, - }) + }); + + // Add RAG model to assistant modes + const ragDocumentParser = new RAGDocumentParser(); + const ragVisionParser = new RAGVisionParser(); + + const ragAssistantMode: AssistantMode = { + id: 'rag', + name: 'RAG Model', + description: 'Retrieval-Augmented Generation model for discussions based on specific medical literature.', + systemPrompt: 'Use the RAG model to provide responses based on specific medical literature.' + }; + return NextResponse.json({ - assistantModes - }) + assistantModes: [...assistantModes, ragAssistantMode] + }); } diff --git a/src/app/api/chat-rooms/route.ts b/src/app/api/chat-rooms/route.ts index 4bd6abe..69578a7 100644 --- a/src/app/api/chat-rooms/route.ts +++ b/src/app/api/chat-rooms/route.ts @@ -1,5 +1,6 @@ import prisma, {Prisma} from "@/lib/prisma"; import {NextResponse} from "next/server"; +import {RAGDocumentParser, RAGVisionParser} from "@/lib/health-data/parser/rag"; export interface ChatRoom extends Prisma.ChatRoomGetPayload<{ select: { diff --git a/src/components/chat/chat-setting-side-bar.tsx b/src/components/chat/chat-setting-side-bar.tsx index e4cd9da..44ad977 100644 --- a/src/components/chat/chat-setting-side-bar.tsx +++ b/src/components/chat/chat-setting-side-bar.tsx @@ -10,6 +10,7 @@ import {ChatRoomGetResponse} from "@/app/api/chat-rooms/[id]/route"; import {AssistantModePatchRequest} from "@/app/api/assistant-modes/[id]/route"; import {LLMProvider, LLMProviderListResponse} from "@/app/api/llm-providers/route"; import {LLMProviderModel, LLMProviderModelListResponse} from "@/app/api/llm-providers/[id]/models/route"; +import {RAGDocumentParser, RAGVisionParser} from "@/lib/health-data/parser/rag"; interface ChatSettingSideBarProps { isRightSidebarOpen: boolean; diff --git a/src/components/chat/chat-side-bar.tsx b/src/components/chat/chat-side-bar.tsx index bb7e580..4495287 100644 --- a/src/components/chat/chat-side-bar.tsx +++ b/src/components/chat/chat-side-bar.tsx @@ -18,6 +18,7 @@ import { TooltipProvider, TooltipTrigger, } from "@/components/ui/tooltip" +import {RAGDocumentParser, RAGVisionParser} from "@/lib/health-data/parser/rag"; interface ChatSideBarProps { isLeftSidebarOpen: boolean; diff --git a/src/lib/health-data/parser/rag.ts b/src/lib/health-data/parser/rag.ts new file mode 100644 index 0000000..36c44c7 --- /dev/null +++ b/src/lib/health-data/parser/rag.ts @@ -0,0 +1,148 @@ +import {BaseDocumentParser, DocumentParseOptions, DocumentParseResult, DocumentParserModel, OCRParseResult} from "@/lib/health-data/parser/document/base-document"; +import {BaseVisionParser, VisionParseOptions, VisionParserModel} from "@/lib/health-data/parser/vision/base-vision"; +import {ChatPromptTemplate} from "@langchain/core/prompts"; +import {HealthCheckupType} from "@/lib/health-data/parser/schema"; +import {MessagePayload} from "@/lib/health-data/parser/prompt"; +import {processBatchWithConcurrency} from "@/lib/health-data/parser/util"; + +export class RAGDocumentParser extends BaseDocumentParser { + get name(): string { + return 'RAGDocumentParser'; + } + + get apiKeyRequired(): boolean { + return true; + } + + async models(): Promise { + return [ + {id: 'rag-document-parse', name: 'RAG Document Parse'} + ]; + } + + async ocr(options: DocumentParseOptions): Promise { + // Implement OCR logic for RAG + return {ocr: {}}; + } + + async parse(options: DocumentParseOptions): Promise { + // Implement document parsing logic for RAG + return {document: {}}; + } +} + +export class RAGVisionParser extends BaseVisionParser { + get name(): string { + return 'RAGVisionParser'; + } + + async models(): Promise { + return [ + {id: 'rag-vision-parse', name: 'RAG Vision Parse'} + ]; + } + + async parse(options: VisionParseOptions): Promise { + const llm = new ChatPromptTemplate({model: options.model.id, apiKey: options.apiKey}); + const messages = options.messages || ChatPromptTemplate.fromMessages([]); + const chain = messages.pipe(llm.withStructuredOutput(HealthCheckupType, { + method: 'functionCalling', + })); + return await chain.withRetry({stopAfterAttempt: 3}).invoke(options.input); + } +} + +export async function parseHealthDataWithRAG(options: { file: string, visionParser: VisionParserOptions, documentParser: DocumentParserOptions }) { + const {file, visionParser, documentParser} = options; + + // prepare images + const imagePaths = await documentToImages({file}); + + // prepare ocr results + const ocrResults = await documentOCR({ + document: file, + documentParser: documentParser + }); + + // prepare parse results + await processBatchWithConcurrency( + imagePaths, + async (path) => documentParse({document: path, documentParser: documentParser}), + 3 + ); + + // Merge the results + const baseInferenceOptions = {imagePaths, visionParser, documentParser}; + const [ + {finalHealthCheckup: resultTotal, mergedTestResultPage: resultTotalPages}, + {finalHealthCheckup: resultText, mergedTestResultPage: resultTextPages}, + {finalHealthCheckup: resultImage, mergedTestResultPage: resultImagePages} + ] = await Promise.all([ + inference({...baseInferenceOptions, excludeImage: false, excludeText: false}), + inference({...baseInferenceOptions, excludeImage: false, excludeText: true}), + inference({...baseInferenceOptions, excludeImage: true, excludeText: false}), + ]); + + const resultDictTotal = resultTotal.test_result; + const resultDictText = resultText.test_result; + const resultDictImage = resultImage.test_result; + + const mergedTestResult: { [key: string]: any } = {}; + const mergedPageResult: { [key: string]: { page: number } | null } = {}; + + for (const key of HealthCheckupSchema.shape.test_result.keyof().options) { + const valueTotal = + resultDictTotal.hasOwnProperty(key) && + resultDictTotal[key] !== null && + resultDictTotal[key]!.value !== null + ? resultDictTotal[key] + : null; + const pageTotal = valueTotal !== null ? resultTotalPages[key] : null; + + const valueText = + resultDictText.hasOwnProperty(key) && + resultDictText[key] !== null && + resultDictText[key]!.value !== null + ? resultDictText[key] + : null; + const pageText = valueText !== null ? resultTextPages[key] : null; + + const valueImage = + resultDictImage.hasOwnProperty(key) && + resultDictImage[key] !== null && + resultDictImage[key]!.value !== null + ? resultDictImage[key] + : null; + const pageImage = valueImage !== null ? resultImagePages[key] : null; + + if (valueTotal === null) { + if (valueText !== null) { + mergedTestResult[key] = valueText; + mergedPageResult[key] = pageText; + } else if (valueImage !== null) { + mergedTestResult[key] = valueImage; + mergedPageResult[key] = pageImage; + } else { + mergedTestResult[key] = valueText; + mergedPageResult[key] = pageText; + } + } else { + mergedTestResult[key] = valueTotal; + mergedPageResult[key] = pageTotal; + } + } + + // remove all null values in mergedTestResult + for (const key in mergedTestResult) { + if (mergedTestResult[key] === null) { + delete mergedTestResult[key]; + } + } + + const healthCheckup = HealthCheckupSchema.parse({ + ...resultTotal, + test_result: mergedTestResult + }); + + return {data: [healthCheckup], pages: [mergedPageResult], ocrResults: [ocrResults]}; +}