diff --git a/bun.lockb b/bun.lockb index 98a9e9d..d0edba9 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/index.ts b/index.ts index 3844a68..9e09b8a 100644 --- a/index.ts +++ b/index.ts @@ -1,2 +1,3 @@ export * from "./src/rag-chat"; export * from "./src/services/history"; +export * from "./src/error"; diff --git a/package.json b/package.json index 2d1d403..6f2d9db 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@upstash/rag-chat", - "version": "0.0.11-alpha", + "version": "0.0.14-alpha", "main": "./dist/index.js", "module": "./dist/index.mjs", "types": "./dist/index.d.ts", @@ -51,7 +51,7 @@ "@langchain/community": "^0.0.50", "@langchain/core": "^0.1.58", "@langchain/openai": "^0.0.28", - "@upstash/sdk": "0.0.25-alpha", + "@upstash/sdk": "0.0.26-alpha", "ai": "^3.0.35" } } diff --git a/src/clients/redis/index.test.ts b/src/clients/redis/index.test.ts index d6f3542..394a239 100644 --- a/src/clients/redis/index.test.ts +++ b/src/clients/redis/index.test.ts @@ -1,7 +1,8 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { Upstash } from "@upstash/sdk"; import { describe, expect, test } from "bun:test"; -import { DEFAULT_REDIS_CONFIG, DEFAULT_REDIS_DB_NAME, RedisClient } from "."; +import { DEFAULT_REDIS_CONFIG, RedisClient } from "."; +import { DEFAULT_REDIS_DB_NAME } from "../../constants"; const upstashSDK = new Upstash({ email: process.env.UPSTASH_EMAIL!, diff --git a/src/clients/redis/index.ts b/src/clients/redis/index.ts index d78dc39..1a2cabf 100644 --- a/src/clients/redis/index.ts +++ b/src/clients/redis/index.ts @@ -2,8 +2,7 @@ import type { CreateCommandPayload, Upstash } from "@upstash/sdk"; import { Redis } from "@upstash/sdk"; import type { PreferredRegions } from "../../types"; - -export const DEFAULT_REDIS_DB_NAME = "upstash-rag-chat-redis"; +import { DEFAULT_REDIS_DB_NAME } from "../../constants"; export const DEFAULT_REDIS_CONFIG: CreateCommandPayload = { name: DEFAULT_REDIS_DB_NAME, diff --git a/src/clients/vector/index.test.ts b/src/clients/vector/index.test.ts index 4454c49..9819c3c 100644 --- a/src/clients/vector/index.test.ts +++ b/src/clients/vector/index.test.ts @@ -1,7 +1,8 @@ /* eslint-disable @typescript-eslint/no-non-null-assertion */ import { Upstash } from "@upstash/sdk"; import { describe, expect, test } from "bun:test"; -import { DEFAULT_VECTOR_DB_NAME, VectorClient, DEFAULT_VECTOR_CONFIG } from "."; +import { VectorClient, DEFAULT_VECTOR_CONFIG } from "."; +import { DEFAULT_VECTOR_DB_NAME } from "../../constants"; const upstashSDK = new Upstash({ email: process.env.UPSTASH_EMAIL!, diff --git a/src/clients/vector/index.ts b/src/clients/vector/index.ts index 7c475bf..31d88e6 100644 --- a/src/clients/vector/index.ts +++ b/src/clients/vector/index.ts @@ -2,8 +2,7 @@ import type { CreateIndexPayload, Upstash } from "@upstash/sdk"; import { Index } from "@upstash/sdk"; import type { PreferredRegions } from "../../types"; - -export const DEFAULT_VECTOR_DB_NAME = "upstash-rag-chat-vector"; +import { DEFAULT_VECTOR_DB_NAME } from "../../constants"; export const DEFAULT_VECTOR_CONFIG: CreateIndexPayload = { name: DEFAULT_VECTOR_DB_NAME, diff --git a/src/config.test.ts b/src/config.test.ts index 7fc3493..05a4c35 100644 --- a/src/config.test.ts +++ b/src/config.test.ts @@ -3,7 +3,8 @@ import { PromptTemplate } from "@langchain/core/prompts"; import { ChatOpenAI } from "@langchain/openai"; import { Index, Redis } from "@upstash/sdk"; import { expect, test } from "bun:test"; -import { Config, DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME } from "./config"; +import { Config } from "./config"; +import { DEFAULT_VECTOR_DB_NAME, DEFAULT_REDIS_DB_NAME } from "./constants"; const mockRedis = new Redis({ token: "hey", diff --git a/src/config.ts b/src/config.ts index d145472..d0b1175 100644 --- a/src/config.ts +++ b/src/config.ts @@ -1,32 +1,20 @@ import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base"; import type { PromptTemplate } from "@langchain/core/prompts"; +import type { Ratelimit } from "@upstash/sdk"; import { Redis } from "@upstash/sdk"; import { Index } from "@upstash/sdk"; -import type { PreferredRegions } from "./types"; - -type RAGChatConfigCommon = { - model?: BaseLanguageModelInterface; - template?: PromptTemplate; - region?: PreferredRegions; -}; - -const PREFERRED_REGION: PreferredRegions = "us-east-1"; -export const DEFAULT_VECTOR_DB_NAME = "upstash-rag-chat-vector"; -export const DEFAULT_REDIS_DB_NAME = "upstash-rag-chat-redis"; - -export type RAGChatConfig = { - vector?: string | Index; - redis?: string | Redis; -} & RAGChatConfigCommon; +import type { PreferredRegions, RAGChatConfig } from "./types"; +import { DEFAULT_REDIS_DB_NAME, DEFAULT_VECTOR_DB_NAME, PREFERRED_REGION } from "./constants"; export class Config { public readonly token: string; public readonly email: string; - public readonly region: PreferredRegions; public readonly vector?: string | Index; public readonly redis?: string | Redis; + public readonly ratelimit?: Ratelimit; + public readonly region: PreferredRegions; public readonly model?: BaseLanguageModelInterface; public readonly template?: PromptTemplate; @@ -45,6 +33,8 @@ export class Config { ? config.redis : DEFAULT_REDIS_DB_NAME; + this.ratelimit = config?.ratelimit; + this.model = config?.model; this.template = config?.template; } diff --git a/src/constants.ts b/src/constants.ts new file mode 100644 index 0000000..cd9c7b9 --- /dev/null +++ b/src/constants.ts @@ -0,0 +1,10 @@ +import type { PreferredRegions } from "./types"; + +export const DEFAULT_CHAT_SESSION_ID = "upstash-rag-chat-session"; +export const DEFAULT_CHAT_RATELIMIT_SESSION_ID = "upstash-rag-chat-ratelimit-session"; + +export const RATELIMIT_ERROR_MESSAGE = "ERR:USER_RATELIMITED"; + +export const DEFAULT_VECTOR_DB_NAME = "upstash-rag-chat-vector"; +export const DEFAULT_REDIS_DB_NAME = "upstash-rag-chat-redis"; +export const PREFERRED_REGION: PreferredRegions = "us-east-1"; diff --git a/src/error/index.ts b/src/error/index.ts new file mode 100644 index 0000000..c42f9bc --- /dev/null +++ b/src/error/index.ts @@ -0,0 +1 @@ +export * from "./ratelimit"; diff --git a/src/error/internal.ts b/src/error/internal.ts index b115f3c..7a7f338 100644 --- a/src/error/internal.ts +++ b/src/error/internal.ts @@ -1,6 +1,6 @@ export class InternalUpstashError extends Error { constructor(message: string) { super(message); - this.name = "InternalUpstashError"; + this.name = "InternalError"; } } diff --git a/src/error/model.ts b/src/error/model.ts index 179966f..aa19ed1 100644 --- a/src/error/model.ts +++ b/src/error/model.ts @@ -1,6 +1,6 @@ export class UpstashModelError extends Error { constructor(message: string) { super(message); - this.name = "UpstashModelError"; + this.name = "ModelError"; } } diff --git a/src/error/ratelimit.ts b/src/error/ratelimit.ts new file mode 100644 index 0000000..567f7e7 --- /dev/null +++ b/src/error/ratelimit.ts @@ -0,0 +1,14 @@ +import type { RATELIMIT_ERROR_MESSAGE } from "../constants"; + +type RatelimitResponse = { + error: typeof RATELIMIT_ERROR_MESSAGE; + resetTime?: number; +}; + +export class RatelimitUpstashError extends Error { + constructor(message: string, cause: RatelimitResponse) { + super(message); + this.name = "RatelimitError"; + this.cause = cause; + } +} diff --git a/src/rag-chat.ts b/src/rag-chat.ts index d7a6ec3..6efcd62 100644 --- a/src/rag-chat.ts +++ b/src/rag-chat.ts @@ -3,7 +3,7 @@ import type { BaseMessage } from "@langchain/core/messages"; import { RunnableSequence, RunnableWithMessageHistory } from "@langchain/core/runnables"; import { LangChainStream, StreamingTextResponse } from "ai"; -import { formatChatHistory, sanitizeQuestion } from "./utils"; +import { appendDefaultsIfNeeded, formatChatHistory, sanitizeQuestion } from "./utils"; import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base"; import type { PromptTemplate } from "@langchain/core/prompts"; @@ -13,24 +13,16 @@ import { HistoryService } from "./services/history"; import { RetrievalService } from "./services/retrieval"; import { QA_TEMPLATE } from "./prompts"; import { UpstashModelError } from "./error/model"; +import { RateLimitService } from "./services/ratelimit"; +import type { ChatOptions, PrepareChatResult, RAGChatConfig } from "./types"; +import { RatelimitUpstashError } from "./error/ratelimit"; type CustomInputValues = { chat_history?: BaseMessage[]; question: string; context: string }; -type ChatOptions = { - stream: boolean; - sessionId: string; - includeHistory?: number; - similarityThreshold?: number; -}; - -type PrepareChatResult = { - question: string; - facts: string; -}; - export class RAGChat { private retrievalService: RetrievalService; private historyService: HistoryService; + private ratelimitService: RateLimitService; private model: BaseLanguageModelInterface; private template: PromptTemplate; @@ -38,10 +30,12 @@ export class RAGChat { constructor( retrievalService: RetrievalService, historyService: HistoryService, + ratelimitService: RateLimitService, config: { model: BaseLanguageModelInterface; template: PromptTemplate } ) { this.retrievalService = retrievalService; this.historyService = historyService; + this.ratelimitService = ratelimitService; this.model = config.model; this.template = config.template; @@ -56,18 +50,33 @@ export class RAGChat { return { question, facts }; } - async chat(input: string, options: ChatOptions) { + async chat( + input: string, + options: ChatOptions + ): Promise> { + const options_ = appendDefaultsIfNeeded(options); + const { success, resetTime } = await this.ratelimitService.checkLimit( + options_.ratelimitSessionId + ); + + if (!success) { + throw new RatelimitUpstashError("Couldn't process chat due to ratelimit.", { + error: "ERR:USER_RATELIMITED", + resetTime: resetTime, + }); + } + const { question, facts } = await this.prepareChat(input, options.similarityThreshold); return options.stream - ? this.streamingChainCall(question, facts, options) - : this.chainCall(options, question, facts); + ? this.streamingChainCall(options_, question, facts) + : this.chainCall(options_, question, facts); } private streamingChainCall = ( + chatOptions: ChatOptions, question: string, - facts: string, - chatOptions: ChatOptions + facts: string ): StreamingTextResponse => { const { stream, handlers } = LangChainStream(); void this.chainCall(chatOptions, question, facts, [handlers]); @@ -75,7 +84,7 @@ export class RAGChat { }; private chainCall( - chatOptions: { sessionId: string; includeHistory?: number }, + chatOptions: ChatOptions, question: string, facts: string, handlers?: Callbacks @@ -113,7 +122,9 @@ export class RAGChat { ); } - static async initialize(config: Config): Promise { + static async initialize( + config: RAGChatConfig & { email: string; token: string } + ): Promise { const clientFactory = new ClientFactory( new Config(config.email, config.token, { redis: config.redis, @@ -125,12 +136,13 @@ export class RAGChat { const historyService = new HistoryService(redis); const retrievalService = new RetrievalService(index); + const ratelimitService = new RateLimitService(config.ratelimit); if (!config.model) { throw new UpstashModelError("Model can not be undefined!"); } - return new RAGChat(retrievalService, historyService, { + return new RAGChat(retrievalService, historyService, ratelimitService, { model: config.model, template: config.template ?? QA_TEMPLATE, }); diff --git a/src/services/history.ts b/src/services/history.ts index c541591..d876eab 100644 --- a/src/services/history.ts +++ b/src/services/history.ts @@ -1,8 +1,8 @@ import type { Redis } from "@upstash/sdk"; import { CustomUpstashRedisChatMessageHistory } from "./redis-custom-history"; -import type { RAGChatConfig } from "../config"; import { Config } from "../config"; import { ClientFactory } from "../client-factory"; +import type { RAGChatConfig } from "../types"; const DAY_IN_SECONDS = 86_400; const TOP_6 = 5; diff --git a/src/services/ratelimit.ts b/src/services/ratelimit.ts new file mode 100644 index 0000000..ca131b3 --- /dev/null +++ b/src/services/ratelimit.ts @@ -0,0 +1,22 @@ +import type { Ratelimit } from "@upstash/sdk"; + +export class RateLimitService { + private ratelimit?: Ratelimit; + + constructor(ratelimit?: Ratelimit) { + this.ratelimit = ratelimit; + } + + async checkLimit(sessionId: string): Promise<{ success: boolean; resetTime?: number }> { + if (!this.ratelimit) { + // If no ratelimit object is provided, always allow the operation. + return { success: true }; + } + + const result = await this.ratelimit.limit(sessionId); + if (!result.success) { + return { success: false, resetTime: result.reset }; + } + return { success: true }; + } +} diff --git a/src/types.ts b/src/types.ts index c0e75f8..bc4a451 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1 +1,30 @@ +import type { BaseLanguageModelInterface } from "@langchain/core/language_models/base"; +import type { PromptTemplate } from "@langchain/core/prompts"; +import type { Index, Ratelimit, Redis } from "@upstash/sdk"; + export type PreferredRegions = "eu-west-1" | "us-east-1"; + +export type ChatOptions = { + stream: boolean; + sessionId?: string; + includeHistory?: number; + similarityThreshold?: number; + ratelimitSessionId?: string; +}; + +export type PrepareChatResult = { + question: string; + facts: string; +}; + +type RAGChatConfigCommon = { + model?: BaseLanguageModelInterface; + template?: PromptTemplate; + region?: PreferredRegions; + ratelimit?: Ratelimit; +}; + +export type RAGChatConfig = { + vector?: string | Index; + redis?: string | Redis; +} & RAGChatConfigCommon; diff --git a/src/utils.ts b/src/utils.ts index 3c5ac68..2a3a624 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,4 +1,6 @@ import type { BaseMessage } from "@langchain/core/messages"; +import type { ChatOptions } from "./types"; +import { DEFAULT_CHAT_SESSION_ID, DEFAULT_CHAT_RATELIMIT_SESSION_ID } from "./constants"; export const sanitizeQuestion = (question: string) => { return question.trim().replaceAll("\n", " "); @@ -17,3 +19,11 @@ export const formatChatHistory = (chatHistory: BaseMessage[]) => { return formatFacts(formattedDialogueTurns); }; + +export function appendDefaultsIfNeeded(options: ChatOptions) { + return { + ...options, + sessionId: options.sessionId ?? DEFAULT_CHAT_SESSION_ID, + ratelimitSessionId: options.ratelimitSessionId ?? DEFAULT_CHAT_RATELIMIT_SESSION_ID, + } satisfies ChatOptions; +}