anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/anthropics/anthropic-sdk-go"
 13	"github.com/anthropics/anthropic-sdk-go/bedrock"
 14	"github.com/anthropics/anthropic-sdk-go/option"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/llm/models"
 17	"github.com/charmbracelet/crush/internal/llm/tools"
 18	"github.com/charmbracelet/crush/internal/logging"
 19	"github.com/charmbracelet/crush/internal/message"
 20)
 21
 22type anthropicClient struct {
 23	providerOptions providerClientOptions
 24	client          anthropic.Client
 25}
 26
 27type AnthropicClient ProviderClient
 28
 29func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 30	anthropicClientOptions := []option.RequestOption{}
 31	if opts.apiKey != "" {
 32		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 33	}
 34	if useBedrock {
 35		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 36	}
 37
 38	client := anthropic.NewClient(anthropicClientOptions...)
 39	return &anthropicClient{
 40		providerOptions: opts,
 41		client:          client,
 42	}
 43}
 44
 45func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 46	for i, msg := range messages {
 47		cache := false
 48		if i > len(messages)-3 {
 49			cache = true
 50		}
 51		switch msg.Role {
 52		case message.User:
 53			content := anthropic.NewTextBlock(msg.Content().String())
 54			if cache && !a.providerOptions.disableCache {
 55				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 56					Type: "ephemeral",
 57				}
 58			}
 59			var contentBlocks []anthropic.ContentBlockParamUnion
 60			contentBlocks = append(contentBlocks, content)
 61			for _, binaryContent := range msg.BinaryContent() {
 62				base64Image := binaryContent.String(models.ProviderAnthropic)
 63				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
 64				contentBlocks = append(contentBlocks, imageBlock)
 65			}
 66			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
 67
 68		case message.Assistant:
 69			blocks := []anthropic.ContentBlockParamUnion{}
 70			if msg.Content().String() != "" {
 71				content := anthropic.NewTextBlock(msg.Content().String())
 72				if cache && !a.providerOptions.disableCache {
 73					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 74						Type: "ephemeral",
 75					}
 76				}
 77				blocks = append(blocks, content)
 78			}
 79
 80			for _, toolCall := range msg.ToolCalls() {
 81				var inputMap map[string]any
 82				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
 83				if err != nil {
 84					continue
 85				}
 86				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
 87			}
 88
 89			if len(blocks) == 0 {
 90				logging.Warn("There is a message without content, investigate, this should not happen")
 91				continue
 92			}
 93			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
 94
 95		case message.Tool:
 96			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
 97			for i, toolResult := range msg.ToolResults() {
 98				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
 99			}
