anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/anthropics/anthropic-sdk-go"
 13	"github.com/anthropics/anthropic-sdk-go/bedrock"
 14	"github.com/anthropics/anthropic-sdk-go/option"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/fur/provider"
 17	"github.com/charmbracelet/crush/internal/llm/tools"
 18	"github.com/charmbracelet/crush/internal/logging"
 19	"github.com/charmbracelet/crush/internal/message"
 20)
 21
 22type anthropicClient struct {
 23	providerOptions providerClientOptions
 24	client          anthropic.Client
 25}
 26
 27type AnthropicClient ProviderClient
 28
 29func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 30	anthropicClientOptions := []option.RequestOption{}
 31	if opts.apiKey != "" {
 32		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 33	}
 34	if useBedrock {
 35		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 36	}
 37
 38	client := anthropic.NewClient(anthropicClientOptions...)
 39	return &anthropicClient{
 40		providerOptions: opts,
 41		client:          client,
 42	}
 43}
 44
 45func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 46	for i, msg := range messages {
 47		cache := false
 48		if i > len(messages)-3 {
 49			cache = true
 50		}
 51		switch msg.Role {
 52		case message.User:
 53			content := anthropic.NewTextBlock(msg.Content().String())
 54			if cache && !a.providerOptions.disableCache {
 55				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 56					Type: "ephemeral",
 57				}
 58			}
 59			var contentBlocks []anthropic.ContentBlockParamUnion
 60			contentBlocks = append(contentBlocks, content)
 61			for _, binaryContent := range msg.BinaryContent() {
 62				base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
 63				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
 64				contentBlocks = append(contentBlocks, imageBlock)
 65			}
 66			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
 67
 68		case message.Assistant:
 69			blocks := []anthropic.ContentBlockParamUnion{}
 70			if msg.Content().String() != "" {
 71				content := anthropic.NewTextBlock(msg.Content().String())
 72				if cache && !a.providerOptions.disableCache {
 73					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 74						Type: "ephemeral",
 75					}
 76				}
 77				blocks = append(blocks, content)
 78			}
 79
 80			for _, toolCall := range msg.ToolCalls() {
 81				var inputMap map[string]any
 82				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
 83				if err != nil {
 84					continue
 85				}
 86				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
 87			}
 88
 89			if len(blocks) == 0 {
 90				logging.Warn("There is a message without content, investigate, this should not happen")
 91				continue
 92			}
 93			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
 94
 95		case message.Tool:
 96			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
 97			for i, toolResult := range msg.ToolResults() {
 98				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
 99			}
