diff --git a/.changeset/crisp-chairs-read.md b/.changeset/crisp-chairs-read.md new file mode 100644 index 0000000000..45783c41b1 --- /dev/null +++ b/.changeset/crisp-chairs-read.md @@ -0,0 +1,5 @@ +--- +'@posthog/ai': patch +--- + +Fixes cache creation cost for Langchain with Anthropic diff --git a/packages/ai/src/langchain/callbacks.ts b/packages/ai/src/langchain/callbacks.ts index 1910031ac5..42710648cc 100644 --- a/packages/ai/src/langchain/callbacks.ts +++ b/packages/ai/src/langchain/callbacks.ts @@ -446,7 +446,10 @@ export class LangChainCallbackHandler extends BaseCallbackHandler { // Add additional token data to properties if (additionalTokenData.cacheReadInputTokens) { - eventProperties['$ai_cache_read_tokens'] = additionalTokenData.cacheReadInputTokens + eventProperties['$ai_cache_read_input_tokens'] = additionalTokenData.cacheReadInputTokens + } + if (additionalTokenData.cacheWriteInputTokens) { + eventProperties['$ai_cache_creation_input_tokens'] = additionalTokenData.cacheWriteInputTokens } if (additionalTokenData.reasoningTokens) { eventProperties['$ai_reasoning_tokens'] = additionalTokenData.reasoningTokens @@ -623,6 +626,15 @@ export class LangChainCallbackHandler extends BaseCallbackHandler { additionalTokenData.cacheReadInputTokens = usage.input_token_details.cache_read } else if (usage.cachedPromptTokens != null) { additionalTokenData.cacheReadInputTokens = usage.cachedPromptTokens + } else if (usage.cache_read_input_tokens != null) { + additionalTokenData.cacheReadInputTokens = usage.cache_read_input_tokens + } + + // Check for cache write/creation tokens in various formats + if (usage.cache_creation_input_tokens != null) { + additionalTokenData.cacheWriteInputTokens = usage.cache_creation_input_tokens + } else if (usage.input_token_details?.cache_creation != null) { + additionalTokenData.cacheWriteInputTokens = usage.input_token_details.cache_creation } // Check for reasoning tokens in various formats @@ -677,8 +689,10 @@ export class LangChainCallbackHandler extends BaseCallbackHandler { additionalTokenData.webSearchCount = webSearchCount } - // For Anthropic providers, LangChain reports input_tokens as the sum of input and cache read tokens. + // For Anthropic providers, LangChain reports input_tokens as the sum of all input tokens. // Our cost calculation expects them to be separate for Anthropic, so we subtract cache tokens. + // Both cache_read and cache_write tokens should be subtracted since Anthropic's raw API + // reports input_tokens as tokens NOT read from or used to create a cache. // For other providers (OpenAI, etc.), input_tokens already excludes cache tokens as expected. // Match logic consistent with plugin-server: exact match on provider OR substring match on model let isAnthropic = false @@ -688,8 +702,12 @@ export class LangChainCallbackHandler extends BaseCallbackHandler { isAnthropic = true } - if (isAnthropic && parsedUsage.input && additionalTokenData.cacheReadInputTokens) { - parsedUsage.input = Math.max(parsedUsage.input - additionalTokenData.cacheReadInputTokens, 0) + if (isAnthropic && parsedUsage.input) { + const cacheTokens = + (additionalTokenData.cacheReadInputTokens || 0) + (additionalTokenData.cacheWriteInputTokens || 0) + if (cacheTokens > 0) { + parsedUsage.input = Math.max(parsedUsage.input - cacheTokens, 0) + } } return [parsedUsage.input, parsedUsage.output, additionalTokenData] diff --git a/packages/ai/tests/callbacks.test.ts b/packages/ai/tests/callbacks.test.ts index 5210e9f434..a30d691d38 100644 --- a/packages/ai/tests/callbacks.test.ts +++ b/packages/ai/tests/callbacks.test.ts @@ -205,7 +205,7 @@ describe('LangChainCallbackHandler', () => { // Input tokens should NOT be reduced for OpenAI: 150 (no subtraction) expect(captureCall[0].properties['$ai_input_tokens']).toBe(150) expect(captureCall[0].properties['$ai_output_tokens']).toBe(40) - expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(100) + expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(100) }) it('should not subtract for OpenAI even when cache_read_tokens >= input_tokens', async () => { @@ -259,7 +259,7 @@ describe('LangChainCallbackHandler', () => { // Input tokens should NOT be reduced for OpenAI: 80 (no subtraction) expect(captureCall[0].properties['$ai_input_tokens']).toBe(80) expect(captureCall[0].properties['$ai_output_tokens']).toBe(20) - expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(100) + expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(100) }) it('should not subtract when there are no cache_read_tokens', async () => { @@ -362,7 +362,7 @@ describe('LangChainCallbackHandler', () => { // Input tokens should remain 0 (no subtraction because input_tokens is falsy) expect(captureCall[0].properties['$ai_input_tokens']).toBe(0) expect(captureCall[0].properties['$ai_output_tokens']).toBe(10) - expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(50) + expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(50) }) it('should subtract cache_read_tokens from input_tokens for Anthropic provider', async () => { @@ -416,7 +416,7 @@ describe('LangChainCallbackHandler', () => { // Input tokens should be reduced for Anthropic: 1200 - 800 = 400 expect(captureCall[0].properties['$ai_input_tokens']).toBe(400) expect(captureCall[0].properties['$ai_output_tokens']).toBe(50) - expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(800) + expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(800) }) it('should subtract cache_read_tokens when model name contains "anthropic"', async () => { @@ -469,7 +469,7 @@ describe('LangChainCallbackHandler', () => { // Should subtract because model name contains "anthropic": 500 - 200 = 300 expect(captureCall[0].properties['$ai_input_tokens']).toBe(300) expect(captureCall[0].properties['$ai_output_tokens']).toBe(30) - expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(200) + expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(200) }) it('should prevent negative input_tokens for Anthropic when cache_read >= input', async () => { @@ -522,6 +522,167 @@ describe('LangChainCallbackHandler', () => { // Should be max(100 - 150, 0) = 0 expect(captureCall[0].properties['$ai_input_tokens']).toBe(0) expect(captureCall[0].properties['$ai_output_tokens']).toBe(20) - expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(150) + expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(150) + }) + + it('should subtract cache_creation_input_tokens from input_tokens for Anthropic provider', async () => { + const serialized = { + lc: 1, + type: 'constructor' as const, + id: ['langchain', 'chat_models', 'anthropic', 'ChatAnthropic'], + kwargs: {}, + } + + const prompts = ['Test with Anthropic cache write'] + const runId = 'run_anthropic_cache_write_test' + const metadata = { ls_model_name: 'claude-3-sonnet-20240229', ls_provider: 'anthropic' } + const extraParams = { + invocation_params: { + temperature: 0.7, + }, + } + + handler.handleLLMStart(serialized, prompts, runId, undefined, extraParams, undefined, metadata) + + // For Anthropic, LangChain reports input_tokens as sum of input + cache_creation + // input_tokens=1000 includes 800 cache_creation tokens, so actual uncached input is 200 + const llmResult = { + generations: [ + [ + { + text: 'Response from Anthropic with cache creation.', + message: new AIMessage('Response from Anthropic with cache creation.'), + }, + ], + ], + llmOutput: { + tokenUsage: { + promptTokens: 1000, // Sum of actual uncached input (200) + cache creation (800) + completionTokens: 50, + totalTokens: 1050, + cache_creation_input_tokens: 800, // 800 tokens written to cache + }, + }, + } + + handler.handleLLMEnd(llmResult, runId) + + expect(mockPostHogClient.capture).toHaveBeenCalledTimes(1) + const [captureCall] = (mockPostHogClient.capture as jest.Mock).mock.calls + + expect(captureCall[0].event).toBe('$ai_generation') + // Input tokens should be reduced for Anthropic: 1000 - 800 = 200 + expect(captureCall[0].properties['$ai_input_tokens']).toBe(200) + expect(captureCall[0].properties['$ai_output_tokens']).toBe(50) + expect(captureCall[0].properties['$ai_cache_creation_input_tokens']).toBe(800) + }) + + it('should subtract both cache_read and cache_creation tokens for Anthropic', async () => { + const serialized = { + lc: 1, + type: 'constructor' as const, + id: ['langchain', 'chat_models', 'anthropic', 'ChatAnthropic'], + kwargs: {}, + } + + const prompts = ['Test with Anthropic cache read and write'] + const runId = 'run_anthropic_cache_both_test' + const metadata = { ls_model_name: 'claude-3-sonnet-20240229', ls_provider: 'anthropic' } + const extraParams = { + invocation_params: { + temperature: 0.7, + }, + } + + handler.handleLLMStart(serialized, prompts, runId, undefined, extraParams, undefined, metadata) + + // For Anthropic, LangChain reports input_tokens as sum of all tokens + // input_tokens=2000 includes 800 cache_read + 500 cache_creation, so uncached is 700 + const llmResult = { + generations: [ + [ + { + text: 'Response from Anthropic with both cache operations.', + message: new AIMessage('Response from Anthropic with both cache operations.'), + }, + ], + ], + llmOutput: { + tokenUsage: { + promptTokens: 2000, // Sum of uncached (700) + cache read (800) + cache creation (500) + completionTokens: 50, + totalTokens: 2050, + prompt_tokens_details: { + cached_tokens: 800, // 800 tokens read from cache + }, + cache_creation_input_tokens: 500, // 500 tokens written to cache + }, + }, + } + + handler.handleLLMEnd(llmResult, runId) + + expect(mockPostHogClient.capture).toHaveBeenCalledTimes(1) + const [captureCall] = (mockPostHogClient.capture as jest.Mock).mock.calls + + expect(captureCall[0].event).toBe('$ai_generation') + // Input tokens should be reduced for Anthropic: 2000 - 800 - 500 = 700 + expect(captureCall[0].properties['$ai_input_tokens']).toBe(700) + expect(captureCall[0].properties['$ai_output_tokens']).toBe(50) + expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(800) + expect(captureCall[0].properties['$ai_cache_creation_input_tokens']).toBe(500) + }) + + it('should not subtract cache_creation_input_tokens for non-Anthropic providers', async () => { + const serialized = { + lc: 1, + type: 'constructor' as const, + id: ['langchain', 'chat_models', 'openai', 'ChatOpenAI'], + kwargs: {}, + } + + const prompts = ['Test with OpenAI cache write'] + const runId = 'run_openai_cache_write_test' + const metadata = { ls_model_name: 'gpt-4', ls_provider: 'openai' } + const extraParams = { + invocation_params: { + temperature: 0.7, + }, + } + + handler.handleLLMStart(serialized, prompts, runId, undefined, extraParams, undefined, metadata) + + // For OpenAI, input_tokens is already separate from cache tokens + const llmResult = { + generations: [ + [ + { + text: 'Response from OpenAI with cache creation.', + message: new AIMessage('Response from OpenAI with cache creation.'), + }, + ], + ], + llmOutput: { + tokenUsage: { + promptTokens: 200, // Just the uncached tokens + completionTokens: 50, + totalTokens: 250, + input_token_details: { + cache_creation: 800, // OpenAI format for cache write + }, + }, + }, + } + + handler.handleLLMEnd(llmResult, runId) + + expect(mockPostHogClient.capture).toHaveBeenCalledTimes(1) + const [captureCall] = (mockPostHogClient.capture as jest.Mock).mock.calls + + expect(captureCall[0].event).toBe('$ai_generation') + // Input tokens should NOT be reduced for OpenAI + expect(captureCall[0].properties['$ai_input_tokens']).toBe(200) + expect(captureCall[0].properties['$ai_output_tokens']).toBe(50) + expect(captureCall[0].properties['$ai_cache_creation_input_tokens']).toBe(800) }) })