anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"regexp"
 11	"strconv"
 12	"time"
 13
 14	"github.com/anthropics/anthropic-sdk-go"
 15	"github.com/anthropics/anthropic-sdk-go/bedrock"
 16	"github.com/anthropics/anthropic-sdk-go/option"
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/fur/provider"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/message"
 21)
 22
 23type anthropicClient struct {
 24	providerOptions   providerClientOptions
 25	useBedrock        bool
 26	client            anthropic.Client
 27	adjustedMaxTokens int // Used when context limit is hit
 28}
 29
 30type AnthropicClient ProviderClient
 31
 32func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 33	return &anthropicClient{
 34		providerOptions: opts,
 35		client:          createAnthropicClient(opts, useBedrock),
 36	}
 37}
 38
 39func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
 40	anthropicClientOptions := []option.RequestOption{}
 41	if opts.apiKey != "" {
 42		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 43	}
 44	if useBedrock {
 45		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 46	}
 47	return anthropic.NewClient(anthropicClientOptions...)
 48}
 49
 50func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 51	for i, msg := range messages {
 52		cache := false
 53		if i > len(messages)-3 {
 54			cache = true
 55		}
 56		switch msg.Role {
 57		case message.User:
 58			content := anthropic.NewTextBlock(msg.Content().String())
 59			if cache && !a.providerOptions.disableCache {
 60				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 61					Type: "ephemeral",
 62				}
 63			}
 64			var contentBlocks []anthropic.ContentBlockParamUnion
 65			contentBlocks = append(contentBlocks, content)
 66			for _, binaryContent := range msg.BinaryContent() {
 67				base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
 68				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
 69				contentBlocks = append(contentBlocks, imageBlock)
 70			}
 71			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
 72
 73		case message.Assistant:
 74			blocks := []anthropic.ContentBlockParamUnion{}
 75
 76			// Add thinking blocks first if present (required when thinking is enabled with tool use)
 77			if reasoningContent := msg.ReasoningContent(); reasoningContent.Thinking != "" {
 78				thinkingBlock := anthropic.NewThinkingBlock(reasoningContent.Signature, reasoningContent.Thinking)
 79				blocks = append(blocks, thinkingBlock)
 80			}
 81
 82			if msg.Content().String() != "" {
 83				content := anthropic.NewTextBlock(msg.Content().String())
 84				if cache && !a.providerOptions.disableCache {
 85					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 86						Type: "ephemeral",
 87					}
 88				}
 89				blocks = append(blocks, content)
 90			}
 91
 92			for _, toolCall := range msg.ToolCalls() {
 93				var inputMap map[string]any
 94				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
 95				if err != nil {
 96					continue
 97				}
 98				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
 99			}
