Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/app/api/assistant-modes/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import {NextResponse} from "next/server";
import prisma, {Prisma} from "@/lib/prisma";
import {RAGDocumentParser, RAGVisionParser} from "@/lib/health-data/parser/rag";

// eslint-disable-next-line @typescript-eslint/no-empty-object-type
export interface AssistantMode extends Prisma.AssistantModeGetPayload<{
Expand All @@ -14,8 +15,20 @@ export interface AssistantModeListResponse {
export async function GET() {
const assistantModes = await prisma.assistantMode.findMany({
orderBy: {id: 'asc'},
})
});

// Add RAG model to assistant modes
const ragDocumentParser = new RAGDocumentParser();
const ragVisionParser = new RAGVisionParser();

const ragAssistantMode: AssistantMode = {
id: 'rag',
name: 'RAG Model',
description: 'Retrieval-Augmented Generation model for discussions based on specific medical literature.',
systemPrompt: 'Use the RAG model to provide responses based on specific medical literature.'
};

return NextResponse.json<AssistantModeListResponse>({
assistantModes
})
assistantModes: [...assistantModes, ragAssistantMode]
});
}
1 change: 1 addition & 0 deletions src/app/api/chat-rooms/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import prisma, {Prisma} from "@/lib/prisma";
import {NextResponse} from "next/server";
import {RAGDocumentParser, RAGVisionParser} from "@/lib/health-data/parser/rag";

export interface ChatRoom extends Prisma.ChatRoomGetPayload<{
select: {
Expand Down
1 change: 1 addition & 0 deletions src/components/chat/chat-setting-side-bar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {ChatRoomGetResponse} from "@/app/api/chat-rooms/[id]/route";
import {AssistantModePatchRequest} from "@/app/api/assistant-modes/[id]/route";
import {LLMProvider, LLMProviderListResponse} from "@/app/api/llm-providers/route";
import {LLMProviderModel, LLMProviderModelListResponse} from "@/app/api/llm-providers/[id]/models/route";
import {RAGDocumentParser, RAGVisionParser} from "@/lib/health-data/parser/rag";

interface ChatSettingSideBarProps {
isRightSidebarOpen: boolean;
Expand Down
1 change: 1 addition & 0 deletions src/components/chat/chat-side-bar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip"
import {RAGDocumentParser, RAGVisionParser} from "@/lib/health-data/parser/rag";

interface ChatSideBarProps {
isLeftSidebarOpen: boolean;
Expand Down
148 changes: 148 additions & 0 deletions src/lib/health-data/parser/rag.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import {BaseDocumentParser, DocumentParseOptions, DocumentParseResult, DocumentParserModel, OCRParseResult} from "@/lib/health-data/parser/document/base-document";
import {BaseVisionParser, VisionParseOptions, VisionParserModel} from "@/lib/health-data/parser/vision/base-vision";
import {ChatPromptTemplate} from "@langchain/core/prompts";
import {HealthCheckupType} from "@/lib/health-data/parser/schema";
import {MessagePayload} from "@/lib/health-data/parser/prompt";
import {processBatchWithConcurrency} from "@/lib/health-data/parser/util";

export class RAGDocumentParser extends BaseDocumentParser {
get name(): string {
return 'RAGDocumentParser';
}

get apiKeyRequired(): boolean {
return true;
}

async models(): Promise<DocumentParserModel[]> {
return [
{id: 'rag-document-parse', name: 'RAG Document Parse'}
];
}

async ocr(options: DocumentParseOptions): Promise<OCRParseResult> {
// Implement OCR logic for RAG
return {ocr: {}};
}

async parse(options: DocumentParseOptions): Promise<DocumentParseResult> {
// Implement document parsing logic for RAG
return {document: {}};
}
}

export class RAGVisionParser extends BaseVisionParser {
get name(): string {
return 'RAGVisionParser';
}

async models(): Promise<VisionParserModel[]> {
return [
{id: 'rag-vision-parse', name: 'RAG Vision Parse'}
];
}

async parse(options: VisionParseOptions): Promise<HealthCheckupType> {
const llm = new ChatPromptTemplate({model: options.model.id, apiKey: options.apiKey});
const messages = options.messages || ChatPromptTemplate.fromMessages([]);
const chain = messages.pipe(llm.withStructuredOutput(HealthCheckupType, {
method: 'functionCalling',
}));
return await chain.withRetry({stopAfterAttempt: 3}).invoke(options.input);
}
}

export async function parseHealthDataWithRAG(options: { file: string, visionParser: VisionParserOptions, documentParser: DocumentParserOptions }) {
const {file, visionParser, documentParser} = options;

// prepare images
const imagePaths = await documentToImages({file});

// prepare ocr results
const ocrResults = await documentOCR({
document: file,
documentParser: documentParser
});

// prepare parse results
await processBatchWithConcurrency(
imagePaths,
async (path) => documentParse({document: path, documentParser: documentParser}),
3
);

// Merge the results
const baseInferenceOptions = {imagePaths, visionParser, documentParser};
const [
{finalHealthCheckup: resultTotal, mergedTestResultPage: resultTotalPages},
{finalHealthCheckup: resultText, mergedTestResultPage: resultTextPages},
{finalHealthCheckup: resultImage, mergedTestResultPage: resultImagePages}
] = await Promise.all([
inference({...baseInferenceOptions, excludeImage: false, excludeText: false}),
inference({...baseInferenceOptions, excludeImage: false, excludeText: true}),
inference({...baseInferenceOptions, excludeImage: true, excludeText: false}),
]);

const resultDictTotal = resultTotal.test_result;
const resultDictText = resultText.test_result;
const resultDictImage = resultImage.test_result;

const mergedTestResult: { [key: string]: any } = {};
const mergedPageResult: { [key: string]: { page: number } | null } = {};

for (const key of HealthCheckupSchema.shape.test_result.keyof().options) {
const valueTotal =
resultDictTotal.hasOwnProperty(key) &&
resultDictTotal[key] !== null &&
resultDictTotal[key]!.value !== null
? resultDictTotal[key]
: null;
const pageTotal = valueTotal !== null ? resultTotalPages[key] : null;

const valueText =
resultDictText.hasOwnProperty(key) &&
resultDictText[key] !== null &&
resultDictText[key]!.value !== null
? resultDictText[key]
: null;
const pageText = valueText !== null ? resultTextPages[key] : null;

const valueImage =
resultDictImage.hasOwnProperty(key) &&
resultDictImage[key] !== null &&
resultDictImage[key]!.value !== null
? resultDictImage[key]
: null;
const pageImage = valueImage !== null ? resultImagePages[key] : null;

if (valueTotal === null) {
if (valueText !== null) {
mergedTestResult[key] = valueText;
mergedPageResult[key] = pageText;
} else if (valueImage !== null) {
mergedTestResult[key] = valueImage;
mergedPageResult[key] = pageImage;
} else {
mergedTestResult[key] = valueText;
mergedPageResult[key] = pageText;
}
} else {
mergedTestResult[key] = valueTotal;
mergedPageResult[key] = pageTotal;
}
}

// remove all null values in mergedTestResult
for (const key in mergedTestResult) {
if (mergedTestResult[key] === null) {
delete mergedTestResult[key];
}
}

const healthCheckup = HealthCheckupSchema.parse({
...resultTotal,
test_result: mergedTestResult
});

return {data: [healthCheckup], pages: [mergedPageResult], ocrResults: [ocrResults]};
}