anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"regexp"
 10	"strconv"
 11	"time"
 12
 13	"github.com/anthropics/anthropic-sdk-go"
 14	"github.com/anthropics/anthropic-sdk-go/bedrock"
 15	"github.com/anthropics/anthropic-sdk-go/option"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/fur/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/logging"
 20	"github.com/charmbracelet/crush/internal/message"
 21)
 22
 23type anthropicClient struct {
 24	providerOptions   providerClientOptions
 25	useBedrock        bool
 26	client            anthropic.Client
 27	adjustedMaxTokens int // Used when context limit is hit
 28}
 29
 30type AnthropicClient ProviderClient
 31
 32func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 33	return &anthropicClient{
 34		providerOptions: opts,
 35		client:          createAnthropicClient(opts, useBedrock),
 36	}
 37}
 38
 39func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
 40	anthropicClientOptions := []option.RequestOption{}
 41	if opts.apiKey != "" {
 42		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 43	}
 44	if useBedrock {
 45		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 46	}
 47	return anthropic.NewClient(anthropicClientOptions...)
 48}
 49
 50func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 51	for i, msg := range messages {
 52		cache := false
 53		if i > len(messages)-3 {
 54			cache = true
 55		}
 56		switch msg.Role {
 57		case message.User:
 58			content := anthropic.NewTextBlock(msg.Content().String())
 59			if cache && !a.providerOptions.disableCache {
 60				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 61					Type: "ephemeral",
 62				}
 63			}
 64			var contentBlocks []anthropic.ContentBlockParamUnion
 65			contentBlocks = append(contentBlocks, content)
 66			for _, binaryContent := range msg.BinaryContent() {
 67				base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
 68				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
 69				contentBlocks = append(contentBlocks, imageBlock)
 70			}
 71			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
 72
 73		case message.Assistant:
 74			blocks := []anthropic.ContentBlockParamUnion{}
 75			if msg.Content().String() != "" {
 76				content := anthropic.NewTextBlock(msg.Content().String())
 77				if cache && !a.providerOptions.disableCache {
 78					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 79						Type: "ephemeral",
 80					}
 81				}
 82				blocks = append(blocks, content)
 83			}
 84
 85			for _, toolCall := range msg.ToolCalls() {
 86				var inputMap map[string]any
 87				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
 88				if err != nil {
 89					continue
 90				}
 91				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
 92			}
 93
 94			if len(blocks) == 0 {
 95				logging.Warn("There is a message without content, investigate, this should not happen")
 96				continue
 97			}
 98			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
 99
