anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"time"
 10
 11	"github.com/anthropics/anthropic-sdk-go"
 12	"github.com/anthropics/anthropic-sdk-go/bedrock"
 13	"github.com/anthropics/anthropic-sdk-go/option"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/fur/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/message"
 19)
 20
 21type anthropicClient struct {
 22	providerOptions providerClientOptions
 23	useBedrock      bool
 24	client          anthropic.Client
 25}
 26
 27type AnthropicClient ProviderClient
 28
 29func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 30	return &anthropicClient{
 31		providerOptions: opts,
 32		client:          createAnthropicClient(opts, useBedrock),
 33	}
 34}
 35
 36func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client {
 37	anthropicClientOptions := []option.RequestOption{}
 38	if opts.apiKey != "" {
 39		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 40	}
 41	if useBedrock {
 42		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 43	}
 44	return anthropic.NewClient(anthropicClientOptions...)
 45}
 46
 47func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 48	for i, msg := range messages {
 49		cache := false
 50		if i > len(messages)-3 {
 51			cache = true
 52		}
 53		switch msg.Role {
 54		case message.User:
 55			content := anthropic.NewTextBlock(msg.Content().String())
 56			if cache && !a.providerOptions.disableCache {
 57				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 58					Type: "ephemeral",
 59				}
 60			}
 61			var contentBlocks []anthropic.ContentBlockParamUnion
 62			contentBlocks = append(contentBlocks, content)
 63			for _, binaryContent := range msg.BinaryContent() {
 64				base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
 65				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
 66				contentBlocks = append(contentBlocks, imageBlock)
 67			}
 68			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
 69
 70		case message.Assistant:
 71			blocks := []anthropic.ContentBlockParamUnion{}
 72			if msg.Content().String() != "" {
 73				content := anthropic.NewTextBlock(msg.Content().String())
 74				if cache && !a.providerOptions.disableCache {
 75					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 76						Type: "ephemeral",
 77					}
 78				}
 79				blocks = append(blocks, content)
 80			}
 81
 82			for _, toolCall := range msg.ToolCalls() {
 83				var inputMap map[string]any
 84				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
 85				if err != nil {
 86					continue
 87				}
 88				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
 89			}
 90
 91			if len(blocks) == 0 {
 92				logging.Warn("There is a message without content, investigate, this should not happen")
 93				continue
 94			}
 95			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
 96
 97		case message.Tool:
 98			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
 99			for i, toolResult := range msg.ToolResults() {
100				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
101			}
102			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
103		}
104	}
105	return
106}
107
108func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
109	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
110
111	for i, tool := range tools {
112		info := tool.Info()
113		toolParam := anthropic.ToolParam{
114			Name:        info.Name,
115			Description: anthropic.String(info.Description),
116			InputSchema: anthropic.ToolInputSchemaParam{
117				Properties: info.Parameters,
118				// TODO: figure out how we can tell claude the required fields?
119			},
120		}
121
122		if i == len(tools)-1 && !a.providerOptions.disableCache {
123			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
124				Type: "ephemeral",
125			}
126		}
127
128		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
129	}
130
131	return anthropicTools
132}
133
134func (a *anthropicClient) finishReason(reason string) message.FinishReason {
135	switch reason {
136	case "end_turn":
137		return message.FinishReasonEndTurn
138	case "max_tokens":
139		return message.FinishReasonMaxTokens
140	case "tool_use":
141		return message.FinishReasonToolUse
142	case "stop_sequence":
143		return message.FinishReasonEndTurn
144	default:
145		return message.FinishReasonUnknown
146	}
147}
148
149func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
150	model := a.providerOptions.model(a.providerOptions.modelType)
151	var thinkingParam anthropic.ThinkingConfigParamUnion
152	cfg := config.Get()
153	modelConfig := cfg.Models.Large
154	if a.providerOptions.modelType == config.SmallModel {
155		modelConfig = cfg.Models.Small
156	}
157	temperature := anthropic.Float(0)
158
159	if a.Model().CanReason && modelConfig.Think {
160		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
161		temperature = anthropic.Float(1)
162	}
163
164	maxTokens := model.DefaultMaxTokens
165	if modelConfig.MaxTokens > 0 {
166		maxTokens = modelConfig.MaxTokens
167	}
168
169	// Override max tokens if set in provider options
170	if a.providerOptions.maxTokens > 0 {
171		maxTokens = a.providerOptions.maxTokens
172	}
173
174	return anthropic.MessageNewParams{
175		Model:       anthropic.Model(model.ID),
176		MaxTokens:   maxTokens,
177		Temperature: temperature,
178		Messages:    messages,
179		Tools:       tools,
180		Thinking:    thinkingParam,
181		System: []anthropic.TextBlockParam{
182			{
183				Text: a.providerOptions.systemMessage,
184				CacheControl: anthropic.CacheControlEphemeralParam{
185					Type: "ephemeral",
186				},
187			},
188		},
189	}
190}
191
192func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
193	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
194	cfg := config.Get()
195	if cfg.Options.Debug {
196		jsonData, _ := json.Marshal(preparedMessages)
197		logging.Debug("Prepared messages", "messages", string(jsonData))
198	}
199
200	attempts := 0
201	for {
202		attempts++
203		anthropicResponse, err := a.client.Messages.New(
204			ctx,
205			preparedMessages,
206		)
207		// If there is an error we are going to see if we can retry the call
208		if err != nil {
209			logging.Error("Error in Anthropic API call", "error", err)
210			retry, after, retryErr := a.shouldRetry(attempts, err)
211			if retryErr != nil {
212				return nil, retryErr
213			}
214			if retry {
215				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
216				select {
217				case <-ctx.Done():
218					return nil, ctx.Err()
219				case <-time.After(time.Duration(after) * time.Millisecond):
220					continue
221				}
222			}
223			return nil, retryErr
224		}
225
226		content := ""
227		for _, block := range anthropicResponse.Content {
228			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
229				content += text.Text
230			}
231		}
232
233		return &ProviderResponse{
234			Content:   content,
235			ToolCalls: a.toolCalls(*anthropicResponse),
236			Usage:     a.usage(*anthropicResponse),
237		}, nil
238	}
239}
240
241func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
242	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
243	cfg := config.Get()
244	if cfg.Options.Debug {
245		// jsonData, _ := json.Marshal(preparedMessages)
246		// logging.Debug("Prepared messages", "messages", string(jsonData))
247	}
248	attempts := 0
249	eventChan := make(chan ProviderEvent)
250	go func() {
251		for {
252			attempts++
253			anthropicStream := a.client.Messages.NewStreaming(
254				ctx,
255				preparedMessages,
256			)
257			accumulatedMessage := anthropic.Message{}
258
259			currentToolCallID := ""
260			for anthropicStream.Next() {
261				event := anthropicStream.Current()
262				err := accumulatedMessage.Accumulate(event)
263				if err != nil {
264					logging.Warn("Error accumulating message", "error", err)
265					continue
266				}
267
268				switch event := event.AsAny().(type) {
269				case anthropic.ContentBlockStartEvent:
270					switch event.ContentBlock.Type {
271					case "text":
272						eventChan <- ProviderEvent{Type: EventContentStart}
273					case "tool_use":
274						currentToolCallID = event.ContentBlock.ID
275						eventChan <- ProviderEvent{
276							Type: EventToolUseStart,
277							ToolCall: &message.ToolCall{
278								ID:       event.ContentBlock.ID,
279								Name:     event.ContentBlock.Name,
280								Finished: false,
281							},
282						}
283					}
284
285				case anthropic.ContentBlockDeltaEvent:
286					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
287						eventChan <- ProviderEvent{
288							Type:     EventThinkingDelta,
289							Thinking: event.Delta.Thinking,
290						}
291					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
292						eventChan <- ProviderEvent{
293							Type:    EventContentDelta,
294							Content: event.Delta.Text,
295						}
296					} else if event.Delta.Type == "input_json_delta" {
297						if currentToolCallID != "" {
298							eventChan <- ProviderEvent{
299								Type: EventToolUseDelta,
300								ToolCall: &message.ToolCall{
301									ID:       currentToolCallID,
302									Finished: false,
303									Input:    event.Delta.PartialJSON,
304								},
305							}
306						}
307					}
308				case anthropic.ContentBlockStopEvent:
309					if currentToolCallID != "" {
310						eventChan <- ProviderEvent{
311							Type: EventToolUseStop,
312							ToolCall: &message.ToolCall{
313								ID: currentToolCallID,
314							},
315						}
316						currentToolCallID = ""
317					} else {
318						eventChan <- ProviderEvent{Type: EventContentStop}
319					}
320
321				case anthropic.MessageStopEvent:
322					content := ""
323					for _, block := range accumulatedMessage.Content {
324						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
325							content += text.Text
326						}
327					}
328
329					eventChan <- ProviderEvent{
330						Type: EventComplete,
331						Response: &ProviderResponse{
332							Content:      content,
333							ToolCalls:    a.toolCalls(accumulatedMessage),
334							Usage:        a.usage(accumulatedMessage),
335							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
336						},
337						Content: content,
338					}
339				}
340			}
341
342			err := anthropicStream.Err()
343			if err == nil || errors.Is(err, io.EOF) {
344				close(eventChan)
345				return
346			}
347			// If there is an error we are going to see if we can retry the call
348			retry, after, retryErr := a.shouldRetry(attempts, err)
349			if retryErr != nil {
350				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
351				close(eventChan)
352				return
353			}
354			if retry {
355				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
356				select {
357				case <-ctx.Done():
358					// context cancelled
359					if ctx.Err() != nil {
360						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
361					}
362					close(eventChan)
363					return
364				case <-time.After(time.Duration(after) * time.Millisecond):
365					continue
366				}
367			}
368			if ctx.Err() != nil {
369				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
370			}
371
372			close(eventChan)
373			return
374		}
375	}()
376	return eventChan
377}
378
379func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
380	var apiErr *anthropic.Error
381	if !errors.As(err, &apiErr) {
382		return false, 0, err
383	}
384
385	if apiErr.StatusCode == 401 {
386		a.providerOptions.apiKey, err = config.ResolveAPIKey(a.providerOptions.config.APIKey)
387		if err != nil {
388			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
389		}
390		a.client = createAnthropicClient(a.providerOptions, a.useBedrock)
391		return true, 0, nil
392	}
393
394	if apiErr.StatusCode != 429 && apiErr.StatusCode != 529 {
395		return false, 0, err
396	}
397
398	if attempts > maxRetries {
399		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
400	}
401
402	retryMs := 0
403	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
404
405	backoffMs := 2000 * (1 << (attempts - 1))
406	jitterMs := int(float64(backoffMs) * 0.2)
407	retryMs = backoffMs + jitterMs
408	if len(retryAfterValues) > 0 {
409		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
410			retryMs = retryMs * 1000
411		}
412	}
413	return true, int64(retryMs), nil
414}
415
416func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
417	var toolCalls []message.ToolCall
418
419	for _, block := range msg.Content {
420		switch variant := block.AsAny().(type) {
421		case anthropic.ToolUseBlock:
422			toolCall := message.ToolCall{
423				ID:       variant.ID,
424				Name:     variant.Name,
425				Input:    string(variant.Input),
426				Type:     string(variant.Type),
427				Finished: true,
428			}
429			toolCalls = append(toolCalls, toolCall)
430		}
431	}
432
433	return toolCalls
434}
435
436func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
437	return TokenUsage{
438		InputTokens:         msg.Usage.InputTokens,
439		OutputTokens:        msg.Usage.OutputTokens,
440		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
441		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
442	}
443}
444
445func (a *anthropicClient) Model() config.Model {
446	return a.providerOptions.model(a.providerOptions.modelType)
447}