anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"regexp"
 11	"strconv"
 12	"strings"
 13	"time"
 14
 15	"github.com/anthropics/anthropic-sdk-go"
 16	"github.com/anthropics/anthropic-sdk-go/bedrock"
 17	"github.com/anthropics/anthropic-sdk-go/option"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/fur/provider"
 20	"github.com/charmbracelet/crush/internal/llm/tools"
 21	"github.com/charmbracelet/crush/internal/message"
 22)
 23
 24type anthropicClient struct {
 25	providerOptions   providerClientOptions
 26	useBedrock        bool
 27	client            anthropic.Client
 28	adjustedMaxTokens int // Used when context limit is hit
 29}
 30
 31type AnthropicClient ProviderClient
 32
 33func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 34	return &anthropicClient{
 35		providerOptions: opts,
 36		client:          createAnthropicClient(opts, useBedrock),
 37	}
 38}
 39
 40func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
 41	anthropicClientOptions := []option.RequestOption{}
 42	if opts.apiKey != "" {
 43		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 44	}
 45	if useBedrock {
 46		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 47	}
 48	for _, header := range opts.extraHeaders {
 49		anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(header, opts.extraHeaders[header]))
 50	}
 51	for key, value := range opts.extraBody {
 52		anthropicClientOptions = append(anthropicClientOptions, option.WithJSONSet(key, value))
 53	}
 54	return anthropic.NewClient(anthropicClientOptions...)
 55}
 56
 57func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 58	for i, msg := range messages {
 59		cache := false
 60		if i > len(messages)-3 {
 61			cache = true
 62		}
 63		switch msg.Role {
 64		case message.User:
 65			content := anthropic.NewTextBlock(msg.Content().String())
 66			if cache && !a.providerOptions.disableCache {
 67				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 68					Type: "ephemeral",
 69				}
 70			}
 71			var contentBlocks []anthropic.ContentBlockParamUnion
 72			contentBlocks = append(contentBlocks, content)
 73			for _, binaryContent := range msg.BinaryContent() {
 74				base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
 75				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
 76				contentBlocks = append(contentBlocks, imageBlock)
 77			}
 78			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
 79
 80		case message.Assistant:
 81			blocks := []anthropic.ContentBlockParamUnion{}
 82
 83			// Add thinking blocks first if present (required when thinking is enabled with tool use)
 84			if reasoningContent := msg.ReasoningContent(); reasoningContent.Thinking != "" {
 85				thinkingBlock := anthropic.NewThinkingBlock(reasoningContent.Signature, reasoningContent.Thinking)
 86				blocks = append(blocks, thinkingBlock)
 87			}
 88
 89			if msg.Content().String() != "" {
 90				content := anthropic.NewTextBlock(msg.Content().String())
 91				if cache && !a.providerOptions.disableCache {
 92					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 93						Type: "ephemeral",
 94					}
 95				}
 96				blocks = append(blocks, content)
 97			}
 98
 99			for _, toolCall := range msg.ToolCalls() {
100				var inputMap map[string]any
101				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
102				if err != nil {
103					continue
104				}
105				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
106			}
107
108			if len(blocks) == 0 {
109				slog.Warn("There is a message without content, investigate, this should not happen")
110				continue
111			}
112			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
113
114		case message.Tool:
115			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
116			for i, toolResult := range msg.ToolResults() {
117				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
118			}
119			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
120		}
121	}
122	return
123}
124
125func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
126	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
127
128	for i, tool := range tools {
129		info := tool.Info()
130		toolParam := anthropic.ToolParam{
131			Name:        info.Name,
132			Description: anthropic.String(info.Description),
133			InputSchema: anthropic.ToolInputSchemaParam{
134				Properties: info.Parameters,
135				// TODO: figure out how we can tell claude the required fields?
136			},
137		}
138
139		if i == len(tools)-1 && !a.providerOptions.disableCache {
140			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
141				Type: "ephemeral",
142			}
143		}
144
145		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
146	}
147
148	return anthropicTools
149}
150
151func (a *anthropicClient) finishReason(reason string) message.FinishReason {
152	switch reason {
153	case "end_turn":
154		return message.FinishReasonEndTurn
155	case "max_tokens":
156		return message.FinishReasonMaxTokens
157	case "tool_use":
158		return message.FinishReasonToolUse
159	case "stop_sequence":
160		return message.FinishReasonEndTurn
161	default:
162		return message.FinishReasonUnknown
163	}
164}
165
166func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
167	model := a.providerOptions.model(a.providerOptions.modelType)
168	var thinkingParam anthropic.ThinkingConfigParamUnion
169	cfg := config.Get()
170	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
171	if a.providerOptions.modelType == config.SelectedModelTypeSmall {
172		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
173	}
174	temperature := anthropic.Float(0)
175
176	maxTokens := model.DefaultMaxTokens
177	if modelConfig.MaxTokens > 0 {
178		maxTokens = modelConfig.MaxTokens
179	}
180	if a.Model().CanReason && modelConfig.Think {
181		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(maxTokens) * 0.8))
182		temperature = anthropic.Float(1)
183	}
184	// Override max tokens if set in provider options
185	if a.providerOptions.maxTokens > 0 {
186		maxTokens = a.providerOptions.maxTokens
187	}
188
189	// Use adjusted max tokens if context limit was hit
190	if a.adjustedMaxTokens > 0 {
191		maxTokens = int64(a.adjustedMaxTokens)
192	}
193
194	return anthropic.MessageNewParams{
195		Model:       anthropic.Model(model.ID),
196		MaxTokens:   maxTokens,
197		Temperature: temperature,
198		Messages:    messages,
199		Tools:       tools,
200		Thinking:    thinkingParam,
201		System: []anthropic.TextBlockParam{
202			{
203				Text: a.providerOptions.systemMessage,
204				CacheControl: anthropic.CacheControlEphemeralParam{
205					Type: "ephemeral",
206				},
207			},
208		},
209	}
210}
211
212func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
213	cfg := config.Get()
214
215	attempts := 0
216	for {
217		attempts++
218		// Prepare messages on each attempt in case max_tokens was adjusted
219		preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
220		if cfg.Options.Debug {
221			jsonData, _ := json.Marshal(preparedMessages)
222			slog.Debug("Prepared messages", "messages", string(jsonData))
223		}
224
225		anthropicResponse, err := a.client.Messages.New(
226			ctx,
227			preparedMessages,
228		)
229		// If there is an error we are going to see if we can retry the call
230		if err != nil {
231			slog.Error("Error in Anthropic API call", "error", err)
232			retry, after, retryErr := a.shouldRetry(attempts, err)
233			if retryErr != nil {
234				return nil, retryErr
235			}
236			if retry {
237				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
238				select {
239				case <-ctx.Done():
240					return nil, ctx.Err()
241				case <-time.After(time.Duration(after) * time.Millisecond):
242					continue
243				}
244			}
245			return nil, retryErr
246		}
247
248		content := ""
249		for _, block := range anthropicResponse.Content {
250			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
251				content += text.Text
252			}
253		}
254
255		return &ProviderResponse{
256			Content:   content,
257			ToolCalls: a.toolCalls(*anthropicResponse),
258			Usage:     a.usage(*anthropicResponse),
259		}, nil
260	}
261}
262
263func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
264	cfg := config.Get()
265	attempts := 0
266	eventChan := make(chan ProviderEvent)
267	go func() {
268		for {
269			attempts++
270			// Prepare messages on each attempt in case max_tokens was adjusted
271			preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
272			if cfg.Options.Debug {
273				jsonData, _ := json.Marshal(preparedMessages)
274				slog.Debug("Prepared messages", "messages", string(jsonData))
275			}
276
277			anthropicStream := a.client.Messages.NewStreaming(
278				ctx,
279				preparedMessages,
280				option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14"),
281			)
282			accumulatedMessage := anthropic.Message{}
283
284			currentToolCallID := ""
285			for anthropicStream.Next() {
286				event := anthropicStream.Current()
287				err := accumulatedMessage.Accumulate(event)
288				if err != nil {
289					slog.Warn("Error accumulating message", "error", err)
290					continue
291				}
292
293				switch event := event.AsAny().(type) {
294				case anthropic.ContentBlockStartEvent:
295					switch event.ContentBlock.Type {
296					case "text":
297						eventChan <- ProviderEvent{Type: EventContentStart}
298					case "tool_use":
299						currentToolCallID = event.ContentBlock.ID
300						eventChan <- ProviderEvent{
301							Type: EventToolUseStart,
302							ToolCall: &message.ToolCall{
303								ID:       event.ContentBlock.ID,
304								Name:     event.ContentBlock.Name,
305								Finished: false,
306							},
307						}
308					}
309
310				case anthropic.ContentBlockDeltaEvent:
311					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
312						eventChan <- ProviderEvent{
313							Type:     EventThinkingDelta,
314							Thinking: event.Delta.Thinking,
315						}
316					} else if event.Delta.Type == "signature_delta" && event.Delta.Signature != "" {
317						eventChan <- ProviderEvent{
318							Type:      EventSignatureDelta,
319							Signature: event.Delta.Signature,
320						}
321					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
322						eventChan <- ProviderEvent{
323							Type:    EventContentDelta,
324							Content: event.Delta.Text,
325						}
326					} else if event.Delta.Type == "input_json_delta" {
327						if currentToolCallID != "" {
328							eventChan <- ProviderEvent{
329								Type: EventToolUseDelta,
330								ToolCall: &message.ToolCall{
331									ID:       currentToolCallID,
332									Finished: false,
333									Input:    event.Delta.PartialJSON,
334								},
335							}
336						}
337					}
338				case anthropic.ContentBlockStopEvent:
339					if currentToolCallID != "" {
340						eventChan <- ProviderEvent{
341							Type: EventToolUseStop,
342							ToolCall: &message.ToolCall{
343								ID: currentToolCallID,
344							},
345						}
346						currentToolCallID = ""
347					} else {
348						eventChan <- ProviderEvent{Type: EventContentStop}
349					}
350
351				case anthropic.MessageStopEvent:
352					content := ""
353					for _, block := range accumulatedMessage.Content {
354						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
355							content += text.Text
356						}
357					}
358
359					eventChan <- ProviderEvent{
360						Type: EventComplete,
361						Response: &ProviderResponse{
362							Content:      content,
363							ToolCalls:    a.toolCalls(accumulatedMessage),
364							Usage:        a.usage(accumulatedMessage),
365							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
366						},
367						Content: content,
368					}
369				}
370			}
371
372			err := anthropicStream.Err()
373			if err == nil || errors.Is(err, io.EOF) {
374				close(eventChan)
375				return
376			}
377			// If there is an error we are going to see if we can retry the call
378			retry, after, retryErr := a.shouldRetry(attempts, err)
379			if retryErr != nil {
380				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
381				close(eventChan)
382				return
383			}
384			if retry {
385				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
386				select {
387				case <-ctx.Done():
388					// context cancelled
389					if ctx.Err() != nil {
390						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
391					}
392					close(eventChan)
393					return
394				case <-time.After(time.Duration(after) * time.Millisecond):
395					continue
396				}
397			}
398			if ctx.Err() != nil {
399				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
400			}
401
402			close(eventChan)
403			return
404		}
405	}()
406	return eventChan
407}
408
409func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
410	var apiErr *anthropic.Error
411	if !errors.As(err, &apiErr) {
412		return false, 0, err
413	}
414
415	if attempts > maxRetries {
416		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
417	}
418
419	if apiErr.StatusCode == 401 {
420		a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
421		if err != nil {
422			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
423		}
424		a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
425		return true, 0, nil
426	}
427
428	// Handle context limit exceeded error (400 Bad Request)
429	if apiErr.StatusCode == 400 {
430		if adjusted, ok := a.handleContextLimitError(apiErr); ok {
431			a.adjustedMaxTokens = adjusted
432			slog.Debug("Adjusted max_tokens due to context limit", "new_max_tokens", adjusted)
433			return true, 0, nil
434		}
435	}
436
437	isOverloaded := strings.Contains(apiErr.Error(), "overloaded") || strings.Contains(apiErr.Error(), "rate limit exceeded")
438	if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 && !isOverloaded {
439		return false, 0, err
440	}
441
442	retryMs := 0
443	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
444
445	backoffMs := 2000 * (1 << (attempts - 1))
446	jitterMs := int(float64(backoffMs) * 0.2)
447	retryMs = backoffMs + jitterMs
448	if len(retryAfterValues) > 0 {
449		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
450			retryMs = retryMs * 1000
451		}
452	}
453	return true, int64(retryMs), nil
454}
455
456// handleContextLimitError parses context limit error and returns adjusted max_tokens
457func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) {
458	// Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000"
459	errorMsg := apiErr.Error()
460
461	re := regexp.MustCompile("input length and `max_tokens` exceed context limit: (\\d+) \\+ (\\d+) > (\\d+)")
462	matches := re.FindStringSubmatch(errorMsg)
463
464	if len(matches) != 4 {
465		return 0, false
466	}
467
468	inputTokens, err1 := strconv.Atoi(matches[1])
469	contextLimit, err2 := strconv.Atoi(matches[3])
470
471	if err1 != nil || err2 != nil {
472		return 0, false
473	}
474
475	// Calculate safe max_tokens with a buffer of 1000 tokens
476	safeMaxTokens := contextLimit - inputTokens - 1000
477
478	// Ensure we don't go below a minimum threshold
479	safeMaxTokens = max(safeMaxTokens, 1000)
480
481	return safeMaxTokens, true
482}
483
484func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
485	var toolCalls []message.ToolCall
486
487	for _, block := range msg.Content {
488		switch variant := block.AsAny().(type) {
489		case anthropic.ToolUseBlock:
490			toolCall := message.ToolCall{
491				ID:       variant.ID,
492				Name:     variant.Name,
493				Input:    string(variant.Input),
494				Type:     string(variant.Type),
495				Finished: true,
496			}
497			toolCalls = append(toolCalls, toolCall)
498		}
499	}
500
501	return toolCalls
502}
503
504func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
505	return TokenUsage{
506		InputTokens:         msg.Usage.InputTokens,
507		OutputTokens:        msg.Usage.OutputTokens,
508		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
509		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
510	}
511}
512
513func (a *anthropicClient) Model() provider.Model {
514	return a.providerOptions.model(a.providerOptions.modelType)
515}