diff --git a/bun.lockb b/bun.lockb index 7126c40..e6c5e5b 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/index.ts b/index.ts index 52b2aa7..ac7e171 100644 --- a/index.ts +++ b/index.ts @@ -4,3 +4,4 @@ export * from "./src/database"; export * from "./src/ratelimit"; export * from "./src/error"; export * from "./src/types"; +export { MODEL_NAME_WITH_PROVIDER_SPLITTER } from "./src/constants"; diff --git a/package.json b/package.json index 4cd157e..3a7153c 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@upstash/rag-chat", - "version": "0.0.25-alpha", + "version": "0.0.27-alpha", "main": "./dist/index.js", "module": "./dist/index.mjs", "types": "./dist/index.d.ts", @@ -49,10 +49,6 @@ "dependencies": { "@langchain/community": "^0.2.1", "@langchain/core": "^0.1.58", - "@langchain/openai": "^0.0.28", - "@upstash/ratelimit": "^1.1.3", - "@upstash/redis": "^1.31.1", - "@upstash/vector": "^1.1.1", "ai": "^3.1.1", "cheerio": "^1.0.0-rc.12", "d3-dsv": "^3.0.1", @@ -60,5 +56,11 @@ "langchain": "^0.2.0", "nanoid": "^5.0.7", "pdf-parse": "^1.1.1" + }, + "peerDependencies": { + "@upstash/redis": "^1.31.3", + "@upstash/vector": "^1.1.1", + "@upstash/ratelimit": "^1.1.3", + "@langchain/openai": "^0.0.34" } } diff --git a/src/constants.ts b/src/constants.ts index 1ae7f91..17ac760 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -14,3 +14,6 @@ export const DEFAULT_METADATA_KEY = "text"; //History related default options export const DEFAULT_HISTORY_TTL = 86_400; export const DEFAULT_HISTORY_LENGTH = 5; + +//We need that constant to split creator LLM such as `ChatOpenAI_gpt-3.5-turbo`. Format is `provider_modelName`. +export const MODEL_NAME_WITH_PROVIDER_SPLITTER = "_"; diff --git a/src/history/in-memory-custom-history.test.ts b/src/history/in-memory-custom-history.test.ts index f77b26f..9b4272f 100644 --- a/src/history/in-memory-custom-history.test.ts +++ b/src/history/in-memory-custom-history.test.ts @@ -3,7 +3,11 @@ import { CustomInMemoryChatMessageHistory } from "./in-memory-custom-history"; test("should give last 3 messages from in-memory", async () => { const messageHistoryLength = 3; - const history = new CustomInMemoryChatMessageHistory([], messageHistoryLength); + const history = new CustomInMemoryChatMessageHistory({ + messages: [], + topLevelChatHistoryLength: messageHistoryLength, + modelNameWithProvider: "", + }); await history.addUserMessage("Hello!"); await history.addAIMessage("Hello, human."); await history.addUserMessage("Whats your name?"); @@ -16,7 +20,10 @@ test("should give last 3 messages from in-memory", async () => { }); test("should give all the messages", async () => { - const history = new CustomInMemoryChatMessageHistory(); + const history = new CustomInMemoryChatMessageHistory({ + messages: [], + modelNameWithProvider: "", + }); await history.addUserMessage("Hello!"); await history.addAIMessage("Hello, human."); await history.addUserMessage("Whats your name?"); diff --git a/src/history/in-memory-custom-history.ts b/src/history/in-memory-custom-history.ts index d6e235c..e75ff51 100644 --- a/src/history/in-memory-custom-history.ts +++ b/src/history/in-memory-custom-history.ts @@ -2,17 +2,26 @@ import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; import type { BaseMessage } from "@langchain/core/messages"; +export type CustomInMemoryChatMessageHistoryInput = { + messages?: BaseMessage[]; + topLevelChatHistoryLength?: number; + modelNameWithProvider: string; +}; + export class CustomInMemoryChatMessageHistory extends BaseListChatMessageHistory { lc_namespace = ["langchain", "stores", "message", "in_memory"]; private messages: BaseMessage[] = []; private topLevelChatHistoryLength?: number; + private modelNameWithProvider: string; - constructor(messages?: BaseMessage[], topLevelChatHistoryLength?: number) { + constructor(fields: CustomInMemoryChatMessageHistoryInput) { + const { modelNameWithProvider, messages, topLevelChatHistoryLength } = fields; // eslint-disable-next-line prefer-rest-params super(...arguments); this.messages = messages ?? []; this.topLevelChatHistoryLength = topLevelChatHistoryLength; + this.modelNameWithProvider = modelNameWithProvider; } /** @@ -32,7 +41,8 @@ export class CustomInMemoryChatMessageHistory extends BaseListChatMessageHistory * @returns A promise that resolves when the message has been added. */ async addMessage(message: BaseMessage) { - this.messages.push(message); + //@ts-expect-error This our way of mutating Message object to store model name with providers. + this.messages.push({ ...message, modelNameWithProvider: this.modelNameWithProvider }); } /** diff --git a/src/history/index.ts b/src/history/index.ts index b9c49c7..e0e1017 100644 --- a/src/history/index.ts +++ b/src/history/index.ts @@ -3,16 +3,25 @@ import { CustomInMemoryChatMessageHistory } from "./in-memory-custom-history"; import { CustomUpstashRedisChatMessageHistory } from "./redis-custom-history"; import { InternalUpstashError } from "../error"; +type HistoryConfig = { + redis?: Redis; + modelNameWithProvider: string; +}; type GetHistory = { sessionId: string; length?: number; sessionTTL?: number }; export class History { private redis?: Redis; + private modelNameWithProvider: string; private inMemoryChatHistory?: CustomInMemoryChatMessageHistory; - constructor(redis?: Redis) { + constructor(fields: HistoryConfig) { + const { modelNameWithProvider, redis } = fields; + this.redis = redis; + this.modelNameWithProvider = modelNameWithProvider; + if (!redis) { - this.inMemoryChatHistory = new CustomInMemoryChatMessageHistory(); + this.inMemoryChatHistory = new CustomInMemoryChatMessageHistory({ modelNameWithProvider }); } } @@ -24,6 +33,7 @@ export class History { sessionTTL, topLevelChatHistoryLength: length, client: this.redis, + modelNameWithProvider: this.modelNameWithProvider, }); } } catch (error) { diff --git a/src/history/redis-custom-history.ts b/src/history/redis-custom-history.ts index 6369704..b4f6ff7 100644 --- a/src/history/redis-custom-history.ts +++ b/src/history/redis-custom-history.ts @@ -19,6 +19,7 @@ export type CustomUpstashRedisChatMessageHistoryInput = { config?: RedisConfigNodejs; client?: Redis; topLevelChatHistoryLength?: number; + modelNameWithProvider: string; }; /** @@ -38,13 +39,21 @@ export class CustomUpstashRedisChatMessageHistory extends BaseListChatMessageHis public client: Redis; private sessionId: string; + private modelNameWithProvider: string; private sessionTTL?: number; private topLevelChatHistoryLength?: number; constructor(fields: CustomUpstashRedisChatMessageHistoryInput) { super(fields); - const { sessionId, sessionTTL, config, client, topLevelChatHistoryLength } = fields; + const { + sessionId, + sessionTTL, + config, + client, + topLevelChatHistoryLength, + modelNameWithProvider, + } = fields; if (client) { this.client = client; } else if (config) { @@ -54,7 +63,9 @@ export class CustomUpstashRedisChatMessageHistory extends BaseListChatMessageHis `Upstash Redis message stores require either a config object or a pre-configured client.` ); } + this.sessionId = sessionId; + this.modelNameWithProvider = modelNameWithProvider; this.sessionTTL = sessionTTL; this.topLevelChatHistoryLength = topLevelChatHistoryLength; } @@ -86,7 +97,10 @@ export class CustomUpstashRedisChatMessageHistory extends BaseListChatMessageHis */ async addMessage(message: BaseMessage): Promise { const messageToAdd = mapChatMessagesToStoredMessages([message]); - await this.client.lpush(this.sessionId, JSON.stringify(messageToAdd[0])); + await this.client.lpush( + this.sessionId, + JSON.stringify({ ...messageToAdd[0], modelNameWithProvider: this.modelNameWithProvider }) + ); if (this.sessionTTL) { await this.client.expire(this.sessionId, this.sessionTTL); } diff --git a/src/rag-chat.ts b/src/rag-chat.ts index 1450aa4..7a97c63 100644 --- a/src/rag-chat.ts +++ b/src/rag-chat.ts @@ -14,6 +14,7 @@ import { appendDefaultsIfNeeded } from "./utils"; import type { AddContextOptions, AddContextPayload } from "./database"; import { Database } from "./database"; import { History } from "./history"; +import { MODEL_NAME_WITH_PROVIDER_SPLITTER } from "./constants.ts"; export class RAGChat extends RAGChatBase { #ratelimitService: RateLimitService; @@ -21,7 +22,11 @@ export class RAGChat extends RAGChatBase { constructor(config: RAGChatConfig) { const { vector: index, redis } = new Config(config); - const historyService = new History(redis); + const historyService = new History({ + redis, + //@ts-expect-error We need that private field to track message creator LLM such as `ChatOpenAI_gpt-3.5-turbo`. Format is `provider_modelName`. + modelNameWithProvider: `${config.model?.getName()}${MODEL_NAME_WITH_PROVIDER_SPLITTER}${config.model?.modelName}`, + }); const vectorService = new Database(index); const ratelimitService = new RateLimitService(config.ratelimit);