1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/anthropics/anthropic-sdk-go"
 13	"github.com/anthropics/anthropic-sdk-go/bedrock"
 14	"github.com/anthropics/anthropic-sdk-go/option"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/fur/provider"
 17	"github.com/charmbracelet/crush/internal/llm/tools"
 18	"github.com/charmbracelet/crush/internal/logging"
 19	"github.com/charmbracelet/crush/internal/message"
 20)
 21
 22type anthropicClient struct {
 23	providerOptions providerClientOptions
 24	client          anthropic.Client
 25}
 26
 27type AnthropicClient ProviderClient
 28
 29func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 30	anthropicClientOptions := []option.RequestOption{}
 31	if opts.apiKey != "" {
 32		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 33	}
 34	if useBedrock {
 35		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 36	}
 37
 38	client := anthropic.NewClient(anthropicClientOptions...)
 39	return &anthropicClient{
 40		providerOptions: opts,
 41		client:          client,
 42	}
 43}
 44
 45func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 46	for i, msg := range messages {
 47		cache := false
 48		if i > len(messages)-3 {
 49			cache = true
 50		}
 51		switch msg.Role {
 52		case message.User:
 53			content := anthropic.NewTextBlock(msg.Content().String())
 54			if cache && !a.providerOptions.disableCache {
 55				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 56					Type: "ephemeral",
 57				}
 58			}
 59			var contentBlocks []anthropic.ContentBlockParamUnion
 60			contentBlocks = append(contentBlocks, content)
 61			for _, binaryContent := range msg.BinaryContent() {
 62				base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
 63				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
 64				contentBlocks = append(contentBlocks, imageBlock)
 65			}
 66			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
 67
 68		case message.Assistant:
 69			blocks := []anthropic.ContentBlockParamUnion{}
 70			if msg.Content().String() != "" {
 71				content := anthropic.NewTextBlock(msg.Content().String())
 72				if cache && !a.providerOptions.disableCache {
 73					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 74						Type: "ephemeral",
 75					}
 76				}
 77				blocks = append(blocks, content)
 78			}
 79
 80			for _, toolCall := range msg.ToolCalls() {
 81				var inputMap map[string]any
 82				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
 83				if err != nil {
 84					continue
 85				}
 86				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
 87			}
 88
 89			if len(blocks) == 0 {
 90				logging.Warn("There is a message without content, investigate, this should not happen")
 91				continue
 92			}
 93			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
 94
 95		case message.Tool:
 96			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
 97			for i, toolResult := range msg.ToolResults() {
 98				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
 99			}
