1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/anthropics/anthropic-sdk-go"
 13	"github.com/anthropics/anthropic-sdk-go/bedrock"
 14	"github.com/anthropics/anthropic-sdk-go/option"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/fur/provider"
 17	"github.com/charmbracelet/crush/internal/llm/tools"
 18	"github.com/charmbracelet/crush/internal/logging"
 19	"github.com/charmbracelet/crush/internal/message"
 20)
 21
 22type anthropicClient struct {
 23	providerOptions providerClientOptions
 24	client          anthropic.Client
 25}
 26
 27type AnthropicClient ProviderClient
 28
 29func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient {
 30	anthropicClientOptions := []option.RequestOption{}
 31	if opts.apiKey != "" {
 32		anthropicClientOptions = append(anthropicClientOptions, option.WithAPIKey(opts.apiKey))
 33	}
 34	if useBedrock {
 35		anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 36	}
 37
 38	client := anthropic.NewClient(anthropicClientOptions...)
 39	return &anthropicClient{
 40		providerOptions: opts,
 41		client:          client,
 42	}
 43}
 44
 45func (a *anthropicClient) convertMessages(messages []message.Message) (anthropicMessages []anthropic.MessageParam) {
 46	for i, msg := range messages {
 47		cache := false
 48		if i > len(messages)-3 {
 49			cache = true
 50		}
 51		switch msg.Role {
 52		case message.User:
 53			content := anthropic.NewTextBlock(msg.Content().String())
 54			if cache && !a.providerOptions.disableCache {
 55				content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 56					Type: "ephemeral",
 57				}
 58			}
 59			var contentBlocks []anthropic.ContentBlockParamUnion
 60			contentBlocks = append(contentBlocks, content)
 61			for _, binaryContent := range msg.BinaryContent() {
 62				base64Image := binaryContent.String(provider.InferenceProviderAnthropic)
 63				imageBlock := anthropic.NewImageBlockBase64(binaryContent.MIMEType, base64Image)
 64				contentBlocks = append(contentBlocks, imageBlock)
 65			}
 66			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(contentBlocks...))
 67
 68		case message.Assistant:
 69			blocks := []anthropic.ContentBlockParamUnion{}
 70			if msg.Content().String() != "" {
 71				content := anthropic.NewTextBlock(msg.Content().String())
 72				if cache && !a.providerOptions.disableCache {
 73					content.OfText.CacheControl = anthropic.CacheControlEphemeralParam{
 74						Type: "ephemeral",
 75					}
 76				}
 77				blocks = append(blocks, content)
 78			}
 79
 80			for _, toolCall := range msg.ToolCalls() {
 81				var inputMap map[string]any
 82				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
 83				if err != nil {
 84					continue
 85				}
 86				blocks = append(blocks, anthropic.NewToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
 87			}
 88
 89			if len(blocks) == 0 {
 90				logging.Warn("There is a message without content, investigate, this should not happen")
 91				continue
 92			}
 93			anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
 94
 95		case message.Tool:
 96			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
 97			for i, toolResult := range msg.ToolResults() {
 98				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
 99			}
100			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
101		}
102	}
103	return
104}
105
106func (a *anthropicClient) convertTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
107	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
108
109	for i, tool := range tools {
110		info := tool.Info()
111		toolParam := anthropic.ToolParam{
112			Name:        info.Name,
113			Description: anthropic.String(info.Description),
114			InputSchema: anthropic.ToolInputSchemaParam{
115				Properties: info.Parameters,
116				// TODO: figure out how we can tell claude the required fields?
117			},
118		}
119
120		if i == len(tools)-1 && !a.providerOptions.disableCache {
121			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
122				Type: "ephemeral",
123			}
124		}
125
126		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
127	}
128
129	return anthropicTools
130}
131
132func (a *anthropicClient) finishReason(reason string) message.FinishReason {
133	switch reason {
134	case "end_turn":
135		return message.FinishReasonEndTurn
136	case "max_tokens":
137		return message.FinishReasonMaxTokens
138	case "tool_use":
139		return message.FinishReasonToolUse
140	case "stop_sequence":
141		return message.FinishReasonEndTurn
142	default:
143		return message.FinishReasonUnknown
144	}
145}
146
147func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
148	model := a.providerOptions.model(a.providerOptions.modelType)
149	var thinkingParam anthropic.ThinkingConfigParamUnion
150	// TODO: Implement a proper thinking function
151	// lastMessage := messages[len(messages)-1]
152	// isUser := lastMessage.Role == anthropic.MessageParamRoleUser
153	// messageContent := ""
154	temperature := anthropic.Float(0)
155	// if isUser {
156	// 	for _, m := range lastMessage.Content {
157	// 		if m.OfText != nil && m.OfText.Text != "" {
158	// 			messageContent = m.OfText.Text
159	// 		}
160	// 	}
161	// 	if messageContent != "" && a.shouldThink != nil && a.options.shouldThink(messageContent) {
162	// 		thinkingParam = anthropic.ThinkingConfigParamOfEnabled(int64(float64(a.providerOptions.maxTokens) * 0.8))
163	// 		temperature = anthropic.Float(1)
164	// 	}
165	// }
166
167	cfg := config.Get()
168	modelConfig := cfg.Models.Large
169	if a.providerOptions.modelType == config.SmallModel {
170		modelConfig = cfg.Models.Small
171	}
172	maxTokens := model.DefaultMaxTokens
173	if modelConfig.MaxTokens > 0 {
174		maxTokens = modelConfig.MaxTokens
175	}
176
177	return anthropic.MessageNewParams{
178		Model:       anthropic.Model(model.ID),
179		MaxTokens:   maxTokens,
180		Temperature: temperature,
181		Messages:    messages,
182		Tools:       tools,
183		Thinking:    thinkingParam,
184		System: []anthropic.TextBlockParam{
185			{
186				Text: a.providerOptions.systemMessage,
187				CacheControl: anthropic.CacheControlEphemeralParam{
188					Type: "ephemeral",
189				},
190			},
191		},
192	}
193}
194
195func (a *anthropicClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
196	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
197	cfg := config.Get()
198	if cfg.Options.Debug {
199		jsonData, _ := json.Marshal(preparedMessages)
200		logging.Debug("Prepared messages", "messages", string(jsonData))
201	}
202
203	attempts := 0
204	for {
205		attempts++
206		anthropicResponse, err := a.client.Messages.New(
207			ctx,
208			preparedMessages,
209		)
210		// If there is an error we are going to see if we can retry the call
211		if err != nil {
212			logging.Error("Error in Anthropic API call", "error", err)
213			retry, after, retryErr := a.shouldRetry(attempts, err)
214			if retryErr != nil {
215				return nil, retryErr
216			}
217			if retry {
218				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
219				select {
220				case <-ctx.Done():
221					return nil, ctx.Err()
222				case <-time.After(time.Duration(after) * time.Millisecond):
223					continue
224				}
225			}
226			return nil, retryErr
227		}
228
229		content := ""
230		for _, block := range anthropicResponse.Content {
231			if text, ok := block.AsAny().(anthropic.TextBlock); ok {
232				content += text.Text
233			}
234		}
235
236		return &ProviderResponse{
237			Content:   content,
238			ToolCalls: a.toolCalls(*anthropicResponse),
239			Usage:     a.usage(*anthropicResponse),
240		}, nil
241	}
242}
243
244func (a *anthropicClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
245	preparedMessages := a.preparedMessages(a.convertMessages(messages), a.convertTools(tools))
246	cfg := config.Get()
247	if cfg.Options.Debug {
248		// jsonData, _ := json.Marshal(preparedMessages)
249		// logging.Debug("Prepared messages", "messages", string(jsonData))
250	}
251	attempts := 0
252	eventChan := make(chan ProviderEvent)
253	go func() {
254		for {
255			attempts++
256			anthropicStream := a.client.Messages.NewStreaming(
257				ctx,
258				preparedMessages,
259			)
260			accumulatedMessage := anthropic.Message{}
261
262			currentToolCallID := ""
263			for anthropicStream.Next() {
264				event := anthropicStream.Current()
265				err := accumulatedMessage.Accumulate(event)
266				if err != nil {
267					logging.Warn("Error accumulating message", "error", err)
268					continue
269				}
270
271				switch event := event.AsAny().(type) {
272				case anthropic.ContentBlockStartEvent:
273					switch event.ContentBlock.Type {
274					case "text":
275						eventChan <- ProviderEvent{Type: EventContentStart}
276					case "tool_use":
277						currentToolCallID = event.ContentBlock.ID
278						eventChan <- ProviderEvent{
279							Type: EventToolUseStart,
280							ToolCall: &message.ToolCall{
281								ID:       event.ContentBlock.ID,
282								Name:     event.ContentBlock.Name,
283								Finished: false,
284							},
285						}
286					}
287
288				case anthropic.ContentBlockDeltaEvent:
289					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
290						eventChan <- ProviderEvent{
291							Type:     EventThinkingDelta,
292							Thinking: event.Delta.Thinking,
293						}
294					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
295						eventChan <- ProviderEvent{
296							Type:    EventContentDelta,
297							Content: event.Delta.Text,
298						}
299					} else if event.Delta.Type == "input_json_delta" {
300						if currentToolCallID != "" {
301							eventChan <- ProviderEvent{
302								Type: EventToolUseDelta,
303								ToolCall: &message.ToolCall{
304									ID:       currentToolCallID,
305									Finished: false,
306									Input:    event.Delta.PartialJSON,
307								},
308							}
309						}
310					}
311				case anthropic.ContentBlockStopEvent:
312					if currentToolCallID != "" {
313						eventChan <- ProviderEvent{
314							Type: EventToolUseStop,
315							ToolCall: &message.ToolCall{
316								ID: currentToolCallID,
317							},
318						}
319						currentToolCallID = ""
320					} else {
321						eventChan <- ProviderEvent{Type: EventContentStop}
322					}
323
324				case anthropic.MessageStopEvent:
325					content := ""
326					for _, block := range accumulatedMessage.Content {
327						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
328							content += text.Text
329						}
330					}
331
332					eventChan <- ProviderEvent{
333						Type: EventComplete,
334						Response: &ProviderResponse{
335							Content:      content,
336							ToolCalls:    a.toolCalls(accumulatedMessage),
337							Usage:        a.usage(accumulatedMessage),
338							FinishReason: a.finishReason(string(accumulatedMessage.StopReason)),
339						},
340						Content: content,
341					}
342				}
343			}
344
345			err := anthropicStream.Err()
346			if err == nil || errors.Is(err, io.EOF) {
347				close(eventChan)
348				return
349			}
350			// If there is an error we are going to see if we can retry the call
351			retry, after, retryErr := a.shouldRetry(attempts, err)
352			if retryErr != nil {
353				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
354				close(eventChan)
355				return
356			}
357			if retry {
358				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
359				select {
360				case <-ctx.Done():
361					// context cancelled
362					if ctx.Err() != nil {
363						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
364					}
365					close(eventChan)
366					return
367				case <-time.After(time.Duration(after) * time.Millisecond):
368					continue
369				}
370			}
371			if ctx.Err() != nil {
372				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
373			}
374
375			close(eventChan)
376			return
377		}
378	}()
379	return eventChan
380}
381
382func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, error) {
383	var apierr *anthropic.Error
384	if !errors.As(err, &apierr) {
385		return false, 0, err
386	}
387
388	if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
389		return false, 0, err
390	}
391
392	if attempts > maxRetries {
393		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
394	}
395
396	retryMs := 0
397	retryAfterValues := apierr.Response.Header.Values("Retry-After")
398
399	backoffMs := 2000 * (1 << (attempts - 1))
400	jitterMs := int(float64(backoffMs) * 0.2)
401	retryMs = backoffMs + jitterMs
402	if len(retryAfterValues) > 0 {
403		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
404			retryMs = retryMs * 1000
405		}
406	}
407	return true, int64(retryMs), nil
408}
409
410func (a *anthropicClient) toolCalls(msg anthropic.Message) []message.ToolCall {
411	var toolCalls []message.ToolCall
412
413	for _, block := range msg.Content {
414		switch variant := block.AsAny().(type) {
415		case anthropic.ToolUseBlock:
416			toolCall := message.ToolCall{
417				ID:       variant.ID,
418				Name:     variant.Name,
419				Input:    string(variant.Input),
420				Type:     string(variant.Type),
421				Finished: true,
422			}
423			toolCalls = append(toolCalls, toolCall)
424		}
425	}
426
427	return toolCalls
428}
429
430func (a *anthropicClient) usage(msg anthropic.Message) TokenUsage {
431	return TokenUsage{
432		InputTokens:         msg.Usage.InputTokens,
433		OutputTokens:        msg.Usage.OutputTokens,
434		CacheCreationTokens: msg.Usage.CacheCreationInputTokens,
435		CacheReadTokens:     msg.Usage.CacheReadInputTokens,
436	}
437}
438
439func (a *anthropicClient) Model() config.Model {
440	return a.providerOptions.model(a.providerOptions.modelType)
441}
442
443// TODO: check if we need
444func DefaultShouldThinkFn(s string) bool {
445	return strings.Contains(strings.ToLower(s), "think")
446}