anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"regexp"
 11	"strconv"
 12	"strings"
 13	"time"
 14
 15	"github.com/anthropics/anthropic-sdk-go"
 16	"github.com/anthropics/anthropic-sdk-go/bedrock"
 17	"github.com/anthropics/anthropic-sdk-go/option"
 18	"github.com/charmbracelet/catwalk/pkg/catwalk"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/llm/tools"
 21	"github.com/charmbracelet/crush/internal/message"
 22)
 23
 24type anthropicClient struct {
 25	providerOptions   providerClientOptions
 26	useBedrock        bool
 27	client            anthropic.Client
 28	adjustedMaxTokens int // Used when context limit is hit
 29}
 30
 31type AnthropicClient ProviderClient
 32
 33func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 34	return &anthropicClient{
 35		providerOptions: opts,
 36		client:          createAnthropicClient(opts, useBedrock),
 37	}
 38}
 39
 40func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
 41	anthropicClientOptions := []option.RequestOption{}
 42	if opts.apiKey != "" {
 43		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 44	}
 45	if useBedrock {
 46		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 47	}
 48	for _, header := range opts.extraHeaders {
 49		anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(header, opts.extraHeaders[header]))
 50	}
 51	for key, value := range opts.extraBody {
 52		anthropicClientOptions = append(anthropicClientOptions, option.WithJSONSet(key, value))
 53	}
 54	return anthropic.NewClient(anthropicClientOptions...)
 55}
 56
 57func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 58	for i, msg := range messages {
 59		cache := false
 60		if i > len(messages)-3 {
 61			cache = true
 62		}
 63		switch msg.Role {
 64		case message.User:
 65			content := anthropic.NewTextBlock(msg.Content().String())
 66			if cache && !a.providerOptions.disableCache {
 67				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 68					Type: "ephemeral",
 69				}
 70			}
 71			var contentBlocks []anthropic.ContentBlockParamUnion
 72			contentBlocks = append(contentBlocks, content)
 73			for _, binaryContent := range msg.BinaryContent() {
 74				base64Image := binaryContent.String(catwalk.InferenceProviderAnthropic)
 75				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
 76				contentBlocks = append(contentBlocks, imageBlock)
 77			}
 78			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
 79
 80		case message.Assistant:
 81			blocks := []anthropic.ContentBlockParamUnion{}
 82
 83			// Add thinking blocks first if present (required when thinking is enabled with tool use)
 84			if reasoningContent := msg.ReasoningContent(); reasoningContent.Thinking != "" {
 85				thinkingBlock := anthropic.NewThinkingBlock(reasoningContent.Signature, reasoningContent.Thinking)
 86				blocks = append(blocks, thinkingBlock)
 87			}
 88
 89			if msg.Content().String() != "" {
 90				content := anthropic.NewTextBlock(msg.Content().String())
 91				if cache && !a.providerOptions.disableCache {
 92					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 93						Type: "ephemeral",
 94					}
 95				}
 96				blocks = append(blocks, content)
 97			}
 98
 99			for _, toolCall := range msg.ToolCalls() {
100				var inputMap map[string]any
101				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
102				if err != nil {
103					continue
104				}
105				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
106			}
107
108			if len(blocks) == 0 {
109				slog.Warn("There is a message without content, investigate, this should not happen")
110				continue
111			}
112			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
113
114		case message.Tool:
115			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
116			for i, toolResult := range msg.ToolResults() {
117				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
118			}
119			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
120		}
121	}
122	return
123}
124
125func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
126	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
127
128	for i, tool := range tools {
129		info := tool.Info()
130		toolParam := anthropic.ToolParam{
131			Name:        info.Name,
132			Description: anthropic.String(info.Description),
133			InputSchema: anthropic.ToolInputSchemaParam{
134				Properties: info.Parameters,
135				// TODO: figure out how we can tell claude the required fields?
136			},
137		}
138
139		if i == len(tools)-1 && !a.providerOptions.disableCache {
140			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
141				Type: "ephemeral",
142			}
143		}
144
145		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
146	}
147
148	return anthropicTools
149}
150
151func (a *anthropicClient) finishReason(reason string) message.FinishReason {
152	switch reason {
153	case "end_turn":
154		return message.FinishReasonEndTurn
155	case "max_tokens":
156		return message.FinishReasonMaxTokens
157	case "tool_use":
158		return message.FinishReasonToolUse
159	case "stop_sequence":
160		return message.FinishReasonEndTurn
161	default:
162		return message.FinishReasonUnknown
163	}
164}
165
166func (a *anthropicClient) isThinkingEnabled() bool {
167	cfg := config.Get()
168	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
169	if a.providerOptions.modelType == config.SelectedModelTypeSmall {
170		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
171	}
172	return a.Model().CanReason && modelConfig.Think
173}
174
175func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
176	model := a.providerOptions.model(a.providerOptions.modelType)
177	var thinkingParam anthropic.ThinkingConfigParamUnion
178	cfg := config.Get()
179	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
180	if a.providerOptions.modelType == config.SelectedModelTypeSmall {
181		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
182	}
183	temperature := anthropic.Float(0)
184
185	maxTokens := model.DefaultMaxTokens
186	if modelConfig.MaxTokens > 0 {
187		maxTokens = modelConfig.MaxTokens
188	}
189	if a.isThinkingEnabled() {
190		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(maxTokens) * 0.8))
191		temperature = anthropic.Float(1)
192	}
193	// Override max tokens if set in provider options
194	if a.providerOptions.maxTokens > 0 {
195		maxTokens = a.providerOptions.maxTokens
196	}
197
198	// Use adjusted max tokens if context limit was hit
199	if a.adjustedMaxTokens > 0 {
200		maxTokens = int64(a.adjustedMaxTokens)
201	}
202
203	return anthropic.MessageNewParams{
204		Model:       anthropic.Model(model.ID),
205		MaxTokens:   maxTokens,
206		Temperature: temperature,
207		Messages:    messages,
208		Tools:       tools,
209		Thinking:    thinkingParam,
210		System: []anthropic.TextBlockParam{
211			{
212				Text: a.providerOptions.systemMessage,
213				CacheControl: anthropic.CacheControlEphemeralParam{
214					Type: "ephemeral",
215				},
216			},
217		},
218	}
219}
220
221func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
222	cfg := config.Get()
223
224	attempts := 0
225	for {
226		attempts++
227		// Prepare messages on each attempt in case max_tokens was adjusted
228		preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
229		if cfg.Options.Debug {
230			jsonData, _ := json.Marshal(preparedMessages)
231			slog.Debug("Prepared messages", "messages", string(jsonData))
232		}
233
234		var opts []option.RequestOption
235		if a.isThinkingEnabled() {
236			opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14"))
237		}
238		anthropicResponse, err := a.client.Messages.New(
239			ctx,
240			preparedMessages,
241			opts...,
242		)
243		// If there is an error we are going to see if we can retry the call
244		if err != nil {
245			slog.Error("Error in Anthropic API call", "error", err)
246			retry, after, retryErr := a.shouldRetry(attempts, err)
247			if retryErr != nil {
248				return nil, retryErr
249			}
250			if retry {
251				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
252				select {
253				case <-ctx.Done():
254					return nil, ctx.Err()
255				case <-time.After(time.Duration(after) * time.Millisecond):
256					continue
257				}
258			}
259			return nil, retryErr
260		}
261
262		content := ""
263		for _, block := range anthropicResponse.Content {
264			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
265				content += text.Text
266			}
267		}
268
269		return &ProviderResponse{
270			Content:   content,
271			ToolCalls: a.toolCalls(*anthropicResponse),
272			Usage:     a.usage(*anthropicResponse),
273		}, nil
274	}
275}
276
277func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
278	cfg := config.Get()
279	attempts := 0
280	eventChan := make(chan ProviderEvent)
281	go func() {
282		for {
283			attempts++
284			// Prepare messages on each attempt in case max_tokens was adjusted
285			preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
286			if cfg.Options.Debug {
287				jsonData, _ := json.Marshal(preparedMessages)
288				slog.Debug("Prepared messages", "messages", string(jsonData))
289			}
290
291			var opts []option.RequestOption
292			if a.isThinkingEnabled() {
293				opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14"))
294			}
295
296			anthropicStream := a.client.Messages.NewStreaming(
297				ctx,
298				preparedMessages,
299				opts...,
300			)
301			accumulatedMessage := anthropic.Message{}
302
303			currentToolCallID := ""
304			for anthropicStream.Next() {
305				event := anthropicStream.Current()
306				err := accumulatedMessage.Accumulate(event)
307				if err != nil {
308					slog.Warn("Error accumulating message", "error", err)
309					continue
310				}
311
312				switch event := event.AsAny().(type) {
313				case anthropic.ContentBlockStartEvent:
314					switch event.ContentBlock.Type {
315					case "text":
316						eventChan <- ProviderEvent{Type: EventContentStart}
317					case "tool_use":
318						currentToolCallID = event.ContentBlock.ID
319						eventChan <- ProviderEvent{
320							Type: EventToolUseStart,
321							ToolCall: &message.ToolCall{
322								ID:       event.ContentBlock.ID,
323								Name:     event.ContentBlock.Name,
324								Finished: false,
325							},
326						}
327					}
328
329				case anthropic.ContentBlockDeltaEvent:
330					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
331						eventChan <- ProviderEvent{
332							Type:     EventThinkingDelta,
333							Thinking: event.Delta.Thinking,
334						}
335					} else if event.Delta.Type == "signature_delta" && event.Delta.Signature != "" {
336						eventChan <- ProviderEvent{
337							Type:      EventSignatureDelta,
338							Signature: event.Delta.Signature,
339						}
340					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
341						eventChan <- ProviderEvent{
342							Type:    EventContentDelta,
343							Content: event.Delta.Text,
344						}
345					} else if event.Delta.Type == "input_json_delta" {
346						if currentToolCallID != "" {
347							eventChan <- ProviderEvent{
348								Type: EventToolUseDelta,
349								ToolCall: &message.ToolCall{
350									ID:       currentToolCallID,
351									Finished: false,
352									Input:    event.Delta.PartialJSON,
353								},
354							}
355						}
356					}
357				case anthropic.ContentBlockStopEvent:
358					if currentToolCallID != "" {
359						eventChan <- ProviderEvent{
360							Type: EventToolUseStop,
361							ToolCall: &message.ToolCall{
362								ID: currentToolCallID,
363							},
364						}
365						currentToolCallID = ""
366					} else {
367						eventChan <- ProviderEvent{Type: EventContentStop}
368					}
369
370				case anthropic.MessageStopEvent:
371					content := ""
372					for _, block := range accumulatedMessage.Content {
373						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
374							content += text.Text
375						}
376					}
377
378					eventChan <- ProviderEvent{
379						Type: EventComplete,
380						Response: &ProviderResponse{
381							Content:      content,
382							ToolCalls:    a.toolCalls(accumulatedMessage),
383							Usage:        a.usage(accumulatedMessage),
384							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
385						},
386						Content: content,
387					}
388				}
389			}
390
391			err := anthropicStream.Err()
392			if err == nil || errors.Is(err, io.EOF) {
393				close(eventChan)
394				return
395			}
396			// If there is an error we are going to see if we can retry the call
397			retry, after, retryErr := a.shouldRetry(attempts, err)
398			if retryErr != nil {
399				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
400				close(eventChan)
401				return
402			}
403			if retry {
404				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
405				select {
406				case <-ctx.Done():
407					// context cancelled
408					if ctx.Err() != nil {
409						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
410					}
411					close(eventChan)
412					return
413				case <-time.After(time.Duration(after) * time.Millisecond):
414					continue
415				}
416			}
417			if ctx.Err() != nil {
418				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
419			}
420
421			close(eventChan)
422			return
423		}
424	}()
425	return eventChan
426}
427
428func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
429	var apiErr *anthropic.Error
430	if !errors.As(err, &apiErr) {
431		return false, 0, err
432	}
433
434	if attempts > maxRetries {
435		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
436	}
437
438	if apiErr.StatusCode == 401 {
439		a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
440		if err != nil {
441			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
442		}
443		a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
444		return true, 0, nil
445	}
446
447	// Handle context limit exceeded error (400 Bad Request)
448	if apiErr.StatusCode == 400 {
449		if adjusted, ok := a.handleContextLimitError(apiErr); ok {
450			a.adjustedMaxTokens = adjusted
451			slog.Debug("Adjusted max_tokens due to context limit", "new_max_tokens", adjusted)
452			return true, 0, nil
453		}
454	}
455
456	isOverloaded := strings.Contains(apiErr.Error(), "overloaded") || strings.Contains(apiErr.Error(), "rate limit exceeded")
457	if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 && !isOverloaded {
458		return false, 0, err
459	}
460
461	retryMs := 0
462	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
463
464	backoffMs := 2000 * (1 << (attempts - 1))
465	jitterMs := int(float64(backoffMs) * 0.2)
466	retryMs = backoffMs + jitterMs
467	if len(retryAfterValues) > 0 {
468		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
469			retryMs = retryMs * 1000
470		}
471	}
472	return true, int64(retryMs), nil
473}
474
475// handleContextLimitError parses context limit error and returns adjusted max_tokens
476func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) {
477	// Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000"
478	errorMsg := apiErr.Error()
479
480	re := regexp.MustCompile("input length and `max_tokens` exceed context limit: (\\d+) \\+ (\\d+) > (\\d+)")
481	matches := re.FindStringSubmatch(errorMsg)
482
483	if len(matches) != 4 {
484		return 0, false
485	}
486
487	inputTokens, err1 := strconv.Atoi(matches[1])
488	contextLimit, err2 := strconv.Atoi(matches[3])
489
490	if err1 != nil || err2 != nil {
491		return 0, false
492	}
493
494	// Calculate safe max_tokens with a buffer of 1000 tokens
495	safeMaxTokens := contextLimit - inputTokens - 1000
496
497	// Ensure we don't go below a minimum threshold
498	safeMaxTokens = max(safeMaxTokens, 1000)
499
500	return safeMaxTokens, true
501}
502
503func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
504	var toolCalls []message.ToolCall
505
506	for _, block := range msg.Content {
507		switch variant := block.AsAny().(type) {
508		case anthropic.ToolUseBlock:
509			toolCall := message.ToolCall{
510				ID:       variant.ID,
511				Name:     variant.Name,
512				Input:    string(variant.Input),
513				Type:     string(variant.Type),
514				Finished: true,
515			}
516			toolCalls = append(toolCalls, toolCall)
517		}
518	}
519
520	return toolCalls
521}
522
523func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
524	return TokenUsage{
525		InputTokens:         msg.Usage.InputTokens,
526		OutputTokens:        msg.Usage.OutputTokens,
527		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
528		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
529	}
530}
531
532func (a *anthropicClient) Model() catwalk.Model {
533	return a.providerOptions.model(a.providerOptions.modelType)
534}