From 8f5430d5035cf8b5e1f9c86c54116a81cebfd194 Mon Sep 17 00:00:00 2001 From: Haouari haitam Kouider <57036855+haouarihk@users.noreply.github.com> Date: Sat, 4 Jan 2025 12:24:32 +0100 Subject: [PATCH] added support for o1 and o1-mini --- esbuild.config.mjs | 8 +- package.json | 7 +- src/LLMProviders/index.ts | 2 + src/LLMProviders/langchain/openaiAgent.tsx | 250 +++++++++++++++++++++ 4 files changed, 263 insertions(+), 4 deletions(-) create mode 100644 src/LLMProviders/langchain/openaiAgent.tsx diff --git a/esbuild.config.mjs b/esbuild.config.mjs index 7c5d9f3e..5599244e 100644 --- a/esbuild.config.mjs +++ b/esbuild.config.mjs @@ -4,7 +4,9 @@ import builtins from 'builtin-modules' import path from "path"; import fs from "fs"; import obsidianAliasPlugin from "./obsidian-alias/index.js"; +import processFallback from './obsidian-alias/process.js'; import { exec } from 'child_process'; +import { definePlugin } from 'esbuild-plugin-define'; const banner = `/* @@ -66,8 +68,12 @@ const esbuildConfig = { wasmPlugin({ mode: "embed", }), - obsidianAliasPlugin() + obsidianAliasPlugin(), + definePlugin({ + 'process': processFallback + }) ], + external: [ 'obsidian', 'electron', diff --git a/package.json b/package.json index a904a21e..c82848c0 100644 --- a/package.json +++ b/package.json @@ -68,9 +68,9 @@ "@koa/cors": "^5.0.0", "@langchain/anthropic": "^0.3.8", "@langchain/community": "^0.2.25", - "@langchain/core": "^0.3.1", + "@langchain/core": "^0.3.27", "@langchain/google-genai": "^0.1.4", - "@langchain/openai": "^0.3.14", + "@langchain/openai": "^0.3.16", "@mozilla/readability": "^0.4.4", "@react-icons/all-files": "^4.1.0", "@rjsf/core": "^5.20.0", @@ -86,6 +86,7 @@ "clsx": "^2.1.1", "debounce": "^2.1.0", "debug": "^2.6.9", + "esbuild-plugin-define": "^0.5.0", "esbuild-wasm": "^0.21.5", "fs-extra": "^11.2.0", "func-cache": "^2.2.62", @@ -95,7 +96,7 @@ "json5": "^2.2.3", "koa": "^2.15.3", "koa-proxies": "^0.12.4", - "langchain": "^0.3.6", + "langchain": "^0.3.9", "live-plugin-manager": "^0.18.1", "lockfile": "^1.0.4", "lodash.get": "^4.4.2", diff --git a/src/LLMProviders/index.ts b/src/LLMProviders/index.ts index 44eddaba..1b006a3b 100644 --- a/src/LLMProviders/index.ts +++ b/src/LLMProviders/index.ts @@ -11,6 +11,7 @@ import LangchainAzureOpenAIChatProvider from "./langchain/azureOpenAIChat"; import LangchainAzureOpenAIInstructProvider from "./langchain/azureOpenAIInstruct"; import LangchainPalmProvider from "./langchain/palm"; import LangchainChatGoogleGenerativeAIProvider from "./langchain/googleGenerativeAI"; +import LangchainOpenAIAgentProvider from "./langchain/openaiAgent"; // import LangchainReplicaProvider from "./langchain/replica" // import { LOCClone1, LOCClone2 } from "./langchain/clones"; @@ -19,6 +20,7 @@ export const defaultProviders = [ // openai LangchainOpenAIChatProvider, LangchainOpenAIInstructProvider, + LangchainOpenAIAgentProvider, // google LangchainChatGoogleGenerativeAIProvider, diff --git a/src/LLMProviders/langchain/openaiAgent.tsx b/src/LLMProviders/langchain/openaiAgent.tsx new file mode 100644 index 00000000..f266cdc9 --- /dev/null +++ b/src/LLMProviders/langchain/openaiAgent.tsx @@ -0,0 +1,250 @@ +import React from "react"; +import LangchainBase from "./base"; + +import LLMProviderInterface, { LLMConfig } from "../interface"; +import { IconExternalLink } from "@tabler/icons-react"; +import { HeaderEditor, ModelsHandler } from "../utils"; +import debug from "debug"; + +import { AI_MODELS, Input, Message, SettingItem, useGlobal } from "../refs"; +import { OpenAIChatInput } from "@langchain/openai"; + +const logger = debug("textgenerator:llmProvider:openaiChat"); + +const default_values = { + basePath: "https://api.openai.com/v1", + model: "o1-mini", +}; + +export default class LangchainOpenAIChatProvider + extends LangchainBase + implements LLMProviderInterface { + /** for models to know what provider is that, for example if this class is being extended. and the id changes. */ + + static provider = "Langchain"; + static id = "OpenAI Agent (Langchain)" as const; + static slug = "openAIAgent" as const; + static displayName = "OpenAI Agent"; + + id = LangchainOpenAIChatProvider.id; + provider = LangchainOpenAIChatProvider.provider; + originalId = LangchainOpenAIChatProvider.id; + + async load() { + const { ChatOpenAI } = await import("@langchain/openai"); + this.llmClass = ChatOpenAI; + } + + getConfig(options: LLMConfig) { + return this.cleanConfig({ + openAIApiKey: options.api_key, + + // ------------Necessary stuff-------------- + modelKwargs: options.modelKwargs, + modelName: options.model, + // frequencyPenalty: +options.frequency_penalty || 0, + presencePenalty: +options.presence_penalty || 0, + n: options.n || 1, + stop: options.stop || undefined, + streaming: options.stream || false, + maxRetries: 3, + headers: options.headers || undefined, + + bodyParams: { + max_completion_tokens: +options.max_tokens, + }, + } as Partial); + } + + + // getLLM(options: LLMConfig) { + // return new ChatOpenAI({ + // ...this.getConfig(options), + // }); + // } + + RenderSettings(props: Parameters[0]) { + const global = useGlobal(); + + const id = props.self.id; + const config = (global.plugin.settings.LLMProviderOptions[id] ??= { + ...default_values, + }); + + return ( + + + { + if (props.self.originalId == id) + global.plugin.settings.api_key = value; + config.api_key = value; + + global.triggerReload(); + global.plugin.encryptAllKeys(); + // TODO: it could use a debounce here + await global.plugin.saveSettings(); + }} + /> + + + { + config.basePath = value || default_values.basePath; + global.plugin.settings.endpoint = + value || default_values.basePath; + global.triggerReload(); + // TODO: it could use a debounce here + await global.plugin.saveSettings(); + }} + /> + + + + + { + if (!value) config.headers = undefined; + else config.headers = "{}"; + global.triggerReload(); + await global.plugin.saveSettings(); + }} + headers={config.headers} + setHeaders={async (value) => { + config.headers = value; + global.triggerReload(); + await global.plugin.saveSettings(); + }} + /> + + + + ); + } + + async calcPrice( + tokens: number, + reqParams: Partial + ): Promise { + const model = reqParams.model; + const modelInfo = + AI_MODELS[model as keyof typeof AI_MODELS] || AI_MODELS["gpt-3.5-turbo"]; + + console.log(reqParams.max_tokens, modelInfo.prices.completion); + return ( + (tokens * modelInfo.prices.prompt + + (reqParams.max_tokens || 100) * modelInfo.prices.completion) / + 1000 + ); + } + + async calcTokens( + messages: Message[], + reqParams: Partial + ): ReturnType { + const model = reqParams.model; + const modelInfo = + AI_MODELS[model as keyof typeof AI_MODELS] || AI_MODELS["gpt-3.5-turbo"]; + + if (!modelInfo) + return { + tokens: 0, + maxTokens: 0, + }; + const encoder = this.plugin.tokensScope.getEncoderFromEncoding( + modelInfo.encoding + ); + + let tokensPerMessage, tokensPerName; + if (model && ["gpt-3.5-turbo", "gpt-3.5-turbo-0301"].includes(model)) { + tokensPerMessage = 4; + tokensPerName = -1; + } else if (model && ["gpt-4", "gpt-4-0314"].includes(model)) { + tokensPerMessage = 3; + tokensPerName = 1; + } else { + tokensPerMessage = 3; + tokensPerName = 1; + } + + let numTokens = 0; + for (const message of messages) { + numTokens += tokensPerMessage; + for (const [key, value] of Object.entries(message)) { + numTokens += encoder.encode(value).length; + if (key === "name") { + numTokens += tokensPerName; + } + } + } + + numTokens += 3; // every reply is primed with assistant + + return { + tokens: numTokens, + maxTokens: modelInfo.maxTokens, + }; + } +}