1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "regexp"
11 "strconv"
12 "strings"
13 "time"
14
15 "github.com/anthropics/anthropic-sdk-go"
16 "github.com/anthropics/anthropic-sdk-go/bedrock"
17 "github.com/anthropics/anthropic-sdk-go/option"
18 "github.com/charmbracelet/catwalk/pkg/catwalk"
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/llm/tools"
21 "github.com/charmbracelet/crush/internal/message"
22)
23
24var (
25 // Pre-compiled regex for parsing context limit errors.
26 contextLimitRegex = regexp.MustCompile(`input length and ` + "`max_tokens`" + ` exceed context limit: (\d+) \+ (\d+) > (\d+)`)
27)
28
29type anthropicClient struct {
30 providerOptions providerClientOptions
31 useBedrock bool
32 client anthropic.Client
33 adjustedMaxTokens int // Used when context limit is hit
34}
35
36type AnthropicClient ProviderClient
37
38func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
39 return &anthropicClient{
40 providerOptions: opts,
41 client: createAnthropicClient(opts, useBedrock),
42 }
43}
44
45func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
46 anthropicClientOptions := []option.RequestOption{}
47
48 // Check if Authorization header is provided in extra headers
49 hasBearerAuth := false
50 if opts.extraHeaders != nil {
51 for key := range opts.extraHeaders {
52 if strings.ToLower(key) == "authorization" {
53 hasBearerAuth = true
54 break
55 }
56 }
57 }
58
59 isBearerToken := strings.HasPrefix(opts.apiKey, "Bearer ")
60
61 if opts.apiKey != "" && !hasBearerAuth {
62 if isBearerToken {
63 slog.Debug("API key starts with 'Bearer ', using as Authorization header")
64 anthropicClientOptions = append(anthropicClientOptions, option.WithHeader("Authorization", opts.apiKey))
65 } else {
66 // Use standard X-Api-Key header
67 anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
68 }
69 } else if hasBearerAuth {
70 slog.Debug("Skipping X-Api-Key header because Authorization header is provided")
71 }
72 if useBedrock {
73 anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
74 }
75 for _, header := range opts.extraHeaders {
76 anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(header, opts.extraHeaders[header]))
77 }
78 for key, value := range opts.extraBody {
79 anthropicClientOptions = append(anthropicClientOptions, option.WithJSONSet(key, value))
80 }
81 return anthropic.NewClient(anthropicClientOptions...)
82}
83
84func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
85 for i, msg := range messages {
86 cache := false
87 if i > len(messages)-3 {
88 cache = true
89 }
90 switch msg.Role {
91 case message.User:
92 content := anthropic.NewTextBlock(msg.Content().String())
93 if cache && !a.providerOptions.disableCache {
94 content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
95 Type: "ephemeral",
96 }
97 }
98 var contentBlocks []anthropic.ContentBlockParamUnion
99 contentBlocks = append(contentBlocks, content)
100 for _, binaryContent := range msg.BinaryContent() {
101 base64Image := binaryContent.String(catwalk.InferenceProviderAnthropic)
102 imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
103 contentBlocks = append(contentBlocks, imageBlock)
104 }
105 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
106
107 case message.Assistant:
108 blocks := []anthropic.ContentBlockParamUnion{}
109
110 // Add thinking blocks first if present (required when thinking is enabled with tool use)
111 if reasoningContent := msg.ReasoningContent(); reasoningContent.Thinking != "" {
112 thinkingBlock := anthropic.NewThinkingBlock(reasoningContent.Signature, reasoningContent.Thinking)
113 blocks = append(blocks, thinkingBlock)
114 }
115
116 if msg.Content().String() != "" {
117 content := anthropic.NewTextBlock(msg.Content().String())
118 if cache && !a.providerOptions.disableCache {
119 content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
120 Type: "ephemeral",
121 }
122 }
123 blocks = append(blocks, content)
124 }
125
126 for _, toolCall := range msg.ToolCalls() {
127 var inputMap map[string]any
128 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
129 if err != nil {
130 continue
131 }
132 blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
133 }
134
135 if len(blocks) == 0 {
136 slog.Warn("There is a message without content, investigate, this should not happen")
137 continue
138 }
139 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
140
141 case message.Tool:
142 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
143 for i, toolResult := range msg.ToolResults() {
144 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
145 }
146 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
147 }
148 }
149 return
150}
151
152func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
153 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
154
155 for i, tool := range tools {
156 info := tool.Info()
157 toolParam := anthropic.ToolParam{
158 Name: info.Name,
159 Description: anthropic.String(info.Description),
160 InputSchema: anthropic.ToolInputSchemaParam{
161 Properties: info.Parameters,
162 // TODO: figure out how we can tell claude the required fields?
163 },
164 }
165
166 if i == len(tools)-1 && !a.providerOptions.disableCache {
167 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
168 Type: "ephemeral",
169 }
170 }
171
172 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
173 }
174
175 return anthropicTools
176}
177
178func (a *anthropicClient) finishReason(reason string) message.FinishReason {
179 switch reason {
180 case "end_turn":
181 return message.FinishReasonEndTurn
182 case "max_tokens":
183 return message.FinishReasonMaxTokens
184 case "tool_use":
185 return message.FinishReasonToolUse
186 case "stop_sequence":
187 return message.FinishReasonEndTurn
188 default:
189 return message.FinishReasonUnknown
190 }
191}
192
193func (a *anthropicClient) isThinkingEnabled() bool {
194 cfg := config.Get()
195 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
196 if a.providerOptions.modelType == config.SelectedModelTypeSmall {
197 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
198 }
199 return a.Model().CanReason && modelConfig.Think
200}
201
202func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
203 model := a.providerOptions.model(a.providerOptions.modelType)
204 var thinkingParam anthropic.ThinkingConfigParamUnion
205 cfg := config.Get()
206 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
207 if a.providerOptions.modelType == config.SelectedModelTypeSmall {
208 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
209 }
210 temperature := anthropic.Float(0)
211
212 maxTokens := model.DefaultMaxTokens
213 if modelConfig.MaxTokens > 0 {
214 maxTokens = modelConfig.MaxTokens
215 }
216 if a.isThinkingEnabled() {
217 thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(maxTokens) * 0.8))
218 temperature = anthropic.Float(1)
219 }
220 // Override max tokens if set in provider options
221 if a.providerOptions.maxTokens > 0 {
222 maxTokens = a.providerOptions.maxTokens
223 }
224
225 // Use adjusted max tokens if context limit was hit
226 if a.adjustedMaxTokens > 0 {
227 maxTokens = int64(a.adjustedMaxTokens)
228 }
229
230 systemBlocks := []anthropic.TextBlockParam{}
231
232 // Add custom system prompt prefix if configured
233 if a.providerOptions.systemPromptPrefix != "" {
234 systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
235 Text: a.providerOptions.systemPromptPrefix,
236 CacheControl: anthropic.CacheControlEphemeralParam{
237 Type: "ephemeral",
238 },
239 })
240 }
241
242 systemBlocks = append(systemBlocks, anthropic.TextBlockParam{
243 Text: a.providerOptions.systemMessage,
244 CacheControl: anthropic.CacheControlEphemeralParam{
245 Type: "ephemeral",
246 },
247 })
248
249 return anthropic.MessageNewParams{
250 Model: anthropic.Model(model.ID),
251 MaxTokens: maxTokens,
252 Temperature: temperature,
253 Messages: messages,
254 Tools: tools,
255 Thinking: thinkingParam,
256 System: systemBlocks,
257 }
258}
259
260func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
261 cfg := config.Get()
262
263 attempts := 0
264 for {
265 attempts++
266 // Prepare messages on each attempt in case max_tokens was adjusted
267 preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
268 if cfg.Options.Debug {
269 jsonData, _ := json.Marshal(preparedMessages)
270 slog.Debug("Prepared messages", "messages", string(jsonData))
271 }
272
273 var opts []option.RequestOption
274 if a.isThinkingEnabled() {
275 opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14"))
276 }
277 anthropicResponse, err := a.client.Messages.New(
278 ctx,
279 preparedMessages,
280 opts...,
281 )
282 // If there is an error we are going to see if we can retry the call
283 if err != nil {
284 slog.Error("Error in Anthropic API call", "error", err)
285 retry, after, retryErr := a.shouldRetry(attempts, err)
286 if retryErr != nil {
287 return nil, retryErr
288 }
289 if retry {
290 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
291 select {
292 case <-ctx.Done():
293 return nil, ctx.Err()
294 case <-time.After(time.Duration(after) * time.Millisecond):
295 continue
296 }
297 }
298 return nil, retryErr
299 }
300
301 content := ""
302 for _, block := range anthropicResponse.Content {
303 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
304 content += text.Text
305 }
306 }
307
308 return &ProviderResponse{
309 Content: content,
310 ToolCalls: a.toolCalls(*anthropicResponse),
311 Usage: a.usage(*anthropicResponse),
312 }, nil
313 }
314}
315
316func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
317 cfg := config.Get()
318 attempts := 0
319 eventChan := make(chan ProviderEvent)
320 go func() {
321 for {
322 attempts++
323 // Prepare messages on each attempt in case max_tokens was adjusted
324 preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
325 if cfg.Options.Debug {
326 jsonData, _ := json.Marshal(preparedMessages)
327 slog.Debug("Prepared messages", "messages", string(jsonData))
328 }
329
330 var opts []option.RequestOption
331 if a.isThinkingEnabled() {
332 opts = append(opts, option.WithHeaderAdd("anthropic-beta", "interleaved-thinking-2025-05-14"))
333 }
334
335 anthropicStream := a.client.Messages.NewStreaming(
336 ctx,
337 preparedMessages,
338 opts...,
339 )
340 accumulatedMessage := anthropic.Message{}
341
342 currentToolCallID := ""
343 for anthropicStream.Next() {
344 event := anthropicStream.Current()
345 err := accumulatedMessage.Accumulate(event)
346 if err != nil {
347 slog.Warn("Error accumulating message", "error", err)
348 continue
349 }
350
351 switch event := event.AsAny().(type) {
352 case anthropic.ContentBlockStartEvent:
353 switch event.ContentBlock.Type {
354 case "text":
355 eventChan <- ProviderEvent{Type: EventContentStart}
356 case "tool_use":
357 currentToolCallID = event.ContentBlock.ID
358 eventChan <- ProviderEvent{
359 Type: EventToolUseStart,
360 ToolCall: &message.ToolCall{
361 ID: event.ContentBlock.ID,
362 Name: event.ContentBlock.Name,
363 Finished: false,
364 },
365 }
366 }
367
368 case anthropic.ContentBlockDeltaEvent:
369 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
370 eventChan <- ProviderEvent{
371 Type: EventThinkingDelta,
372 Thinking: event.Delta.Thinking,
373 }
374 } else if event.Delta.Type == "signature_delta" && event.Delta.Signature != "" {
375 eventChan <- ProviderEvent{
376 Type: EventSignatureDelta,
377 Signature: event.Delta.Signature,
378 }
379 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
380 eventChan <- ProviderEvent{
381 Type: EventContentDelta,
382 Content: event.Delta.Text,
383 }
384 } else if event.Delta.Type == "input_json_delta" {
385 if currentToolCallID != "" {
386 eventChan <- ProviderEvent{
387 Type: EventToolUseDelta,
388 ToolCall: &message.ToolCall{
389 ID: currentToolCallID,
390 Finished: false,
391 Input: event.Delta.PartialJSON,
392 },
393 }
394 }
395 }
396 case anthropic.ContentBlockStopEvent:
397 if currentToolCallID != "" {
398 eventChan <- ProviderEvent{
399 Type: EventToolUseStop,
400 ToolCall: &message.ToolCall{
401 ID: currentToolCallID,
402 },
403 }
404 currentToolCallID = ""
405 } else {
406 eventChan <- ProviderEvent{Type: EventContentStop}
407 }
408
409 case anthropic.MessageStopEvent:
410 content := ""
411 for _, block := range accumulatedMessage.Content {
412 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
413 content += text.Text
414 }
415 }
416
417 eventChan <- ProviderEvent{
418 Type: EventComplete,
419 Response: &ProviderResponse{
420 Content: content,
421 ToolCalls: a.toolCalls(accumulatedMessage),
422 Usage: a.usage(accumulatedMessage),
423 FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
424 },
425 Content: content,
426 }
427 }
428 }
429
430 err := anthropicStream.Err()
431 if err == nil || errors.Is(err, io.EOF) {
432 close(eventChan)
433 return
434 }
435
436 // If there is an error we are going to see if we can retry the call
437 retry, after, retryErr := a.shouldRetry(attempts, err)
438 if retryErr != nil {
439 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
440 close(eventChan)
441 return
442 }
443 if retry {
444 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
445 select {
446 case <-ctx.Done():
447 // context cancelled
448 if ctx.Err() != nil {
449 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
450 }
451 close(eventChan)
452 return
453 case <-time.After(time.Duration(after) * time.Millisecond):
454 continue
455 }
456 }
457 if ctx.Err() != nil {
458 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
459 }
460
461 close(eventChan)
462 return
463 }
464 }()
465 return eventChan
466}
467
468func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
469 var apiErr *anthropic.Error
470 if !errors.As(err, &apiErr) {
471 return false, 0, err
472 }
473
474 if attempts > maxRetries {
475 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
476 }
477
478 if apiErr.StatusCode == 401 {
479 a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
480 if err != nil {
481 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
482 }
483 a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
484 return true, 0, nil
485 }
486
487 // Handle context limit exceeded error (400 Bad Request)
488 if apiErr.StatusCode == 400 {
489 if adjusted, ok := a.handleContextLimitError(apiErr); ok {
490 a.adjustedMaxTokens = adjusted
491 slog.Debug("Adjusted max_tokens due to context limit", "new_max_tokens", adjusted)
492 return true, 0, nil
493 }
494 }
495
496 isOverloaded := strings.Contains(apiErr.Error(), "overloaded") || strings.Contains(apiErr.Error(), "rate limit exceeded")
497 if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 && !isOverloaded {
498 return false, 0, err
499 }
500
501 retryMs := 0
502 retryAfterValues := apiErr.Response.Header.Values("Retry-After")
503
504 backoffMs := 2000 * (1 << (attempts - 1))
505 jitterMs := int(float64(backoffMs) * 0.2)
506 retryMs = backoffMs + jitterMs
507 if len(retryAfterValues) > 0 {
508 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
509 retryMs = retryMs * 1000
510 }
511 }
512 return true, int64(retryMs), nil
513}
514
515// handleContextLimitError parses context limit error and returns adjusted max_tokens
516func (a *anthropicClient) handleContextLimitError(apiErr *anthropic.Error) (int, bool) {
517 // Parse error message like: "input length and max_tokens exceed context limit: 154978 + 50000 > 200000"
518 errorMsg := apiErr.Error()
519
520 matches := contextLimitRegex.FindStringSubmatch(errorMsg)
521
522 if len(matches) != 4 {
523 return 0, false
524 }
525
526 inputTokens, err1 := strconv.Atoi(matches[1])
527 contextLimit, err2 := strconv.Atoi(matches[3])
528
529 if err1 != nil || err2 != nil {
530 return 0, false
531 }
532
533 // Calculate safe max_tokens with a buffer of 1000 tokens
534 safeMaxTokens := contextLimit - inputTokens - 1000
535
536 // Ensure we don't go below a minimum threshold
537 safeMaxTokens = max(safeMaxTokens, 1000)
538
539 return safeMaxTokens, true
540}
541
542func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
543 var toolCalls []message.ToolCall
544
545 for _, block := range msg.Content {
546 switch variant := block.AsAny().(type) {
547 case anthropic.ToolUseBlock:
548 toolCall := message.ToolCall{
549 ID: variant.ID,
550 Name: variant.Name,
551 Input: string(variant.Input),
552 Type: string(variant.Type),
553 Finished: true,
554 }
555 toolCalls = append(toolCalls, toolCall)
556 }
557 }
558
559 return toolCalls
560}
561
562func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
563 return TokenUsage{
564 InputTokens: msg.Usage.InputTokens,
565 OutputTokens: msg.Usage.OutputTokens,
566 CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
567 CacheReadTokens: msg.Usage.CacheReadInputTokens,
568 }
569}
570
571func (a *anthropicClient) Model() catwalk.Model {
572 return a.providerOptions.model(a.providerOptions.modelType)
573}