100			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
101		}
102	}
103	return
104}
105
106func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
107	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
108
109	for i, tool := range tools {
110		info := tool.Info()
111		toolParam := anthropic.ToolParam{
112			Name:        info.Name,
113			Description: anthropic.String(info.Description),
114			InputSchema: anthropic.ToolInputSchemaParam{
115				Properties: info.Parameters,
116				// TODO: figure out how we can tell claude the required fields?
117			},
118		}
119
120		if i == len(tools)-1 && !a.providerOptions.disableCache {
121			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
122				Type: "ephemeral",
123			}
124		}
125
126		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
127	}
128
129	return anthropicTools
130}
131
132func (a *anthropicClient) finishReason(reason string) message.FinishReason {
133	switch reason {
134	case "end_turn":
135		return message.FinishReasonEndTurn
136	case "max_tokens":
137		return message.FinishReasonMaxTokens
138	case "tool_use":
139		return message.FinishReasonToolUse
140	case "stop_sequence":
141		return message.FinishReasonEndTurn
142	default:
143		return message.FinishReasonUnknown
144	}
145}
146
147func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
148	var thinkingParam anthropic.ThinkingConfigParamUnion
149	// TODO: Implement a proper thinking function
150	// lastMessage := messages[len(messages)-1]
151	// isUser := lastMessage.Role == anthropic.MessageParamRoleUser
152	// messageContent := ""
153	temperature := anthropic.Float(0)
154	// if isUser {
155	// 	for _, m := range lastMessage.Content {
156	// 		if m.OfText != nil && m.OfText.Text != "" {
157	// 			messageContent = m.OfText.Text
158	// 		}
159	// 	}
160	// 	if messageContent != "" && a.shouldThink != nil && a.options.shouldThink(messageContent) {
161	// 		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
162	// 		temperature = anthropic.Float(1)
163	// 	}
164	// }
165
166	return anthropic.MessageNewParams{
167		Model:       anthropic.Model(a.providerOptions.model.APIModel),
168		MaxTokens:   a.providerOptions.maxTokens,
169		Temperature: temperature,
170		Messages:    messages,
171		Tools:       tools,
172		Thinking:    thinkingParam,
173		System: []anthropic.TextBlockParam{
174			{
175				Text: a.providerOptions.systemMessage,
176				CacheControl: anthropic.CacheControlEphemeralParam{
177					Type: "ephemeral",
178				},
179			},
180		},
181	}
182}
183
184func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
185	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
186	cfg := config.Get()
187	if cfg.Debug {
188		jsonData, _ := json.Marshal(preparedMessages)
189		logging.Debug("Prepared messages", "messages", string(jsonData))
190	}
191
192	attempts := 0
193	for {
194		attempts++
195		anthropicResponse, err := a.client.Messages.New(
196			ctx,
197			preparedMessages,
198		)
199		// If there is an error we are going to see if we can retry the call
200		if err != nil {
201			logging.Error("Error in Anthropic API call", "error", err)
202			retry, after, retryErr := a.shouldRetry(attempts, err)
203			if retryErr != nil {
204				return nil, retryErr
205			}
206			if retry {
207				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
208				select {
209				case <-ctx.Done():
210					return nil, ctx.Err()
211				case <-time.After(time.Duration(after) * time.Millisecond):
212					continue
213				}
214			}
215			return nil, retryErr
216		}
217
218		content := ""
219		for _, block := range anthropicResponse.Content {
220			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
221				content += text.Text
222			}
223		}
224
225		return &ProviderResponse{
226			Content:   content,
227			ToolCalls: a.toolCalls(*anthropicResponse),
228			Usage:     a.usage(*anthropicResponse),
229		}, nil
230	}
231}
232
233func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
234	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
235	cfg := config.Get()
236	if cfg.Debug {
237		// jsonData, _ := json.Marshal(preparedMessages)
238		// logging.Debug("Prepared messages", "messages", string(jsonData))
239	}
240	attempts := 0
241	eventChan := make(chan ProviderEvent)
242	go func() {
243		for {
244			attempts++
245			anthropicStream := a.client.Messages.NewStreaming(
246				ctx,
247				preparedMessages,
248			)
249			accumulatedMessage := anthropic.Message{}
250
251			currentToolCallID := ""
252			for anthropicStream.Next() {
253				event := anthropicStream.Current()
254				err := accumulatedMessage.Accumulate(event)
255				if err != nil {
256					logging.Warn("Error accumulating message", "error", err)
257					continue
258				}
259
260				switch event := event.AsAny().(type) {
261				case anthropic.ContentBlockStartEvent:
262					switch event.ContentBlock.Type {
263					case "text":
264						eventChan <- ProviderEvent{Type: EventContentStart}
265					case "tool_use":
266						currentToolCallID = event.ContentBlock.ID
267						eventChan <- ProviderEvent{
268							Type: EventToolUseStart,
269							ToolCall: &message.ToolCall{
270								ID:       event.ContentBlock.ID,
271								Name:     event.ContentBlock.Name,
272								Finished: false,
273							},
274						}
275					}
276
277				case anthropic.ContentBlockDeltaEvent:
278					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
279						eventChan <- ProviderEvent{
280							Type:     EventThinkingDelta,
281							Thinking: event.Delta.Thinking,
282						}
283					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
284						eventChan <- ProviderEvent{
285							Type:    EventContentDelta,
286							Content: event.Delta.Text,
287						}
288					} else if event.Delta.Type == "input_json_delta" {
289						if currentToolCallID != "" {
290							eventChan <- ProviderEvent{
291								Type: EventToolUseDelta,
292								ToolCall: &message.ToolCall{
293									ID:       currentToolCallID,
294									Finished: false,
295									Input:    event.Delta.PartialJSON,
296								},
297							}
298						}
299					}
300				case anthropic.ContentBlockStopEvent:
301					if currentToolCallID != "" {
302						eventChan <- ProviderEvent{
303							Type: EventToolUseStop,
304							ToolCall: &message.ToolCall{
305								ID: currentToolCallID,
306							},
307						}
308						currentToolCallID = ""
309					} else {
310						eventChan <- ProviderEvent{Type: EventContentStop}
311					}
312
313				case anthropic.MessageStopEvent:
314					content := ""
315					for _, block := range accumulatedMessage.Content {
316						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
317							content += text.Text
318						}
319					}
320
321					eventChan <- ProviderEvent{
322						Type: EventComplete,
323						Response: &ProviderResponse{
324							Content:      content,
325							ToolCalls:    a.toolCalls(accumulatedMessage),
326							Usage:        a.usage(accumulatedMessage),
327							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
328						},
329						Content: content,
330					}
331				}
332			}
333
334			err := anthropicStream.Err()
335			if err == nil || errors.Is(err, io.EOF) {
336				close(eventChan)
337				return
338			}
339			// If there is an error we are going to see if we can retry the call
340			retry, after, retryErr := a.shouldRetry(attempts, err)
341			if retryErr != nil {
342				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
343				close(eventChan)
344				return
345			}
346			if retry {
347				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
348				select {
349				case <-ctx.Done():
350					// context cancelled
351					if ctx.Err() != nil {
352						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
353					}
354					close(eventChan)
355					return
356				case <-time.After(time.Duration(after) * time.Millisecond):
357					continue
358				}
359			}
360			if ctx.Err() != nil {
361				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
362			}
363
364			close(eventChan)
365			return
366		}
367	}()
368	return eventChan
369}
370
371func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
372	var apierr *anthropic.Error
373	if !errors.As(err, &apierr) {
374		return false, 0, err
375	}
376
377	if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
378		return false, 0, err
379	}
380
381	if attempts > maxRetries {
382		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
383	}
384
385	retryMs := 0
386	retryAfterValues := apierr.Response.Header.Values("Retry-After")
387
388	backoffMs := 2000 * (1 << (attempts - 1))
389	jitterMs := int(float64(backoffMs) * 0.2)
390	retryMs = backoffMs + jitterMs
391	if len(retryAfterValues) > 0 {
392		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
393			retryMs = retryMs * 1000
394		}
395	}
396	return true, int64(retryMs), nil
397}
398
399func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
400	var toolCalls []message.ToolCall
401
402	for _, block := range msg.Content {
403		switch variant := block.AsAny().(type) {
404		case anthropic.ToolUseBlock:
405			toolCall := message.ToolCall{
406				ID:       variant.ID,
407				Name:     variant.Name,
408				Input:    string(variant.Input),
409				Type:     string(variant.Type),
410				Finished: true,
411			}
412			toolCalls = append(toolCalls, toolCall)
413		}
414	}
415
416	return toolCalls
417}
418
419func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
420	return TokenUsage{
421		InputTokens:         msg.Usage.InputTokens,
422		OutputTokens:        msg.Usage.OutputTokens,
423		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
424		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
425	}
426}
427
428// TODO: check if we need
429func DefaultShouldThinkFn(s string) bool {
430	return strings.Contains(strings.ToLower(s), "think")
431}