100			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
101		}
102	}
103	return
104}
105
106func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
107	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
108
109	for i, tool := range tools {
110		info := tool.Info()
111		toolParam := anthropic.ToolParam{
112			Name:        info.Name,
113			Description: anthropic.String(info.Description),
114			InputSchema: anthropic.ToolInputSchemaParam{
115				Properties: info.Parameters,
116				// TODO: figure out how we can tell claude the required fields?
117			},
118		}
119
120		if i == len(tools)-1 && !a.providerOptions.disableCache {
121			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
122				Type: "ephemeral",
123			}
124		}
125
126		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
127	}
128
129	return anthropicTools
130}
131
132func (a *anthropicClient) finishReason(reason string) message.FinishReason {
133	switch reason {
134	case "end_turn":
135		return message.FinishReasonEndTurn
136	case "max_tokens":
137		return message.FinishReasonMaxTokens
138	case "tool_use":
139		return message.FinishReasonToolUse
140	case "stop_sequence":
141		return message.FinishReasonEndTurn
142	default:
143		return message.FinishReasonUnknown
144	}
145}
146
147func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
148	model := a.providerOptions.model(a.providerOptions.modelType)
149	var thinkingParam anthropic.ThinkingConfigParamUnion
150	// TODO: Implement a proper thinking function
151	// lastMessage := messages[len(messages)-1]
152	// isUser := lastMessage.Role == anthropic.MessageParamRoleUser
153	// messageContent := ""
154	temperature := anthropic.Float(0)
155	// if isUser {
156	// 	for _, m := range lastMessage.Content {
157	// 		if m.OfText != nil && m.OfText.Text != "" {
158	// 			messageContent = m.OfText.Text
159	// 		}
160	// 	}
161	// 	if messageContent != "" && a.shouldThink != nil && a.options.shouldThink(messageContent) {
162	// 		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
163	// 		temperature = anthropic.Float(1)
164	// 	}
165	// }
166
167	cfg := config.Get()
168	modelConfig := cfg.Models.Large
169	if a.providerOptions.modelType == config.SmallModel {
170		modelConfig = cfg.Models.Small
171	}
172	maxTokens := model.DefaultMaxTokens
173	if modelConfig.MaxTokens > 0 {
174		maxTokens = modelConfig.MaxTokens
175	}
176
177	// Override max tokens if set in provider options
178	if a.providerOptions.maxTokens > 0 {
179		maxTokens = a.providerOptions.maxTokens
180	}
181
182	return anthropic.MessageNewParams{
183		Model:       anthropic.Model(model.ID),
184		MaxTokens:   maxTokens,
185		Temperature: temperature,
186		Messages:    messages,
187		Tools:       tools,
188		Thinking:    thinkingParam,
189		System: []anthropic.TextBlockParam{
190			{
191				Text: a.providerOptions.systemMessage,
192				CacheControl: anthropic.CacheControlEphemeralParam{
193					Type: "ephemeral",
194				},
195			},
196		},
197	}
198}
199
200func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
201	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
202	cfg := config.Get()
203	if cfg.Options.Debug {
204		jsonData, _ := json.Marshal(preparedMessages)
205		logging.Debug("Prepared messages", "messages", string(jsonData))
206	}
207
208	attempts := 0
209	for {
210		attempts++
211		anthropicResponse, err := a.client.Messages.New(
212			ctx,
213			preparedMessages,
214		)
215		// If there is an error we are going to see if we can retry the call
216		if err != nil {
217			logging.Error("Error in Anthropic API call", "error", err)
218			retry, after, retryErr := a.shouldRetry(attempts, err)
219			if retryErr != nil {
220				return nil, retryErr
221			}
222			if retry {
223				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
224				select {
225				case <-ctx.Done():
226					return nil, ctx.Err()
227				case <-time.After(time.Duration(after) * time.Millisecond):
228					continue
229				}
230			}
231			return nil, retryErr
232		}
233
234		content := ""
235		for _, block := range anthropicResponse.Content {
236			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
237				content += text.Text
238			}
239		}
240
241		return &ProviderResponse{
242			Content:   content,
243			ToolCalls: a.toolCalls(*anthropicResponse),
244			Usage:     a.usage(*anthropicResponse),
245		}, nil
246	}
247}
248
249func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
250	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
251	cfg := config.Get()
252	if cfg.Options.Debug {
253		// jsonData, _ := json.Marshal(preparedMessages)
254		// logging.Debug("Prepared messages", "messages", string(jsonData))
255	}
256	attempts := 0
257	eventChan := make(chan ProviderEvent)
258	go func() {
259		for {
260			attempts++
261			anthropicStream := a.client.Messages.NewStreaming(
262				ctx,
263				preparedMessages,
264			)
265			accumulatedMessage := anthropic.Message{}
266
267			currentToolCallID := ""
268			for anthropicStream.Next() {
269				event := anthropicStream.Current()
270				err := accumulatedMessage.Accumulate(event)
271				if err != nil {
272					logging.Warn("Error accumulating message", "error", err)
273					continue
274				}
275
276				switch event := event.AsAny().(type) {
277				case anthropic.ContentBlockStartEvent:
278					switch event.ContentBlock.Type {
279					case "text":
280						eventChan <- ProviderEvent{Type: EventContentStart}
281					case "tool_use":
282						currentToolCallID = event.ContentBlock.ID
283						eventChan <- ProviderEvent{
284							Type: EventToolUseStart,
285							ToolCall: &message.ToolCall{
286								ID:       event.ContentBlock.ID,
287								Name:     event.ContentBlock.Name,
288								Finished: false,
289							},
290						}
291					}
292
293				case anthropic.ContentBlockDeltaEvent:
294					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
295						eventChan <- ProviderEvent{
296							Type:     EventThinkingDelta,
297							Thinking: event.Delta.Thinking,
298						}
299					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
300						eventChan <- ProviderEvent{
301							Type:    EventContentDelta,
302							Content: event.Delta.Text,
303						}
304					} else if event.Delta.Type == "input_json_delta" {
305						if currentToolCallID != "" {
306							eventChan <- ProviderEvent{
307								Type: EventToolUseDelta,
308								ToolCall: &message.ToolCall{
309									ID:       currentToolCallID,
310									Finished: false,
311									Input:    event.Delta.PartialJSON,
312								},
313							}
314						}
315					}
316				case anthropic.ContentBlockStopEvent:
317					if currentToolCallID != "" {
318						eventChan <- ProviderEvent{
319							Type: EventToolUseStop,
320							ToolCall: &message.ToolCall{
321								ID: currentToolCallID,
322							},
323						}
324						currentToolCallID = ""
325					} else {
326						eventChan <- ProviderEvent{Type: EventContentStop}
327					}
328
329				case anthropic.MessageStopEvent:
330					content := ""
331					for _, block := range accumulatedMessage.Content {
332						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
333							content += text.Text
334						}
335					}
336
337					eventChan <- ProviderEvent{
338						Type: EventComplete,
339						Response: &ProviderResponse{
340							Content:      content,
341							ToolCalls:    a.toolCalls(accumulatedMessage),
342							Usage:        a.usage(accumulatedMessage),
343							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
344						},
345						Content: content,
346					}
347				}
348			}
349
350			err := anthropicStream.Err()
351			if err == nil || errors.Is(err, io.EOF) {
352				close(eventChan)
353				return
354			}
355			// If there is an error we are going to see if we can retry the call
356			retry, after, retryErr := a.shouldRetry(attempts, err)
357			if retryErr != nil {
358				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
359				close(eventChan)
360				return
361			}
362			if retry {
363				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
364				select {
365				case <-ctx.Done():
366					// context cancelled
367					if ctx.Err() != nil {
368						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
369					}
370					close(eventChan)
371					return
372				case <-time.After(time.Duration(after) * time.Millisecond):
373					continue
374				}
375			}
376			if ctx.Err() != nil {
377				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
378			}
379
380			close(eventChan)
381			return
382		}
383	}()
384	return eventChan
385}
386
387func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
388	var apierr *anthropic.Error
389	if !errors.As(err, &apierr) {
390		return false, 0, err
391	}
392
393	if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
394		return false, 0, err
395	}
396
397	if attempts > maxRetries {
398		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
399	}
400
401	retryMs := 0
402	retryAfterValues := apierr.Response.Header.Values("Retry-After")
403
404	backoffMs := 2000 * (1 << (attempts - 1))
405	jitterMs := int(float64(backoffMs) * 0.2)
406	retryMs = backoffMs + jitterMs
407	if len(retryAfterValues) > 0 {
408		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
409			retryMs = retryMs * 1000
410		}
411	}
412	return true, int64(retryMs), nil
413}
414
415func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
416	var toolCalls []message.ToolCall
417
418	for _, block := range msg.Content {
419		switch variant := block.AsAny().(type) {
420		case anthropic.ToolUseBlock:
421			toolCall := message.ToolCall{
422				ID:       variant.ID,
423				Name:     variant.Name,
424				Input:    string(variant.Input),
425				Type:     string(variant.Type),
426				Finished: true,
427			}
428			toolCalls = append(toolCalls, toolCall)
429		}
430	}
431
432	return toolCalls
433}
434
435func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
436	return TokenUsage{
437		InputTokens:         msg.Usage.InputTokens,
438		OutputTokens:        msg.Usage.OutputTokens,
439		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
440		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
441	}
442}
443
444func (a *anthropicClient) Model() config.Model {
445	return a.providerOptions.model(a.providerOptions.modelType)
446}
447
448// TODO: check if we need
449func DefaultShouldThinkFn(s string) bool {
450	return strings.Contains(strings.ToLower(s), "think")
451}