Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/crisp-chairs-read.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@posthog/ai': patch
---

Fixes cache creation cost for Langchain with Anthropic
26 changes: 22 additions & 4 deletions packages/ai/src/langchain/callbacks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,10 @@ export class LangChainCallbackHandler extends BaseCallbackHandler {

// Add additional token data to properties
if (additionalTokenData.cacheReadInputTokens) {
eventProperties['$ai_cache_read_tokens'] = additionalTokenData.cacheReadInputTokens
eventProperties['$ai_cache_read_input_tokens'] = additionalTokenData.cacheReadInputTokens
}
if (additionalTokenData.cacheWriteInputTokens) {
eventProperties['$ai_cache_creation_input_tokens'] = additionalTokenData.cacheWriteInputTokens
}
if (additionalTokenData.reasoningTokens) {
eventProperties['$ai_reasoning_tokens'] = additionalTokenData.reasoningTokens
Expand Down Expand Up @@ -623,6 +626,15 @@ export class LangChainCallbackHandler extends BaseCallbackHandler {
additionalTokenData.cacheReadInputTokens = usage.input_token_details.cache_read
} else if (usage.cachedPromptTokens != null) {
additionalTokenData.cacheReadInputTokens = usage.cachedPromptTokens
} else if (usage.cache_read_input_tokens != null) {
additionalTokenData.cacheReadInputTokens = usage.cache_read_input_tokens
}

// Check for cache write/creation tokens in various formats
if (usage.cache_creation_input_tokens != null) {
additionalTokenData.cacheWriteInputTokens = usage.cache_creation_input_tokens
} else if (usage.input_token_details?.cache_creation != null) {
additionalTokenData.cacheWriteInputTokens = usage.input_token_details.cache_creation
}

// Check for reasoning tokens in various formats
Expand Down Expand Up @@ -677,8 +689,10 @@ export class LangChainCallbackHandler extends BaseCallbackHandler {
additionalTokenData.webSearchCount = webSearchCount
}

// For Anthropic providers, LangChain reports input_tokens as the sum of input and cache read tokens.
// For Anthropic providers, LangChain reports input_tokens as the sum of all input tokens.
// Our cost calculation expects them to be separate for Anthropic, so we subtract cache tokens.
// Both cache_read and cache_write tokens should be subtracted since Anthropic's raw API
// reports input_tokens as tokens NOT read from or used to create a cache.
// For other providers (OpenAI, etc.), input_tokens already excludes cache tokens as expected.
// Match logic consistent with plugin-server: exact match on provider OR substring match on model
let isAnthropic = false
Expand All @@ -688,8 +702,12 @@ export class LangChainCallbackHandler extends BaseCallbackHandler {
isAnthropic = true
}

if (isAnthropic && parsedUsage.input && additionalTokenData.cacheReadInputTokens) {
parsedUsage.input = Math.max(parsedUsage.input - additionalTokenData.cacheReadInputTokens, 0)
if (isAnthropic && parsedUsage.input) {
const cacheTokens =
(additionalTokenData.cacheReadInputTokens || 0) + (additionalTokenData.cacheWriteInputTokens || 0)
if (cacheTokens > 0) {
parsedUsage.input = Math.max(parsedUsage.input - cacheTokens, 0)
}
}

return [parsedUsage.input, parsedUsage.output, additionalTokenData]
Expand Down
173 changes: 167 additions & 6 deletions packages/ai/tests/callbacks.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ describe('LangChainCallbackHandler', () => {
// Input tokens should NOT be reduced for OpenAI: 150 (no subtraction)
expect(captureCall[0].properties['$ai_input_tokens']).toBe(150)
expect(captureCall[0].properties['$ai_output_tokens']).toBe(40)
expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(100)
expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(100)
})

it('should not subtract for OpenAI even when cache_read_tokens >= input_tokens', async () => {
Expand Down Expand Up @@ -259,7 +259,7 @@ describe('LangChainCallbackHandler', () => {
// Input tokens should NOT be reduced for OpenAI: 80 (no subtraction)
expect(captureCall[0].properties['$ai_input_tokens']).toBe(80)
expect(captureCall[0].properties['$ai_output_tokens']).toBe(20)
expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(100)
expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(100)
})

it('should not subtract when there are no cache_read_tokens', async () => {
Expand Down Expand Up @@ -362,7 +362,7 @@ describe('LangChainCallbackHandler', () => {
// Input tokens should remain 0 (no subtraction because input_tokens is falsy)
expect(captureCall[0].properties['$ai_input_tokens']).toBe(0)
expect(captureCall[0].properties['$ai_output_tokens']).toBe(10)
expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(50)
expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(50)
})

it('should subtract cache_read_tokens from input_tokens for Anthropic provider', async () => {
Expand Down Expand Up @@ -416,7 +416,7 @@ describe('LangChainCallbackHandler', () => {
// Input tokens should be reduced for Anthropic: 1200 - 800 = 400
expect(captureCall[0].properties['$ai_input_tokens']).toBe(400)
expect(captureCall[0].properties['$ai_output_tokens']).toBe(50)
expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(800)
expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(800)
})

it('should subtract cache_read_tokens when model name contains "anthropic"', async () => {
Expand Down Expand Up @@ -469,7 +469,7 @@ describe('LangChainCallbackHandler', () => {
// Should subtract because model name contains "anthropic": 500 - 200 = 300
expect(captureCall[0].properties['$ai_input_tokens']).toBe(300)
expect(captureCall[0].properties['$ai_output_tokens']).toBe(30)
expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(200)
expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(200)
})

it('should prevent negative input_tokens for Anthropic when cache_read >= input', async () => {
Expand Down Expand Up @@ -522,6 +522,167 @@ describe('LangChainCallbackHandler', () => {
// Should be max(100 - 150, 0) = 0
expect(captureCall[0].properties['$ai_input_tokens']).toBe(0)
expect(captureCall[0].properties['$ai_output_tokens']).toBe(20)
expect(captureCall[0].properties['$ai_cache_read_tokens']).toBe(150)
expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(150)
})

it('should subtract cache_creation_input_tokens from input_tokens for Anthropic provider', async () => {
const serialized = {
lc: 1,
type: 'constructor' as const,
id: ['langchain', 'chat_models', 'anthropic', 'ChatAnthropic'],
kwargs: {},
}

const prompts = ['Test with Anthropic cache write']
const runId = 'run_anthropic_cache_write_test'
const metadata = { ls_model_name: 'claude-3-sonnet-20240229', ls_provider: 'anthropic' }
const extraParams = {
invocation_params: {
temperature: 0.7,
},
}

handler.handleLLMStart(serialized, prompts, runId, undefined, extraParams, undefined, metadata)

// For Anthropic, LangChain reports input_tokens as sum of input + cache_creation
// input_tokens=1000 includes 800 cache_creation tokens, so actual uncached input is 200
const llmResult = {
generations: [
[
{
text: 'Response from Anthropic with cache creation.',
message: new AIMessage('Response from Anthropic with cache creation.'),
},
],
],
llmOutput: {
tokenUsage: {
promptTokens: 1000, // Sum of actual uncached input (200) + cache creation (800)
completionTokens: 50,
totalTokens: 1050,
cache_creation_input_tokens: 800, // 800 tokens written to cache
},
},
}

handler.handleLLMEnd(llmResult, runId)

expect(mockPostHogClient.capture).toHaveBeenCalledTimes(1)
const [captureCall] = (mockPostHogClient.capture as jest.Mock).mock.calls

expect(captureCall[0].event).toBe('$ai_generation')
// Input tokens should be reduced for Anthropic: 1000 - 800 = 200
expect(captureCall[0].properties['$ai_input_tokens']).toBe(200)
expect(captureCall[0].properties['$ai_output_tokens']).toBe(50)
expect(captureCall[0].properties['$ai_cache_creation_input_tokens']).toBe(800)
})

it('should subtract both cache_read and cache_creation tokens for Anthropic', async () => {
const serialized = {
lc: 1,
type: 'constructor' as const,
id: ['langchain', 'chat_models', 'anthropic', 'ChatAnthropic'],
kwargs: {},
}

const prompts = ['Test with Anthropic cache read and write']
const runId = 'run_anthropic_cache_both_test'
const metadata = { ls_model_name: 'claude-3-sonnet-20240229', ls_provider: 'anthropic' }
const extraParams = {
invocation_params: {
temperature: 0.7,
},
}

handler.handleLLMStart(serialized, prompts, runId, undefined, extraParams, undefined, metadata)

// For Anthropic, LangChain reports input_tokens as sum of all tokens
// input_tokens=2000 includes 800 cache_read + 500 cache_creation, so uncached is 700
const llmResult = {
generations: [
[
{
text: 'Response from Anthropic with both cache operations.',
message: new AIMessage('Response from Anthropic with both cache operations.'),
},
],
],
llmOutput: {
tokenUsage: {
promptTokens: 2000, // Sum of uncached (700) + cache read (800) + cache creation (500)
completionTokens: 50,
totalTokens: 2050,
prompt_tokens_details: {
cached_tokens: 800, // 800 tokens read from cache
},
cache_creation_input_tokens: 500, // 500 tokens written to cache
},
},
}

handler.handleLLMEnd(llmResult, runId)

expect(mockPostHogClient.capture).toHaveBeenCalledTimes(1)
const [captureCall] = (mockPostHogClient.capture as jest.Mock).mock.calls

expect(captureCall[0].event).toBe('$ai_generation')
// Input tokens should be reduced for Anthropic: 2000 - 800 - 500 = 700
expect(captureCall[0].properties['$ai_input_tokens']).toBe(700)
expect(captureCall[0].properties['$ai_output_tokens']).toBe(50)
expect(captureCall[0].properties['$ai_cache_read_input_tokens']).toBe(800)
expect(captureCall[0].properties['$ai_cache_creation_input_tokens']).toBe(500)
})

it('should not subtract cache_creation_input_tokens for non-Anthropic providers', async () => {
const serialized = {
lc: 1,
type: 'constructor' as const,
id: ['langchain', 'chat_models', 'openai', 'ChatOpenAI'],
kwargs: {},
}

const prompts = ['Test with OpenAI cache write']
const runId = 'run_openai_cache_write_test'
const metadata = { ls_model_name: 'gpt-4', ls_provider: 'openai' }
const extraParams = {
invocation_params: {
temperature: 0.7,
},
}

handler.handleLLMStart(serialized, prompts, runId, undefined, extraParams, undefined, metadata)

// For OpenAI, input_tokens is already separate from cache tokens
const llmResult = {
generations: [
[
{
text: 'Response from OpenAI with cache creation.',
message: new AIMessage('Response from OpenAI with cache creation.'),
},
],
],
llmOutput: {
tokenUsage: {
promptTokens: 200, // Just the uncached tokens
completionTokens: 50,
totalTokens: 250,
input_token_details: {
cache_creation: 800, // OpenAI format for cache write
},
},
},
}

handler.handleLLMEnd(llmResult, runId)

expect(mockPostHogClient.capture).toHaveBeenCalledTimes(1)
const [captureCall] = (mockPostHogClient.capture as jest.Mock).mock.calls

expect(captureCall[0].event).toBe('$ai_generation')
// Input tokens should NOT be reduced for OpenAI
expect(captureCall[0].properties['$ai_input_tokens']).toBe(200)
expect(captureCall[0].properties['$ai_output_tokens']).toBe(50)
expect(captureCall[0].properties['$ai_cache_creation_input_tokens']).toBe(800)
})
})
Loading