Skip to content

GH-1403: Implements Anthropic's prompt caching feature to improve tok… #4199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@
import org.springframework.ai.content.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
Expand Down Expand Up @@ -482,12 +482,25 @@ private Map<String, String> mergeHttpHeaders(Map<String, String> runtimeHttpHead

ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

// Get cache control from options
AnthropicChatOptions requestOptions = (AnthropicChatOptions) prompt.getOptions();
AnthropicApi.ChatCompletionRequest.CacheControl cacheControl = (requestOptions != null)
? requestOptions.getCacheControl() : null;

List<AnthropicMessage> userMessages = prompt.getInstructions()
.stream()
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
.map(message -> {
if (message.getMessageType() == MessageType.USER) {
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(message.getText())));
List<ContentBlock> contents = new ArrayList<>();

// Apply cache control if enabled for user messages
if (cacheControl != null) {
contents.add(new ContentBlock(message.getText(), cacheControl));
}
else {
contents.add(new ContentBlock(message.getText()));
}
if (message instanceof UserMessage userMessage) {
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {
List<ContentBlock> mediaContent = userMessage.getMedia().stream().map(media -> {
Expand Down Expand Up @@ -537,7 +550,6 @@ else if (message.getMessageType() == MessageType.TOOL) {
ChatCompletionRequest request = new ChatCompletionRequest(this.defaultOptions.getModel(), userMessages,
systemPrompt, this.defaultOptions.getMaxTokens(), this.defaultOptions.getTemperature(), stream);

AnthropicChatOptions requestOptions = (AnthropicChatOptions) prompt.getOptions();
request = ModelOptionsUtils.merge(requestOptions, request, ChatCompletionRequest.class);

// Add the tool definitions to the request's tools parameter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
* @author Thomas Vitale
* @author Alexandros Pappas
* @author Ilayaperumal Gopinathan
* @author Soby Chacko
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
Expand All @@ -59,6 +60,20 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
private @JsonProperty("top_k") Integer topK;
private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking;

/**
* Cache control for user messages. When set, enables caching for user messages.
* Uses the existing CacheControl record from AnthropicApi.ChatCompletionRequest.
*/
private @JsonProperty("cache_control") ChatCompletionRequest.CacheControl cacheControl;

public ChatCompletionRequest.CacheControl getCacheControl() {
return cacheControl;
}

public void setCacheControl(ChatCompletionRequest.CacheControl cacheControl) {
this.cacheControl = cacheControl;
}

/**
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
* completion requests.
Expand Down Expand Up @@ -111,6 +126,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.internalToolExecutionEnabled(fromOptions.getInternalToolExecutionEnabled())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.httpHeaders(fromOptions.getHttpHeaders() != null ? new HashMap<>(fromOptions.getHttpHeaders()) : null)
.cacheControl(fromOptions.getCacheControl())
.build();
}

Expand Down Expand Up @@ -267,12 +283,10 @@ public AnthropicChatOptions copy() {

@Override
public boolean equals(Object o) {
if (this == o) {
if (this == o)
return true;
}
if (!(o instanceof AnthropicChatOptions that)) {
if (!(o instanceof AnthropicChatOptions that))
return false;
}
return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens)
&& Objects.equals(this.metadata, that.metadata)
&& Objects.equals(this.stopSequences, that.stopSequences)
Expand All @@ -282,14 +296,15 @@ public boolean equals(Object o) {
&& Objects.equals(this.toolNames, that.toolNames)
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
&& Objects.equals(this.toolContext, that.toolContext)
&& Objects.equals(this.httpHeaders, that.httpHeaders);
&& Objects.equals(this.httpHeaders, that.httpHeaders)
&& Objects.equals(this.cacheControl, that.cacheControl);
}

@Override
public int hashCode() {
return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP,
this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled,
this.toolContext, this.httpHeaders);
this.toolContext, this.httpHeaders, this.cacheControl);
}

public static class Builder {
Expand Down Expand Up @@ -389,6 +404,14 @@ public Builder httpHeaders(Map<String, String> httpHeaders) {
return this;
}

/**
* Set cache control for user messages
*/
public Builder cacheControl(ChatCompletionRequest.CacheControl cacheControl) {
this.options.cacheControl = cacheControl;
return this;
}

public AnthropicChatOptions build() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl;
import org.springframework.ai.anthropic.api.StreamHelper.ChatCompletionResponseBuilder;
import org.springframework.ai.model.ApiKey;
import org.springframework.ai.model.ChatModelDescription;
Expand Down Expand Up @@ -66,6 +67,7 @@
* @author Jonghoon Park
* @author Claudio Silva Junior
* @author Filip Hrisafov
* @author Soby Chacko
* @since 1.0.0
*/
public final class AnthropicApi {
Expand Down Expand Up @@ -559,6 +561,14 @@ public record Metadata(@JsonProperty("user_id") String userId) {

}

/**
* @param type is the cache type supported by anthropic. <a href=
* "https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#cache-limitations">Doc</a>
*/
@JsonInclude(Include.NON_NULL)
public record CacheControl(String type) {
}

/**
* Configuration for the model's thinking mode.
*
Expand Down Expand Up @@ -765,8 +775,11 @@ public record ContentBlock(
@JsonProperty("thinking") String thinking,

// Redacted Thinking only
@JsonProperty("data") String data
) {
@JsonProperty("data") String data,

// cache object
@JsonProperty("cache_control") CacheControl cacheControl
) {
// @formatter:on

/**
Expand All @@ -784,23 +797,27 @@ public ContentBlock(String mediaType, String data) {
* @param source The source of the content.
*/
public ContentBlock(Type type, Source source) {
this(type, source, null, null, null, null, null, null, null, null, null, null);
this(type, source, null, null, null, null, null, null, null, null, null, null, null);
}

/**
* Create content block
* @param source The source of the content.
*/
public ContentBlock(Source source) {
this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null);
this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null, null);
}

/**
* Create content block
* @param text The text of the content.
*/
public ContentBlock(String text) {
this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null);
this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null, null);
}

public ContentBlock(String text, CacheControl cache) {
this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null, cache);
}

// Tool result
Expand All @@ -811,7 +828,7 @@ public ContentBlock(String text) {
* @param content The content of the tool result.
*/
public ContentBlock(Type type, String toolUseId, String content) {
this(type, null, null, null, null, null, null, toolUseId, content, null, null, null);
this(type, null, null, null, null, null, null, toolUseId, content, null, null, null, null);
}

/**
Expand All @@ -822,7 +839,7 @@ public ContentBlock(Type type, String toolUseId, String content) {
* @param index The index of the content block.
*/
public ContentBlock(Type type, Source source, String text, Integer index) {
this(type, source, text, index, null, null, null, null, null, null, null, null);
this(type, source, text, index, null, null, null, null, null, null, null, null, null);
}

// Tool use input JSON delta streaming
Expand All @@ -834,7 +851,7 @@ public ContentBlock(Type type, Source source, String text, Integer index) {
* @param input The input of the tool use.
*/
public ContentBlock(Type type, String id, String name, Map<String, Object> input) {
this(type, null, null, null, id, name, input, null, null, null, null, null);
this(type, null, null, null, id, name, input, null, null, null, null, null, null);
}

/**
Expand Down Expand Up @@ -1028,7 +1045,9 @@ public record ChatCompletionResponse(
public record Usage(
// @formatter:off
@JsonProperty("input_tokens") Integer inputTokens,
@JsonProperty("output_tokens") Integer outputTokens) {
@JsonProperty("output_tokens") Integer outputTokens,
@JsonProperty("cache_creation_input_tokens") Integer cacheCreationInputTokens,
@JsonProperty("cache_read_input_tokens") Integer cacheReadInputTokens) {
// @formatter:off
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.anthropic.api;

import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.CacheControl;

import java.util.function.Supplier;

/**
* Cache types supported by Anthropic's prompt caching feature.
*
* <p>
* Prompt caching allows reusing frequently used prompts to reduce costs and improve
* response times for repeated interactions.
*
* @see <a href=
* "https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching">Anthropic Prompt
* Caching</a>
* @author Claudio Silva Junior
* @author Soby Chacko
*/
public enum AnthropicCacheType {

/**
* Ephemeral cache with 5-minute lifetime, refreshed on each use.
*/
EPHEMERAL(() -> new CacheControl("ephemeral"));

private final Supplier<CacheControl> value;

AnthropicCacheType(Supplier<CacheControl> value) {
this.value = value;
}

/**
* Returns a new CacheControl instance for this cache type.
* @return a CacheControl instance configured for this cache type
*/
public CacheControl cacheControl() {
return value.get();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
* @author Christian Tzolov
* @author Jihoon Kim
* @author Alexandros Pappas
* @author Claudio Silva Junior
* @author Soby Chacko
* @since 1.0.0
*/
public class StreamHelper {
Expand Down Expand Up @@ -159,7 +161,7 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_START)) {
}
else if (contentBlockStartEvent.contentBlock() instanceof ContentBlockThinking thinkingBlock) {
ContentBlock cb = new ContentBlock(Type.THINKING, null, null, contentBlockStartEvent.index(), null,
null, null, null, null, null, thinkingBlock.thinking(), null);
null, null, null, null, null, thinkingBlock.thinking(), null, null);
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else {
Expand All @@ -176,12 +178,12 @@ else if (event.type().equals(EventType.CONTENT_BLOCK_DELTA)) {
}
else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaThinking thinking) {
ContentBlock cb = new ContentBlock(Type.THINKING_DELTA, null, null, contentBlockDeltaEvent.index(),
null, null, null, null, null, null, thinking.thinking(), null);
null, null, null, null, null, null, thinking.thinking(), null, null);
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaSignature sig) {
ContentBlock cb = new ContentBlock(Type.SIGNATURE_DELTA, null, null, contentBlockDeltaEvent.index(),
null, null, null, null, null, sig.signature(), null, null);
null, null, null, null, null, sig.signature(), null, null, null);
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else {
Expand All @@ -205,7 +207,9 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {

if (messageDeltaEvent.usage() != null) {
Usage totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
messageDeltaEvent.usage().outputTokens());
messageDeltaEvent.usage().outputTokens(),
contentBlockReference.get().usage.cacheCreationInputTokens(),
contentBlockReference.get().usage.cacheReadInputTokens());
contentBlockReference.get().withUsage(totalUsage);
}
}
Expand Down
Loading