100
101			if len(blocks) == 0 {
102				slog.Warn("There is a message without content, investigate, this should not happen")
103				continue
104			}
105			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
106
107		case message.Tool:
108			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
109			for i, toolResult := range msg.ToolResults() {
110				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
111			}
112			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
113		}
114	}
115	return
116}
117
118func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
119	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
120
121	for i, tool := range tools {
122		info := tool.Info()
123		toolParam := anthropic.ToolParam{
124			Name:        info.Name,
125			Description: anthropic.String(info.Description),
126			InputSchema: anthropic.ToolInputSchemaParam{
127				Properties: info.Parameters,
128				// TODO: figure out how we can tell claude the required fields?
129			},
130		}
131
132		if i == len(tools)-1 && !a.providerOptions.disableCache {
133			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
134				Type: "ephemeral",
135			}
136		}
137
138		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
139	}
140
141	return anthropicTools
142}
143
144func (a *anthropicClient) finishReason(reason string) message.FinishReason {
145	switch reason {
146	case "end_turn":
147		return message.FinishReasonEndTurn
148	case "max_tokens":
149		return message.FinishReasonMaxTokens
150	case "tool_use":
151		return message.FinishReasonToolUse
152	case "stop_sequence":
153		return message.FinishReasonEndTurn
154	default:
155		return message.FinishReasonUnknown
156	}
157}
158
159func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
160	model := a.providerOptions.model(a.providerOptions.modelType)
161	var thinkingParam anthropic.ThinkingConfigParamUnion
162	cfg := config.Get()
163	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
164	if a.providerOptions.modelType == config.SelectedModelTypeSmall {
165		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
166	}
167	temperature := anthropic.Float(0)
168
169	maxTokens := model.DefaultMaxTokens
170	if modelConfig.MaxTokens > 0 {
171		maxTokens = modelConfig.MaxTokens
172	}
173	if a.Model().CanReason && modelConfig.Think {
174		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(maxTokens) * 0.8))
175		temperature = anthropic.Float(1)
176	}
177	// Override max tokens if set in provider options
178	if a.providerOptions.maxTokens > 0 {
179		maxTokens = a.providerOptions.maxTokens
180	}
181
182	// Use adjusted max tokens if context limit was hit
183	if a.adjustedMaxTokens > 0 {
184		maxTokens = int64(a.adjustedMaxTokens)
185	}
186
187	return anthropic.MessageNewParams{
188		Model:       anthropic.Model(model.ID),
189		MaxTokens:   maxTokens,
190		Temperature: temperature,
191		Messages:    messages,
192		Tools:       tools,
193		Thinking:    thinkingParam,
194		System: []anthropic.TextBlockParam{
195			{
196				Text: a.providerOptions.systemMessage,
197				CacheControl: anthropic.CacheControlEphemeralParam{
198					Type: "ephemeral",
199				},
200			},
201		},
202	}
203}
204
205func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
206	cfg := config.Get()
207
208	attempts := 0
209	for {
210		attempts++
211		// Prepare messages on each attempt in case max_tokens was adjusted
212		preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
213		if cfg.Options.Debug {
214			jsonData, _ := json.Marshal(preparedMessages)
215			slog.Debug("Prepared messages", "messages", string(jsonData))
216		}
217
218		anthropicResponse, err := a.client.Messages.New(
219			ctx,
220			preparedMessages,
221		)
222		// If there is an error we are going to see if we can retry the call
223		if err != nil {
224			slog.Error("Error in Anthropic API call", "error", err)
225			retry, after, retryErr := a.shouldRetry(attempts, err)
226			if retryErr != nil {
227				return nil, retryErr
228			}
229			if retry {
230				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
231				select {
232				case <-ctx.Done():
233					return nil, ctx.Err()
234				case <-time.After(time.Duration(after) * time.Millisecond):
235					continue
236				}
237			}
238			return nil, retryErr
239		}
240
241		content := ""
242		for _, block := range anthropicResponse.Content {
243			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
244				content += text.Text
245			}
246		}
247
248		return &ProviderResponse{
249			Content:   content,
250			ToolCalls: a.toolCalls(*anthropicResponse),
251			Usage:     a.usage(*anthropicResponse),
252		}, nil
253	}
254}
255
256func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
257	cfg := config.Get()
258	attempts := 0
259	eventChan := make(chan ProviderEvent)
260	go func() {
261		for {
262			attempts++
263			// Prepare messages on each attempt in case max_tokens was adjusted
264			preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
265			if cfg.Options.Debug {
266				jsonData, _ := json.Marshal(preparedMessages)
267				slog.Debug("Prepared messages", "messages", string(jsonData))
268			}
269
270			anthropicStream := a.client.Messages.NewStreaming(
271				ctx,
272				preparedMessages,
273			)
274			accumulatedMessage := anthropic.Message{}
275
276			currentToolCallID := ""
277			for anthropicStream.Next() {
278				event := anthropicStream.Current()
279				err := accumulatedMessage.Accumulate(event)
280				if err != nil {
281					slog.Warn("Error accumulating message", "error", err)
282					continue
283				}
284
285				switch event := event.AsAny().(type) {
286				case anthropic.ContentBlockStartEvent:
287					switch event.ContentBlock.Type {
288					case "text":
289						eventChan <- ProviderEvent{Type: EventContentStart}
290					case "tool_use":
291						currentToolCallID = event.ContentBlock.ID
292						eventChan <- ProviderEvent{
293							Type: EventToolUseStart,
294							ToolCall: &message.ToolCall{
295								ID:       event.ContentBlock.ID,
296								Name:     event.ContentBlock.Name,
297								Finished: false,
298							},
299						}
300					}
301
302				case anthropic.ContentBlockDeltaEvent:
303					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
304						eventChan <- ProviderEvent{
305							Type:     EventThinkingDelta,
306							Thinking: event.Delta.Thinking,
307						}
308					} else if event.Delta.Type == "signature_delta" && event.Delta.Signature != "" {
309						eventChan <- ProviderEvent{
310							Type:      EventSignatureDelta,
311							Signature: event.Delta.Signature,
312						}
313					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
314						eventChan <- ProviderEvent{
315							Type:    EventContentDelta,
316							Content: event.Delta.Text,
317						}
318					} else if event.Delta.Type == "input_json_delta" {
319						if currentToolCallID != "" {
320							eventChan <- ProviderEvent{
321								Type: EventToolUseDelta,
322								ToolCall: &message.ToolCall{
323									ID:       currentToolCallID,
324									Finished: false,
325									Input:    event.Delta.PartialJSON,
326								},
327							}
328						}
329					}
330				case anthropic.ContentBlockStopEvent:
331					if currentToolCallID != "" {
332						eventChan <- ProviderEvent{
333							Type: EventToolUseStop,
334							ToolCall: &message.ToolCall{
335								ID: currentToolCallID,
336							},
337						}
338						currentToolCallID = ""
339					} else {
340						eventChan <- ProviderEvent{Type: EventContentStop}
341					}
342
343				case anthropic.MessageStopEvent:
344					content := ""
345					for _, block := range accumulatedMessage.Content {
346						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
347							content += text.Text
348						}
349					}
350
351					eventChan <- ProviderEvent{
352						Type: EventComplete,
353						Response: &ProviderResponse{
354							Content:      content,
355							ToolCalls:    a.toolCalls(accumulatedMessage),
356							Usage:        a.usage(accumulatedMessage),
357							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
358						},
359						Content: content,
360					}
361				}
362			}
363
364			err := anthropicStream.Err()
365			if err == nil || errors.Is(err, io.EOF) {
366				close(eventChan)
367				return
368			}
369			// If there is an error we are going to see if we can retry the call
370			retry, after, retryErr := a.shouldRetry(attempts, err)
371			if retryErr != nil {
372				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
373				close(eventChan)
374				return
375			}
376			if retry {
377				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
378				select {
379				case <-ctx.Done():
380					// context cancelled
381					if ctx.Err() != nil {
382						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
383					}
384					close(eventChan)
385					return
386				case <-time.After(time.Duration(after) * time.Millisecond):
387					continue
388				}
389			}
390			if ctx.Err() != nil {
391				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
392			}
393
394			close(eventChan)
395			return
396		}
397	}()
398	return eventChan
399}
400
401func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
402	var apiErr *anthropic.Error
403	if !errors.As(err, &apiErr) {
404		return false, 0, err
405	}
406
407	if attempts > maxRetries {
408		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
409	}
410
411	if apiErr.StatusCode == 401 {
412		a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
413		if err != nil {
414			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
415		}
416		a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
417		return true, 0, nil
418	}
419
420	// Handle context limit exceeded error (400 Bad Request)
421	if apiErr.StatusCode == 400 {
422		if adjusted, ok := a.handleContextLimitError(apiErr); ok {
423			a.adjustedMaxTokens = adjusted
424			slog.Debug("Adjusted max_tokens due to context limit", "new_max_tokens", adjusted)
425			return true, 0, nil
426		}
427	}
428
429	if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 {
430		return false, 0, err
431	}
432
433	retryMs := 0
434	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
435
436	backoffMs := 2000 * (1 << (attempts - 1))
437	jitterMs := int(float64(backoffMs) * 0.2)
438	retryMs = backoffMs + jitterMs
439	if len(retryAfterValues) > 0 {
440		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
441			retryMs = retryMs * 1000
442		}
443	}
444	return true, int64(retryMs), nil
445}
446
447// handleContextLimitError parses context limit error and returns adjusted max_tokens
448func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) {
449	// Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000"
450	errorMsg := apiErr.Error()
451
452	re := regexp.MustCompile("input length and `max_tokens` exceed context limit: (\\d+) \\+ (\\d+) > (\\d+)")
453	matches := re.FindStringSubmatch(errorMsg)
454
455	if len(matches) != 4 {
456		return 0, false
457	}
458
459	inputTokens, err1 := strconv.Atoi(matches[1])
460	contextLimit, err2 := strconv.Atoi(matches[3])
461
462	if err1 != nil || err2 != nil {
463		return 0, false
464	}
465
466	// Calculate safe max_tokens with a buffer of 1000 tokens
467	safeMaxTokens := contextLimit - inputTokens - 1000
468
469	// Ensure we don't go below a minimum threshold
470	safeMaxTokens = max(safeMaxTokens, 1000)
471
472	return safeMaxTokens, true
473}
474
475func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
476	var toolCalls []message.ToolCall
477
478	for _, block := range msg.Content {
479		switch variant := block.AsAny().(type) {
480		case anthropic.ToolUseBlock:
481			toolCall := message.ToolCall{
482				ID:       variant.ID,
483				Name:     variant.Name,
484				Input:    string(variant.Input),
485				Type:     string(variant.Type),
486				Finished: true,
487			}
488			toolCalls = append(toolCalls, toolCall)
489		}
490	}
491
492	return toolCalls
493}
494
495func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
496	return TokenUsage{
497		InputTokens:         msg.Usage.InputTokens,
498		OutputTokens:        msg.Usage.OutputTokens,
499		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
500		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
501	}
502}
503
504func (a *anthropicClient) Model() provider.Model {
505	return a.providerOptions.model(a.providerOptions.modelType)
506}