100		case message.Tool:
101			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
102			for i, toolResult := range msg.ToolResults() {
103				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
104			}
105			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
106		}
107	}
108	return
109}
110
111func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
112	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
113
114	for i, tool := range tools {
115		info := tool.Info()
116		toolParam := anthropic.ToolParam{
117			Name:        info.Name,
118			Description: anthropic.String(info.Description),
119			InputSchema: anthropic.ToolInputSchemaParam{
120				Properties: info.Parameters,
121				// TODO: figure out how we can tell claude the required fields?
122			},
123		}
124
125		if i == len(tools)-1 && !a.providerOptions.disableCache {
126			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
127				Type: "ephemeral",
128			}
129		}
130
131		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
132	}
133
134	return anthropicTools
135}
136
137func (a *anthropicClient) finishReason(reason string) message.FinishReason {
138	switch reason {
139	case "end_turn":
140		return message.FinishReasonEndTurn
141	case "max_tokens":
142		return message.FinishReasonMaxTokens
143	case "tool_use":
144		return message.FinishReasonToolUse
145	case "stop_sequence":
146		return message.FinishReasonEndTurn
147	default:
148		return message.FinishReasonUnknown
149	}
150}
151
152func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
153	model := a.providerOptions.model(a.providerOptions.modelType)
154	var thinkingParam anthropic.ThinkingConfigParamUnion
155	cfg := config.Get()
156	modelConfig := cfg.Models.Large
157	if a.providerOptions.modelType == config.SmallModel {
158		modelConfig = cfg.Models.Small
159	}
160	temperature := anthropic.Float(0)
161
162	if a.Model().CanReason && modelConfig.Think {
163		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
164		temperature = anthropic.Float(1)
165	}
166
167	maxTokens := model.DefaultMaxTokens
168	if modelConfig.MaxTokens > 0 {
169		maxTokens = modelConfig.MaxTokens
170	}
171
172	// Override max tokens if set in provider options
173	if a.providerOptions.maxTokens > 0 {
174		maxTokens = a.providerOptions.maxTokens
175	}
176
177	// Use adjusted max tokens if context limit was hit
178	if a.adjustedMaxTokens > 0 {
179		maxTokens = int64(a.adjustedMaxTokens)
180	}
181
182	return anthropic.MessageNewParams{
183		Model:       anthropic.Model(model.ID),
184		MaxTokens:   maxTokens,
185		Temperature: temperature,
186		Messages:    messages,
187		Tools:       tools,
188		Thinking:    thinkingParam,
189		System: []anthropic.TextBlockParam{
190			{
191				Text: a.providerOptions.systemMessage,
192				CacheControl: anthropic.CacheControlEphemeralParam{
193					Type: "ephemeral",
194				},
195			},
196		},
197	}
198}
199
200func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
201	cfg := config.Get()
202
203	attempts := 0
204	for {
205		attempts++
206		// Prepare messages on each attempt in case max_tokens was adjusted
207		preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
208		if cfg.Options.Debug {
209			jsonData, _ := json.Marshal(preparedMessages)
210			logging.Debug("Prepared messages", "messages", string(jsonData))
211		}
212
213		anthropicResponse, err := a.client.Messages.New(
214			ctx,
215			preparedMessages,
216		)
217		// If there is an error we are going to see if we can retry the call
218		if err != nil {
219			logging.Error("Error in Anthropic API call", "error", err)
220			retry, after, retryErr := a.shouldRetry(attempts, err)
221			if retryErr != nil {
222				return nil, retryErr
223			}
224			if retry {
225				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
226				select {
227				case <-ctx.Done():
228					return nil, ctx.Err()
229				case <-time.After(time.Duration(after) * time.Millisecond):
230					continue
231				}
232			}
233			return nil, retryErr
234		}
235
236		content := ""
237		for _, block := range anthropicResponse.Content {
238			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
239				content += text.Text
240			}
241		}
242
243		return &ProviderResponse{
244			Content:   content,
245			ToolCalls: a.toolCalls(*anthropicResponse),
246			Usage:     a.usage(*anthropicResponse),
247		}, nil
248	}
249}
250
251func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
252	cfg := config.Get()
253	attempts := 0
254	eventChan := make(chan ProviderEvent)
255	go func() {
256		for {
257			attempts++
258			// Prepare messages on each attempt in case max_tokens was adjusted
259			preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
260			if cfg.Options.Debug {
261				jsonData, _ := json.Marshal(preparedMessages)
262				logging.Debug("Prepared messages", "messages", string(jsonData))
263			}
264
265			anthropicStream := a.client.Messages.NewStreaming(
266				ctx,
267				preparedMessages,
268			)
269			accumulatedMessage := anthropic.Message{}
270
271			currentToolCallID := ""
272			for anthropicStream.Next() {
273				event := anthropicStream.Current()
274				err := accumulatedMessage.Accumulate(event)
275				if err != nil {
276					logging.Warn("Error accumulating message", "error", err)
277					continue
278				}
279
280				switch event := event.AsAny().(type) {
281				case anthropic.ContentBlockStartEvent:
282					switch event.ContentBlock.Type {
283					case "text":
284						eventChan <- ProviderEvent{Type: EventContentStart}
285					case "tool_use":
286						currentToolCallID = event.ContentBlock.ID
287						eventChan <- ProviderEvent{
288							Type: EventToolUseStart,
289							ToolCall: &message.ToolCall{
290								ID:       event.ContentBlock.ID,
291								Name:     event.ContentBlock.Name,
292								Finished: false,
293							},
294						}
295					}
296
297				case anthropic.ContentBlockDeltaEvent:
298					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
299						eventChan <- ProviderEvent{
300							Type:     EventThinkingDelta,
301							Thinking: event.Delta.Thinking,
302						}
303					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
304						eventChan <- ProviderEvent{
305							Type:    EventContentDelta,
306							Content: event.Delta.Text,
307						}
308					} else if event.Delta.Type == "input_json_delta" {
309						if currentToolCallID != "" {
310							eventChan <- ProviderEvent{
311								Type: EventToolUseDelta,
312								ToolCall: &message.ToolCall{
313									ID:       currentToolCallID,
314									Finished: false,
315									Input:    event.Delta.PartialJSON,
316								},
317							}
318						}
319					}
320				case anthropic.ContentBlockStopEvent:
321					if currentToolCallID != "" {
322						eventChan <- ProviderEvent{
323							Type: EventToolUseStop,
324							ToolCall: &message.ToolCall{
325								ID: currentToolCallID,
326							},
327						}
328						currentToolCallID = ""
329					} else {
330						eventChan <- ProviderEvent{Type: EventContentStop}
331					}
332
333				case anthropic.MessageStopEvent:
334					content := ""
335					for _, block := range accumulatedMessage.Content {
336						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
337							content += text.Text
338						}
339					}
340
341					eventChan <- ProviderEvent{
342						Type: EventComplete,
343						Response: &ProviderResponse{
344							Content:      content,
345							ToolCalls:    a.toolCalls(accumulatedMessage),
346							Usage:        a.usage(accumulatedMessage),
347							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
348						},
349						Content: content,
350					}
351				}
352			}
353
354			err := anthropicStream.Err()
355			if err == nil || errors.Is(err, io.EOF) {
356				close(eventChan)
357				return
358			}
359			// If there is an error we are going to see if we can retry the call
360			retry, after, retryErr := a.shouldRetry(attempts, err)
361			if retryErr != nil {
362				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
363				close(eventChan)
364				return
365			}
366			if retry {
367				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
368				select {
369				case <-ctx.Done():
370					// context cancelled
371					if ctx.Err() != nil {
372						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
373					}
374					close(eventChan)
375					return
376				case <-time.After(time.Duration(after) * time.Millisecond):
377					continue
378				}
379			}
380			if ctx.Err() != nil {
381				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
382			}
383
384			close(eventChan)
385			return
386		}
387	}()
388	return eventChan
389}
390
391func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
392	var apiErr *anthropic.Error
393	if !errors.As(err, &apiErr) {
394		return false, 0, err
395	}
396
397	if attempts > maxRetries {
398		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
399	}
400
401	if apiErr.StatusCode == 401 {
402		a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey)
403		if err != nil {
404			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
405		}
406		a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
407		return true, 0, nil
408	}
409
410	// Handle context limit exceeded error (400 Bad Request)
411	if apiErr.StatusCode == 400 {
412		if adjusted, ok := a.handleContextLimitError(apiErr); ok {
413			a.adjustedMaxTokens = adjusted
414			logging.Debug("Adjusted max_tokens due to context limit", "new_max_tokens", adjusted)
415			return true, 0, nil
416		}
417	}
418
419	if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 {
420		return false, 0, err
421	}
422
423	retryMs := 0
424	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
425
426	backoffMs := 2000 * (1 << (attempts - 1))
427	jitterMs := int(float64(backoffMs) * 0.2)
428	retryMs = backoffMs + jitterMs
429	if len(retryAfterValues) > 0 {
430		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
431			retryMs = retryMs * 1000
432		}
433	}
434	return true, int64(retryMs), nil
435}
436
437// handleContextLimitError parses context limit error and returns adjusted max_tokens
438func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) {
439	// Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000"
440	errorMsg := apiErr.Error()
441	re := regexp.MustCompile(`input length and max_tokens exceed context limit: (\d+) \+ (\d+) > (\d+)`)
442	matches := re.FindStringSubmatch(errorMsg)
443
444	if len(matches) != 4 {
445		return 0, false
446	}
447
448	inputTokens, err1 := strconv.Atoi(matches[1])
449	contextLimit, err2 := strconv.Atoi(matches[3])
450
451	if err1 != nil || err2 != nil {
452		return 0, false
453	}
454
455	// Calculate safe max_tokens with a buffer of 1000 tokens
456	safeMaxTokens := contextLimit - inputTokens - 1000
457
458	// Ensure we don't go below a minimum threshold
459	safeMaxTokens = max(safeMaxTokens, 1000)
460
461	return safeMaxTokens, true
462}
463
464func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
465	var toolCalls []message.ToolCall
466
467	for _, block := range msg.Content {
468		switch variant := block.AsAny().(type) {
469		case anthropic.ToolUseBlock:
470			toolCall := message.ToolCall{
471				ID:       variant.ID,
472				Name:     variant.Name,
473				Input:    string(variant.Input),
474				Type:     string(variant.Type),
475				Finished: true,
476			}
477			toolCalls = append(toolCalls, toolCall)
478		}
479	}
480
481	return toolCalls
482}
483
484func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
485	return TokenUsage{
486		InputTokens:         msg.Usage.InputTokens,
487		OutputTokens:        msg.Usage.OutputTokens,
488		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
489		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
490	}
491}
492
493func (a *anthropicClient) Model() config.Model {
494	return a.providerOptions.model(a.providerOptions.modelType)
495}