chore: proper caching for anthropic models with openrouter

Kujtim Hoxha created

Change summary

internal/llm/provider/openai.go | 45 ++++++++++++++++++++++++++++++++--
1 file changed, 42 insertions(+), 3 deletions(-)

Detailed changes

internal/llm/provider/openai.go 🔗

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
+	"strings"
 	"time"
 
 	"github.com/charmbracelet/catwalk/pkg/catwalk"
@@ -56,14 +57,33 @@ func createOpenAIClient(opts providerClientOptions) openai.Client {
 }
 
 func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
+	isAnthropicModel := o.providerOptions.config.ID == "openrouter" && strings.HasPrefix(o.Model().ID, "anthropic/")
 	// Add system message first
 	systemMessage := o.providerOptions.systemMessage
 	if o.providerOptions.systemPromptPrefix != "" {
 		systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
 	}
-	openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage))
 
-	for _, msg := range messages {
+	systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
+	if isAnthropicModel && !o.providerOptions.disableCache {
+		systemTextBlock.SetExtraFields(
+			map[string]any{
+				"cache_control": map[string]string{
+					"type": "ephemeral",
+				},
+			},
+		)
+	}
+	var content []openai.ChatCompletionContentPartTextParam
+	content = append(content, systemTextBlock)
+	system := openai.SystemMessage(content)
+	openaiMessages = append(openaiMessages, system)
+
+	for i, msg := range messages {
+		cache := false
+		if i > len(messages)-3 {
+			cache = true
+		}
 		switch msg.Role {
 		case message.User:
 			var content []openai.ChatCompletionContentPartUnionParam
@@ -75,6 +95,13 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
 
 				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 			}
+			if cache && !o.providerOptions.disableCache && isAnthropicModel {
+				textBlock.SetExtraFields(map[string]any{
+					"cache_control": map[string]string{
+						"type": "ephemeral",
+					},
+				})
+			}
 
 			openaiMessages = append(openaiMessages, openai.UserMessage(content))
 
@@ -86,8 +113,20 @@ func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessag
 			hasContent := false
 			if msg.Content().String() != "" {
 				hasContent = true
+				textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
+				if cache && !o.providerOptions.disableCache && isAnthropicModel {
+					textBlock.SetExtraFields(map[string]any{
+						"cache_control": map[string]string{
+							"type": "ephemeral",
+						},
+					})
+				}
 				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
-					OfString: openai.String(msg.Content().String()),
+					OfArrayOfContentParts: []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
+						{
+							OfText: &textBlock,
+						},
+					},
 				}
 			}