100			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
101		}
102	}
103	return
104}
105
106func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
107	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
108
109	for i, tool := range tools {
110		info := tool.Info()
111		toolParam := anthropic.ToolParam{
112			Name:        info.Name,
113			Description: anthropic.String(info.Description),
114			InputSchema: anthropic.ToolInputSchemaParam{
115				Properties: info.Parameters,
116				// TODO: figure out how we can tell claude the required fields?
117			},
118		}
119
120		if i == len(tools)-1 && !a.providerOptions.disableCache {
121			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
122				Type: "ephemeral",
123			}
124		}
125
126		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
127	}
128
129	return anthropicTools
130}
131
132func (a *anthropicClient) finishReason(reason string) message.FinishReason {
133	switch reason {
134	case "end_turn":
135		return message.FinishReasonEndTurn
136	case "max_tokens":
137		return message.FinishReasonMaxTokens
138	case "tool_use":
139		return message.FinishReasonToolUse
140	case "stop_sequence":
141		return message.FinishReasonEndTurn
142	default:
143		return message.FinishReasonUnknown
144	}
145}
146
147func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
148	model := a.providerOptions.model(a.providerOptions.modelType)
149	var thinkingParam anthropic.ThinkingConfigParamUnion
150	// TODO: Implement a proper thinking function
151	// lastMessage := messages[len(messages)-1]
152	// isUser := lastMessage.Role == anthropic.MessageParamRoleUser
153	// messageContent := ""
154	temperature := anthropic.Float(0)
155	// if isUser {
156	// 	for _, m := range lastMessage.Content {
157	// 		if m.OfText != nil && m.OfText.Text != "" {
158	// 			messageContent = m.OfText.Text
159	// 		}
160	// 	}
161	// 	if messageContent != "" && a.shouldThink != nil && a.options.shouldThink(messageContent) {
162	// 		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
163	// 		temperature = anthropic.Float(1)
164	// 	}
165	// }
166
167	return anthropic.MessageNewParams{
168		Model:       anthropic.Model(model.ID),
169		MaxTokens:   a.providerOptions.maxTokens,
170		Temperature: temperature,
171		Messages:    messages,
172		Tools:       tools,
173		Thinking:    thinkingParam,
174		System: []anthropic.TextBlockParam{
175			{
176				Text: a.providerOptions.systemMessage,
177				CacheControl: anthropic.CacheControlEphemeralParam{
178					Type: "ephemeral",
179				},
180			},
181		},
182	}
183}
184
185func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
186	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
187	cfg := config.Get()
188	if cfg.Options.Debug {
189		jsonData, _ := json.Marshal(preparedMessages)
190		logging.Debug("Prepared messages", "messages", string(jsonData))
191	}
192
193	attempts := 0
194	for {
195		attempts++
196		anthropicResponse, err := a.client.Messages.New(
197			ctx,
198			preparedMessages,
199		)
200		// If there is an error we are going to see if we can retry the call
201		if err != nil {
202			logging.Error("Error in Anthropic API call", "error", err)
203			retry, after, retryErr := a.shouldRetry(attempts, err)
204			if retryErr != nil {
205				return nil, retryErr
206			}
207			if retry {
208				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
209				select {
210				case <-ctx.Done():
211					return nil, ctx.Err()
212				case <-time.After(time.Duration(after) * time.Millisecond):
213					continue
214				}
215			}
216			return nil, retryErr
217		}
218
219		content := ""
220		for _, block := range anthropicResponse.Content {
221			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
222				content += text.Text
223			}
224		}
225
226		return &ProviderResponse{
227			Content:   content,
228			ToolCalls: a.toolCalls(*anthropicResponse),
229			Usage:     a.usage(*anthropicResponse),
230		}, nil
231	}
232}
233
234func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
235	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
236	cfg := config.Get()
237	if cfg.Options.Debug {
238		// jsonData, _ := json.Marshal(preparedMessages)
239		// logging.Debug("Prepared messages", "messages", string(jsonData))
240	}
241	attempts := 0
242	eventChan := make(chan ProviderEvent)
243	go func() {
244		for {
245			attempts++
246			anthropicStream := a.client.Messages.NewStreaming(
247				ctx,
248				preparedMessages,
249			)
250			accumulatedMessage := anthropic.Message{}
251
252			currentToolCallID := ""
253			for anthropicStream.Next() {
254				event := anthropicStream.Current()
255				err := accumulatedMessage.Accumulate(event)
256				if err != nil {
257					logging.Warn("Error accumulating message", "error", err)
258					continue
259				}
260
261				switch event := event.AsAny().(type) {
262				case anthropic.ContentBlockStartEvent:
263					switch event.ContentBlock.Type {
264					case "text":
265						eventChan <- ProviderEvent{Type: EventContentStart}
266					case "tool_use":
267						currentToolCallID = event.ContentBlock.ID
268						eventChan <- ProviderEvent{
269							Type: EventToolUseStart,
270							ToolCall: &message.ToolCall{
271								ID:       event.ContentBlock.ID,
272								Name:     event.ContentBlock.Name,
273								Finished: false,
274							},
275						}
276					}
277
278				case anthropic.ContentBlockDeltaEvent:
279					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
280						eventChan <- ProviderEvent{
281							Type:     EventThinkingDelta,
282							Thinking: event.Delta.Thinking,
283						}
284					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
285						eventChan <- ProviderEvent{
286							Type:    EventContentDelta,
287							Content: event.Delta.Text,
288						}
289					} else if event.Delta.Type == "input_json_delta" {
290						if currentToolCallID != "" {
291							eventChan <- ProviderEvent{
292								Type: EventToolUseDelta,
293								ToolCall: &message.ToolCall{
294									ID:       currentToolCallID,
295									Finished: false,
296									Input:    event.Delta.PartialJSON,
297								},
298							}
299						}
300					}
301				case anthropic.ContentBlockStopEvent:
302					if currentToolCallID != "" {
303						eventChan <- ProviderEvent{
304							Type: EventToolUseStop,
305							ToolCall: &message.ToolCall{
306								ID: currentToolCallID,
307							},
308						}
309						currentToolCallID = ""
310					} else {
311						eventChan <- ProviderEvent{Type: EventContentStop}
312					}
313
314				case anthropic.MessageStopEvent:
315					content := ""
316					for _, block := range accumulatedMessage.Content {
317						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
318							content += text.Text
319						}
320					}
321
322					eventChan <- ProviderEvent{
323						Type: EventComplete,
324						Response: &ProviderResponse{
325							Content:      content,
326							ToolCalls:    a.toolCalls(accumulatedMessage),
327							Usage:        a.usage(accumulatedMessage),
328							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
329						},
330						Content: content,
331					}
332				}
333			}
334
335			err := anthropicStream.Err()
336			if err == nil || errors.Is(err, io.EOF) {
337				close(eventChan)
338				return
339			}
340			// If there is an error we are going to see if we can retry the call
341			retry, after, retryErr := a.shouldRetry(attempts, err)
342			if retryErr != nil {
343				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
344				close(eventChan)
345				return
346			}
347			if retry {
348				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
349				select {
350				case <-ctx.Done():
351					// context cancelled
352					if ctx.Err() != nil {
353						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
354					}
355					close(eventChan)
356					return
357				case <-time.After(time.Duration(after) * time.Millisecond):
358					continue
359				}
360			}
361			if ctx.Err() != nil {
362				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
363			}
364
365			close(eventChan)
366			return
367		}
368	}()
369	return eventChan
370}
371
372func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
373	var apierr *anthropic.Error
374	if !errors.As(err, &apierr) {
375		return false, 0, err
376	}
377
378	if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
379		return false, 0, err
380	}
381
382	if attempts > maxRetries {
383		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
384	}
385
386	retryMs := 0
387	retryAfterValues := apierr.Response.Header.Values("Retry-After")
388
389	backoffMs := 2000 * (1 << (attempts - 1))
390	jitterMs := int(float64(backoffMs) * 0.2)
391	retryMs = backoffMs + jitterMs
392	if len(retryAfterValues) > 0 {
393		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
394			retryMs = retryMs * 1000
395		}
396	}
397	return true, int64(retryMs), nil
398}
399
400func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
401	var toolCalls []message.ToolCall
402
403	for _, block := range msg.Content {
404		switch variant := block.AsAny().(type) {
405		case anthropic.ToolUseBlock:
406			toolCall := message.ToolCall{
407				ID:       variant.ID,
408				Name:     variant.Name,
409				Input:    string(variant.Input),
410				Type:     string(variant.Type),
411				Finished: true,
412			}
413			toolCalls = append(toolCalls, toolCall)
414		}
415	}
416
417	return toolCalls
418}
419
420func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
421	return TokenUsage{
422		InputTokens:         msg.Usage.InputTokens,
423		OutputTokens:        msg.Usage.OutputTokens,
424		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
425		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
426	}
427}
428
429func (a *anthropicClient) Model() config.Model {
430	return a.providerOptions.model(a.providerOptions.modelType)
431}
432
433// TODO: check if we need
434func DefaultShouldThinkFn(s string) bool {
435	return strings.Contains(strings.ToLower(s), "think")
436}