1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/anthropics/anthropic-sdk-go"
12 "github.com/anthropics/anthropic-sdk-go/bedrock"
13 "github.com/anthropics/anthropic-sdk-go/option"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/message"
19)
20
21type anthropicClient struct {
22 providerOptions providerClientOptions
23 useBedrock bool
24 client anthropic.Client
25}
26
27type AnthropicClient ProviderClient
28
29func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
30 return &anthropicClient{
31 providerOptions: opts,
32 client: createAnthropicClient(opts, useBedrock),
33 }
34}
35
36func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
37 anthropicClientOptions := []option.RequestOption{}
38 if opts.apiKey != "" {
39 anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
40 }
41 if useBedrock {
42 anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
43 }
44 return anthropic.NewClient(anthropicClientOptions...)
45}
46
47func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
48 for i, msg := range messages {
49 cache := false
50 if i > len(messages)-3 {
51 cache = true
52 }
53 switch msg.Role {
54 case message.User:
55 content := anthropic.NewTextBlock(msg.Content().String())
56 if cache && !a.providerOptions.disableCache {
57 content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
58 Type: "ephemeral",
59 }
60 }
61 var contentBlocks []anthropic.ContentBlockParamUnion
62 contentBlocks = append(contentBlocks, content)
63 for _, binaryContent := range msg.BinaryContent() {
64 base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
65 imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
66 contentBlocks = append(contentBlocks, imageBlock)
67 }
68 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
69
70 case message.Assistant:
71 blocks := []anthropic.ContentBlockParamUnion{}
72 if msg.Content().String() != "" {
73 content := anthropic.NewTextBlock(msg.Content().String())
74 if cache && !a.providerOptions.disableCache {
75 content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
76 Type: "ephemeral",
77 }
78 }
79 blocks = append(blocks, content)
80 }
81
82 for _, toolCall := range msg.ToolCalls() {
83 var inputMap map[string]any
84 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
85 if err != nil {
86 continue
87 }
88 blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
89 }
90
91 if len(blocks) == 0 {
92 logging.Warn("There is a message without content, investigate, this should not happen")
93 continue
94 }
95 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
96
97 case message.Tool:
98 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
99 for i, toolResult := range msg.ToolResults() {
100 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
101 }
102 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
103 }
104 }
105 return
106}
107
108func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
109 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
110
111 for i, tool := range tools {
112 info := tool.Info()
113 toolParam := anthropic.ToolParam{
114 Name: info.Name,
115 Description: anthropic.String(info.Description),
116 InputSchema: anthropic.ToolInputSchemaParam{
117 Properties: info.Parameters,
118 // TODO: figure out how we can tell claude the required fields?
119 },
120 }
121
122 if i == len(tools)-1 && !a.providerOptions.disableCache {
123 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
124 Type: "ephemeral",
125 }
126 }
127
128 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
129 }
130
131 return anthropicTools
132}
133
134func (a *anthropicClient) finishReason(reason string) message.FinishReason {
135 switch reason {
136 case "end_turn":
137 return message.FinishReasonEndTurn
138 case "max_tokens":
139 return message.FinishReasonMaxTokens
140 case "tool_use":
141 return message.FinishReasonToolUse
142 case "stop_sequence":
143 return message.FinishReasonEndTurn
144 default:
145 return message.FinishReasonUnknown
146 }
147}
148
149func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
150 model := a.providerOptions.model(a.providerOptions.modelType)
151 var thinkingParam anthropic.ThinkingConfigParamUnion
152 cfg := config.Get()
153 modelConfig := cfg.Models.Large
154 if a.providerOptions.modelType == config.SmallModel {
155 modelConfig = cfg.Models.Small
156 }
157 temperature := anthropic.Float(0)
158
159 if a.Model().CanReason && modelConfig.Think {
160 thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
161 temperature = anthropic.Float(1)
162 }
163
164 maxTokens := model.DefaultMaxTokens
165 if modelConfig.MaxTokens > 0 {
166 maxTokens = modelConfig.MaxTokens
167 }
168
169 // Override max tokens if set in provider options
170 if a.providerOptions.maxTokens > 0 {
171 maxTokens = a.providerOptions.maxTokens
172 }
173
174 return anthropic.MessageNewParams{
175 Model: anthropic.Model(model.ID),
176 MaxTokens: maxTokens,
177 Temperature: temperature,
178 Messages: messages,
179 Tools: tools,
180 Thinking: thinkingParam,
181 System: []anthropic.TextBlockParam{
182 {
183 Text: a.providerOptions.systemMessage,
184 CacheControl: anthropic.CacheControlEphemeralParam{
185 Type: "ephemeral",
186 },
187 },
188 },
189 }
190}
191
192func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
193 preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
194 cfg := config.Get()
195 if cfg.Options.Debug {
196 jsonData, _ := json.Marshal(preparedMessages)
197 logging.Debug("Prepared messages", "messages", string(jsonData))
198 }
199
200 attempts := 0
201 for {
202 attempts++
203 anthropicResponse, err := a.client.Messages.New(
204 ctx,
205 preparedMessages,
206 )
207 // If there is an error we are going to see if we can retry the call
208 if err != nil {
209 logging.Error("Error in Anthropic API call", "error", err)
210 retry, after, retryErr := a.shouldRetry(attempts, err)
211 if retryErr != nil {
212 return nil, retryErr
213 }
214 if retry {
215 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
216 select {
217 case <-ctx.Done():
218 return nil, ctx.Err()
219 case <-time.After(time.Duration(after) * time.Millisecond):
220 continue
221 }
222 }
223 return nil, retryErr
224 }
225
226 content := ""
227 for _, block := range anthropicResponse.Content {
228 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
229 content += text.Text
230 }
231 }
232
233 return &ProviderResponse{
234 Content: content,
235 ToolCalls: a.toolCalls(*anthropicResponse),
236 Usage: a.usage(*anthropicResponse),
237 }, nil
238 }
239}
240
241func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
242 preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
243 cfg := config.Get()
244 if cfg.Options.Debug {
245 // jsonData, _ := json.Marshal(preparedMessages)
246 // logging.Debug("Prepared messages", "messages", string(jsonData))
247 }
248 attempts := 0
249 eventChan := make(chan ProviderEvent)
250 go func() {
251 for {
252 attempts++
253 anthropicStream := a.client.Messages.NewStreaming(
254 ctx,
255 preparedMessages,
256 )
257 accumulatedMessage := anthropic.Message{}
258
259 currentToolCallID := ""
260 for anthropicStream.Next() {
261 event := anthropicStream.Current()
262 err := accumulatedMessage.Accumulate(event)
263 if err != nil {
264 logging.Warn("Error accumulating message", "error", err)
265 continue
266 }
267
268 switch event := event.AsAny().(type) {
269 case anthropic.ContentBlockStartEvent:
270 switch event.ContentBlock.Type {
271 case "text":
272 eventChan <- ProviderEvent{Type: EventContentStart}
273 case "tool_use":
274 currentToolCallID = event.ContentBlock.ID
275 eventChan <- ProviderEvent{
276 Type: EventToolUseStart,
277 ToolCall: &message.ToolCall{
278 ID: event.ContentBlock.ID,
279 Name: event.ContentBlock.Name,
280 Finished: false,
281 },
282 }
283 }
284
285 case anthropic.ContentBlockDeltaEvent:
286 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
287 eventChan <- ProviderEvent{
288 Type: EventThinkingDelta,
289 Thinking: event.Delta.Thinking,
290 }
291 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
292 eventChan <- ProviderEvent{
293 Type: EventContentDelta,
294 Content: event.Delta.Text,
295 }
296 } else if event.Delta.Type == "input_json_delta" {
297 if currentToolCallID != "" {
298 eventChan <- ProviderEvent{
299 Type: EventToolUseDelta,
300 ToolCall: &message.ToolCall{
301 ID: currentToolCallID,
302 Finished: false,
303 Input: event.Delta.PartialJSON,
304 },
305 }
306 }
307 }
308 case anthropic.ContentBlockStopEvent:
309 if currentToolCallID != "" {
310 eventChan <- ProviderEvent{
311 Type: EventToolUseStop,
312 ToolCall: &message.ToolCall{
313 ID: currentToolCallID,
314 },
315 }
316 currentToolCallID = ""
317 } else {
318 eventChan <- ProviderEvent{Type: EventContentStop}
319 }
320
321 case anthropic.MessageStopEvent:
322 content := ""
323 for _, block := range accumulatedMessage.Content {
324 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
325 content += text.Text
326 }
327 }
328
329 eventChan <- ProviderEvent{
330 Type: EventComplete,
331 Response: &ProviderResponse{
332 Content: content,
333 ToolCalls: a.toolCalls(accumulatedMessage),
334 Usage: a.usage(accumulatedMessage),
335 FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
336 },
337 Content: content,
338 }
339 }
340 }
341
342 err := anthropicStream.Err()
343 if err == nil || errors.Is(err, io.EOF) {
344 close(eventChan)
345 return
346 }
347 // If there is an error we are going to see if we can retry the call
348 retry, after, retryErr := a.shouldRetry(attempts, err)
349 if retryErr != nil {
350 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
351 close(eventChan)
352 return
353 }
354 if retry {
355 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
356 select {
357 case <-ctx.Done():
358 // context cancelled
359 if ctx.Err() != nil {
360 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
361 }
362 close(eventChan)
363 return
364 case <-time.After(time.Duration(after) * time.Millisecond):
365 continue
366 }
367 }
368 if ctx.Err() != nil {
369 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
370 }
371
372 close(eventChan)
373 return
374 }
375 }()
376 return eventChan
377}
378
379func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
380 var apiErr *anthropic.Error
381 if !errors.As(err, &apiErr) {
382 return false, 0, err
383 }
384
385 if apiErr.StatusCode == 401 {
386 a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey)
387 if err != nil {
388 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
389 }
390 a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
391 return true, 0, nil
392 }
393
394 if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 {
395 return false, 0, err
396 }
397
398 if attempts > maxRetries {
399 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
400 }
401
402 retryMs := 0
403 retryAfterValues := apiErr.Response.Header.Values("Retry-After")
404
405 backoffMs := 2000 * (1 << (attempts - 1))
406 jitterMs := int(float64(backoffMs) * 0.2)
407 retryMs = backoffMs + jitterMs
408 if len(retryAfterValues) > 0 {
409 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
410 retryMs = retryMs * 1000
411 }
412 }
413 return true, int64(retryMs), nil
414}
415
416func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
417 var toolCalls []message.ToolCall
418
419 for _, block := range msg.Content {
420 switch variant := block.AsAny().(type) {
421 case anthropic.ToolUseBlock:
422 toolCall := message.ToolCall{
423 ID: variant.ID,
424 Name: variant.Name,
425 Input: string(variant.Input),
426 Type: string(variant.Type),
427 Finished: true,
428 }
429 toolCalls = append(toolCalls, toolCall)
430 }
431 }
432
433 return toolCalls
434}
435
436func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
437 return TokenUsage{
438 InputTokens: msg.Usage.InputTokens,
439 OutputTokens: msg.Usage.OutputTokens,
440 CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
441 CacheReadTokens: msg.Usage.CacheReadInputTokens,
442 }
443}
444
445func (a *anthropicClient) Model() config.Model {
446 return a.providerOptions.model(a.providerOptions.modelType)
447}