anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"regexp"
 11	"strconv"
 12	"strings"
 13	"time"
 14
 15	"github.com/anthropics/anthropic-sdk-go"
 16	"github.com/anthropics/anthropic-sdk-go/bedrock"
 17	"github.com/anthropics/anthropic-sdk-go/option"
 18	"github.com/charmbracelet/catwalk/pkg/catwalk"
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/llm/tools"
 21	"github.com/charmbracelet/crush/internal/message"
 22)
 23
 24var (
 25	// Pre-compiled regex for parsing context limit errors.
 26	contextLimitRegex = regexp.MustCompile(`input length and ` + "`max_tokens`" + ` exceed context limit: (\d+) \+ (\d+) > (\d+)`)
 27)
 28
 29type anthropicClient struct {
 30	providerOptions   providerClientOptions
 31	useBedrock        bool
 32	client            anthropic.Client
 33	adjustedMaxTokens int // Used when context limit is hit
 34}
 35
 36type AnthropicClient ProviderClient
 37
 38func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 39	return &anthropicClient{
 40		providerOptions: opts,
 41		client:          createAnthropicClient(opts, useBedrock),
 42	}
 43}
 44
 45func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
 46	anthropicClientOptions := []option.RequestOption{}
 47
 48	// Check if Authorization header is provided in extra headers
 49	hasBearerAuth := false
 50	if opts.extraHeaders != nil {
 51		for key := range opts.extraHeaders {
 52			if strings.ToLower(key) == "authorization" {
 53				hasBearerAuth = true
 54				break
 55			}
 56		}
 57	}
 58
 59	isBearerToken := strings.HasPrefix(opts.apiKey, "Bearer ")
 60
 61	if opts.apiKey != "" && !hasBearerAuth {
 62		if isBearerToken {
 63			slog.Debug("API key starts with 'Bearer ', using as Authorization header")
 64			anthropicClientOptions = append(anthropicClientOptions, option.WithHeader("Authorization", opts.apiKey))
 65		} else {
 66			// Use standard X-Api-Key header
 67			anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 68		}
 69	} else if hasBearerAuth {
 70		slog.Debug("Skipping X-Api-Key header because Authorization header is provided")
 71	}
 72	if useBedrock {
 73		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 74	}
 75	for _, header := range opts.extraHeaders {
 76		anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(header, opts.extraHeaders[header]))
 77	}
 78	for key, value := range opts.extraBody {
 79		anthropicClientOptions = append(anthropicClientOptions, option.WithJSONSet(key, value))
 80	}
 81	return anthropic.NewClient(anthropicClientOptions...)
 82}
 83
 84func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 85	for i, msg := range messages {
 86		cache := false
 87		if i > len(messages)-3 {
 88			cache = true
 89		}
 90		switch msg.Role {
 91		case message.User:
 92			content := anthropic.NewTextBlock(msg.Content().String())
 93			if cache && !a.providerOptions.disableCache {
 94				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 95					Type: "ephemeral",
 96				}
 97			}
 98			var contentBlocks []anthropic.ContentBlockParamUnion
 99			contentBlocks = append(contentBlocks, content)
100			for _, binaryContent := range msg.BinaryContent() {
101				base64Image := binaryContent.String(catwalk.InferenceProviderAnthropic)
102				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
103				contentBlocks = append(contentBlocks, imageBlock)
104			}
105			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
106
107		case message.Assistant:
108			blocks := []anthropic.ContentBlockParamUnion{}
109
110			// Add thinking blocks first if present (required when thinking is enabled with tool use)
111			if reasoningContent := msg.ReasoningContent(); reasoningContent.Thinking != "" {
112				thinkingBlock := anthropic.NewThinkingBlock(reasoningContent.Signature, reasoningContent.Thinking)
113				blocks = append(blocks, thinkingBlock)
114			}
115
116			if msg.Content().String() != "" {
117				content := anthropic.NewTextBlock(msg.Content().String())
118				if cache && !a.providerOptions.disableCache {
119					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
120						Type: "ephemeral",
121					}
122				}
123				blocks = append(blocks, content)
124			}
125
126			for _, toolCall := range msg.ToolCalls() {
127				var inputMap map[string]any
128				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
129				if err != nil {
130					continue
131				}
132				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
133			}
134
135			if len(blocks) == 0 {
136				slog.Warn("There is a message without content, investigate, this should not happen")
137				continue
138			}
139			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
140
141		case message.Tool:
142			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
143			for i, toolResult := range msg.ToolResults() {
144				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
145			}
146			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
147		}
148	}
149	return
150}
151
152func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
153	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
154
155	for i, tool := range tools {
156		info := tool.Info()
157		toolParam := anthropic.ToolParam{
158			Name:        info.Name,
159			Description: anthropic.String(info.Description),
160			InputSchema: anthropic.ToolInputSchemaParam{
161				Properties: info.Parameters,
162				// TODO: figure out how we can tell claude the required fields?
163			},
164		}
165
166		if i == len(tools)-1 && !a.providerOptions.disableCache {
167			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
168				Type: "ephemeral",
169			}
170		}
171
172		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
173	}
174
175	return anthropicTools
176}
177
178func (a *anthropicClient) finishReason(reason string) message.FinishReason {
179	switch reason {
180	case "end_turn":
181		return message.FinishReasonEndTurn
182	case "max_tokens":
183		return message.FinishReasonMaxTokens
184	case "tool_use":
185		return message.FinishReasonToolUse
186	case "stop_sequence":
187		return message.FinishReasonEndTurn
188	default:
189		return message.FinishReasonUnknown
190	}
191}
192
193func (a *anthropicClient) isThinkingEnabled() bool {
194	cfg := config.Get()
195	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
196	if a.providerOptions.modelType == config.SelectedModelTypeSmall {
197		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
198	}
199	return a.Model().CanReason && modelConfig.Think
200}
201
202func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
203	model := a.providerOptions.model(a.providerOptions.modelType)
204	var thinkingParam anthropic.ThinkingConfigParamUnion
205	cfg := config.Get()
206	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
207	if a.providerOptions.modelType == config.SelectedModelTypeSmall {
208		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
209	}
210	temperature := anthropic.Float(0)
211
212	maxTokens := model.DefaultMaxTokens
213	if modelConfig.MaxTokens > 0 {
214		maxTokens = modelConfig.MaxTokens
215	}
216	if a.isThinkingEnabled() {
217		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(maxTokens) * 0.8))
218		temperature = anthropic.Float(1)
219	}
220	// Override max tokens if set in provider options
221	if a.providerOptions.maxTokens > 0 {
222		maxTokens = a.providerOptions.maxTokens
223	}
224
225	// Use adjusted max tokens if context limit was hit
226	if a.adjustedMaxTokens > 0 {
227		maxTokens = int64(a.adjustedMaxTokens)
228	}
229
230	systemBlocks := []anthropic.TextBlockParam{}
231
232	// Add custom system prompt prefix if configured
233	if a.providerOptions.systemPromptPrefix != "" {
234		systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
235			Text: a.providerOptions.systemPromptPrefix,
236			CacheControl: anthropic.CacheControlEphemeralParam{
237				Type: "ephemeral",
238			},
239		})
240	}
241
242	systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
243		Text: a.providerOptions.systemMessage,
244		CacheControl: anthropic.CacheControlEphemeralParam{
245			Type: "ephemeral",
246		},
247	})
248
249	return anthropic.MessageNewParams{
250		Model:       anthropic.Model(model.ID),
251		MaxTokens:   maxTokens,
252		Temperature: temperature,
253		Messages:    messages,
254		Tools:       tools,
255		Thinking:    thinkingParam,
256		System:      systemBlocks,
257	}
258}
259
260func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
261	cfg := config.Get()
262
263	attempts := 0
264	for {
265		attempts++
266		// Prepare messages on each attempt in case max_tokens was adjusted
267		preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
268		if cfg.Options.Debug {
269			jsonData, _ := json.Marshal(preparedMessages)
270			slog.Debug("Prepared messages", "messages", string(jsonData))
271		}
272
273		var opts []option.RequestOption
274		if a.isThinkingEnabled() {
275			opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14"))
276		}
277		anthropicResponse, err := a.client.Messages.New(
278			ctx,
279			preparedMessages,
280			opts...,
281		)
282		// If there is an error we are going to see if we can retry the call
283		if err != nil {
284			slog.Error("Error in Anthropic API call", "error", err)
285			retry, after, retryErr := a.shouldRetry(attempts, err)
286			if retryErr != nil {
287				return nil, retryErr
288			}
289			if retry {
290				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
291				select {
292				case <-ctx.Done():
293					return nil, ctx.Err()
294				case <-time.After(time.Duration(after) * time.Millisecond):
295					continue
296				}
297			}
298			return nil, retryErr
299		}
300
301		content := ""
302		for _, block := range anthropicResponse.Content {
303			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
304				content += text.Text
305			}
306		}
307
308		return &ProviderResponse{
309			Content:   content,
310			ToolCalls: a.toolCalls(*anthropicResponse),
311			Usage:     a.usage(*anthropicResponse),
312		}, nil
313	}
314}
315
316func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
317	cfg := config.Get()
318	attempts := 0
319	eventChan := make(chan ProviderEvent)
320	go func() {
321		for {
322			attempts++
323			// Prepare messages on each attempt in case max_tokens was adjusted
324			preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
325			if cfg.Options.Debug {
326				jsonData, _ := json.Marshal(preparedMessages)
327				slog.Debug("Prepared messages", "messages", string(jsonData))
328			}
329
330			var opts []option.RequestOption
331			if a.isThinkingEnabled() {
332				opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14"))
333			}
334
335			anthropicStream := a.client.Messages.NewStreaming(
336				ctx,
337				preparedMessages,
338				opts...,
339			)
340			accumulatedMessage := anthropic.Message{}
341
342			currentToolCallID := ""
343			for anthropicStream.Next() {
344				event := anthropicStream.Current()
345				err := accumulatedMessage.Accumulate(event)
346				if err != nil {
347					slog.Warn("Error accumulating message", "error", err)
348					continue
349				}
350
351				switch event := event.AsAny().(type) {
352				case anthropic.ContentBlockStartEvent:
353					switch event.ContentBlock.Type {
354					case "text":
355						eventChan <- ProviderEvent{Type: EventContentStart}
356					case "tool_use":
357						currentToolCallID = event.ContentBlock.ID
358						eventChan <- ProviderEvent{
359							Type: EventToolUseStart,
360							ToolCall: &message.ToolCall{
361								ID:       event.ContentBlock.ID,
362								Name:     event.ContentBlock.Name,
363								Finished: false,
364							},
365						}
366					}
367
368				case anthropic.ContentBlockDeltaEvent:
369					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
370						eventChan <- ProviderEvent{
371							Type:     EventThinkingDelta,
372							Thinking: event.Delta.Thinking,
373						}
374					} else if event.Delta.Type == "signature_delta" && event.Delta.Signature != "" {
375						eventChan <- ProviderEvent{
376							Type:      EventSignatureDelta,
377							Signature: event.Delta.Signature,
378						}
379					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
380						eventChan <- ProviderEvent{
381							Type:    EventContentDelta,
382							Content: event.Delta.Text,
383						}
384					} else if event.Delta.Type == "input_json_delta" {
385						if currentToolCallID != "" {
386							eventChan <- ProviderEvent{
387								Type: EventToolUseDelta,
388								ToolCall: &message.ToolCall{
389									ID:       currentToolCallID,
390									Finished: false,
391									Input:    event.Delta.PartialJSON,
392								},
393							}
394						}
395					}
396				case anthropic.ContentBlockStopEvent:
397					if currentToolCallID != "" {
398						eventChan <- ProviderEvent{
399							Type: EventToolUseStop,
400							ToolCall: &message.ToolCall{
401								ID: currentToolCallID,
402							},
403						}
404						currentToolCallID = ""
405					} else {
406						eventChan <- ProviderEvent{Type: EventContentStop}
407					}
408
409				case anthropic.MessageStopEvent:
410					content := ""
411					for _, block := range accumulatedMessage.Content {
412						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
413							content += text.Text
414						}
415					}
416
417					eventChan <- ProviderEvent{
418						Type: EventComplete,
419						Response: &ProviderResponse{
420							Content:      content,
421							ToolCalls:    a.toolCalls(accumulatedMessage),
422							Usage:        a.usage(accumulatedMessage),
423							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
424						},
425						Content: content,
426					}
427				}
428			}
429
430			err := anthropicStream.Err()
431			if err == nil || errors.Is(err, io.EOF) {
432				close(eventChan)
433				return
434			}
435
436			// If there is an error we are going to see if we can retry the call
437			retry, after, retryErr := a.shouldRetry(attempts, err)
438			if retryErr != nil {
439				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
440				close(eventChan)
441				return
442			}
443			if retry {
444				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
445				select {
446				case <-ctx.Done():
447					// context cancelled
448					if ctx.Err() != nil {
449						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
450					}
451					close(eventChan)
452					return
453				case <-time.After(time.Duration(after) * time.Millisecond):
454					continue
455				}
456			}
457			if ctx.Err() != nil {
458				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
459			}
460
461			close(eventChan)
462			return
463		}
464	}()
465	return eventChan
466}
467
468func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
469	var apiErr *anthropic.Error
470	if !errors.As(err, &apiErr) {
471		return false, 0, err
472	}
473
474	if attempts > maxRetries {
475		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
476	}
477
478	if apiErr.StatusCode == 401 {
479		a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
480		if err != nil {
481			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
482		}
483		a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
484		return true, 0, nil
485	}
486
487	// Handle context limit exceeded error (400 Bad Request)
488	if apiErr.StatusCode == 400 {
489		if adjusted, ok := a.handleContextLimitError(apiErr); ok {
490			a.adjustedMaxTokens = adjusted
491			slog.Debug("Adjusted max_tokens due to context limit", "new_max_tokens", adjusted)
492			return true, 0, nil
493		}
494	}
495
496	isOverloaded := strings.Contains(apiErr.Error(), "overloaded") || strings.Contains(apiErr.Error(), "rate limit exceeded")
497	if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 && !isOverloaded {
498		return false, 0, err
499	}
500
501	retryMs := 0
502	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
503
504	backoffMs := 2000 * (1 << (attempts - 1))
505	jitterMs := int(float64(backoffMs) * 0.2)
506	retryMs = backoffMs + jitterMs
507	if len(retryAfterValues) > 0 {
508		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
509			retryMs = retryMs * 1000
510		}
511	}
512	return true, int64(retryMs), nil
513}
514
515// handleContextLimitError parses context limit error and returns adjusted max_tokens
516func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) {
517	// Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000"
518	errorMsg := apiErr.Error()
519
520	matches := contextLimitRegex.FindStringSubmatch(errorMsg)
521
522	if len(matches) != 4 {
523		return 0, false
524	}
525
526	inputTokens, err1 := strconv.Atoi(matches[1])
527	contextLimit, err2 := strconv.Atoi(matches[3])
528
529	if err1 != nil || err2 != nil {
530		return 0, false
531	}
532
533	// Calculate safe max_tokens with a buffer of 1000 tokens
534	safeMaxTokens := contextLimit - inputTokens - 1000
535
536	// Ensure we don't go below a minimum threshold
537	safeMaxTokens = max(safeMaxTokens, 1000)
538
539	return safeMaxTokens, true
540}
541
542func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
543	var toolCalls []message.ToolCall
544
545	for _, block := range msg.Content {
546		switch variant := block.AsAny().(type) {
547		case anthropic.ToolUseBlock:
548			toolCall := message.ToolCall{
549				ID:       variant.ID,
550				Name:     variant.Name,
551				Input:    string(variant.Input),
552				Type:     string(variant.Type),
553				Finished: true,
554			}
555			toolCalls = append(toolCalls, toolCall)
556		}
557	}
558
559	return toolCalls
560}
561
562func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
563	return TokenUsage{
564		InputTokens:         msg.Usage.InputTokens,
565		OutputTokens:        msg.Usage.OutputTokens,
566		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
567		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
568	}
569}
570
571func (a *anthropicClient) Model() catwalk.Model {
572	return a.providerOptions.model(a.providerOptions.modelType)
573}