anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"strings"
  9	"time"
 10
 11	"github.com/anthropics/anthropic-sdk-go"
 12	"github.com/anthropics/anthropic-sdk-go/bedrock"
 13	"github.com/anthropics/anthropic-sdk-go/option"
 14	"github.com/kujtimiihoxha/termai/internal/llm/models"
 15	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 16	"github.com/kujtimiihoxha/termai/internal/message"
 17)
 18
 19type anthropicProvider struct {
 20	client        anthropic.Client
 21	model         models.Model
 22	maxTokens     int64
 23	apiKey        string
 24	systemMessage string
 25	useBedrock    bool
 26	disableCache  bool
 27}
 28
 29type AnthropicOption func(*anthropicProvider)
 30
 31func WithAnthropicSystemMessage(message string) AnthropicOption {
 32	return func(a *anthropicProvider) {
 33		a.systemMessage = message
 34	}
 35}
 36
 37func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 38	return func(a *anthropicProvider) {
 39		a.maxTokens = maxTokens
 40	}
 41}
 42
 43func WithAnthropicModel(model models.Model) AnthropicOption {
 44	return func(a *anthropicProvider) {
 45		a.model = model
 46	}
 47}
 48
 49func WithAnthropicKey(apiKey string) AnthropicOption {
 50	return func(a *anthropicProvider) {
 51		a.apiKey = apiKey
 52	}
 53}
 54
 55func WithAnthropicBedrock() AnthropicOption {
 56	return func(a *anthropicProvider) {
 57		a.useBedrock = true
 58	}
 59}
 60
 61func WithAnthropicDisableCache() AnthropicOption {
 62	return func(a *anthropicProvider) {
 63		a.disableCache = true
 64	}
 65}
 66
 67func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 68	provider := &anthropicProvider{
 69		maxTokens: 1024,
 70	}
 71
 72	for _, opt := range opts {
 73		opt(provider)
 74	}
 75
 76	if provider.systemMessage == "" {
 77		return nil, errors.New("system message is required")
 78	}
 79
 80	anthropicOptions := []option.RequestOption{}
 81
 82	if provider.apiKey != "" {
 83		anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
 84	}
 85	if provider.useBedrock {
 86		anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 87	}
 88
 89	provider.client = anthropic.NewClient(anthropicOptions...)
 90	return provider, nil
 91}
 92
 93func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 94	anthropicMessages := a.convertToAnthropicMessages(messages)
 95	anthropicTools := a.convertToAnthropicTools(tools)
 96
 97	response, err := a.client.Messages.New(
 98		ctx,
 99		anthropic.MessageNewParams{
100			Model:       anthropic.Model(a.model.APIModel),
101			MaxTokens:   a.maxTokens,
102			Temperature: anthropic.Float(0),
103			Messages:    anthropicMessages,
104			Tools:       anthropicTools,
105			System: []anthropic.TextBlockParam{
106				{
107					Text: a.systemMessage,
108					CacheControl: anthropic.CacheControlEphemeralParam{
109						Type: "ephemeral",
110					},
111				},
112			},
113		},
114	)
115	if err != nil {
116		return nil, err
117	}
118
119	content := ""
120	for _, block := range response.Content {
121		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
122			content += text.Text
123		}
124	}
125
126	toolCalls := a.extractToolCalls(response.Content)
127	tokenUsage := a.extractTokenUsage(response.Usage)
128
129	return &ProviderResponse{
130		Content:   content,
131		ToolCalls: toolCalls,
132		Usage:     tokenUsage,
133	}, nil
134}
135
136func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
137	anthropicMessages := a.convertToAnthropicMessages(messages)
138	anthropicTools := a.convertToAnthropicTools(tools)
139
140	var thinkingParam anthropic.ThinkingConfigParamUnion
141	lastMessage := messages[len(messages)-1]
142	temperature := anthropic.Float(0)
143	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
144		thinkingParam = anthropic.ThinkingConfigParamUnion{
145			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
146				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
147				Type:         "enabled",
148			},
149		}
150		temperature = anthropic.Float(1)
151	}
152
153	eventChan := make(chan ProviderEvent)
154
155	go func() {
156		defer close(eventChan)
157
158		const maxRetries = 8
159		attempts := 0
160
161		for {
162			// If this isn't the first attempt, we're retrying
163			if attempts > 0 {
164				if attempts > maxRetries {
165					eventChan <- ProviderEvent{
166						Type:  EventError,
167						Error: errors.New("maximum retry attempts reached for rate limit (429)"),
168					}
169					return
170				}
171
172				// Inform user we're retrying with attempt number
173				eventChan <- ProviderEvent{
174					Type: EventWarning,
175					Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
176				}
177
178				// Calculate backoff with exponential backoff and jitter
179				backoffMs := 2000 * (1 << (attempts - 1)) // 2s, 4s, 8s, 16s, 32s
180				jitterMs := int(float64(backoffMs) * 0.2)
181				totalBackoffMs := backoffMs + jitterMs
182
183				// Sleep with backoff, respecting context cancellation
184				select {
185				case <-ctx.Done():
186					eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
187					return
188				case <-time.After(time.Duration(totalBackoffMs) * time.Millisecond):
189					// Continue with retry
190				}
191			}
192
193			attempts++
194
195			// Create new streaming request
196			stream := a.client.Messages.NewStreaming(
197				ctx,
198				anthropic.MessageNewParams{
199					Model:       anthropic.Model(a.model.APIModel),
200					MaxTokens:   a.maxTokens,
201					Temperature: temperature,
202					Messages:    anthropicMessages,
203					Tools:       anthropicTools,
204					Thinking:    thinkingParam,
205					System: []anthropic.TextBlockParam{
206						{
207							Text: a.systemMessage,
208							CacheControl: anthropic.CacheControlEphemeralParam{
209								Type: "ephemeral",
210							},
211						},
212					},
213				},
214			)
215
216			// Process stream events
217			accumulatedMessage := anthropic.Message{}
218			streamSuccess := false
219
220			// Process the stream until completion or error
221			for stream.Next() {
222				event := stream.Current()
223				err := accumulatedMessage.Accumulate(event)
224				if err != nil {
225					eventChan <- ProviderEvent{Type: EventError, Error: err}
226					return // Don't retry on accumulation errors
227				}
228
229				switch event := event.AsAny().(type) {
230				case anthropic.ContentBlockStartEvent:
231					eventChan <- ProviderEvent{Type: EventContentStart}
232
233				case anthropic.ContentBlockDeltaEvent:
234					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
235						eventChan <- ProviderEvent{
236							Type:     EventThinkingDelta,
237							Thinking: event.Delta.Thinking,
238						}
239					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
240						eventChan <- ProviderEvent{
241							Type:    EventContentDelta,
242							Content: event.Delta.Text,
243						}
244					}
245
246				case anthropic.ContentBlockStopEvent:
247					eventChan <- ProviderEvent{Type: EventContentStop}
248
249				case anthropic.MessageStopEvent:
250					streamSuccess = true
251					content := ""
252					for _, block := range accumulatedMessage.Content {
253						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
254							content += text.Text
255						}
256					}
257
258					toolCalls := a.extractToolCalls(accumulatedMessage.Content)
259					tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
260
261					eventChan <- ProviderEvent{
262						Type: EventComplete,
263						Response: &ProviderResponse{
264							Content:      content,
265							ToolCalls:    toolCalls,
266							Usage:        tokenUsage,
267							FinishReason: string(accumulatedMessage.StopReason),
268						},
269					}
270				}
271			}
272
273			// If the stream completed successfully, we're done
274			if streamSuccess {
275				return
276			}
277
278			// Check for stream errors
279			err := stream.Err()
280			if err != nil {
281				var apierr *anthropic.Error
282				if errors.As(err, &apierr) {
283					if apierr.StatusCode == 429 || apierr.StatusCode == 529 {
284						// Check for Retry-After header
285						if retryAfterValues := apierr.Response.Header.Values("Retry-After"); len(retryAfterValues) > 0 {
286							// Parse the retry after value (seconds)
287							var retryAfterSec int
288							if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
289								retryMs := retryAfterSec * 1000
290
291								// Inform user of retry with specific wait time
292								eventChan <- ProviderEvent{
293									Type: EventWarning,
294									Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
295								}
296
297								// Sleep respecting context cancellation
298								select {
299								case <-ctx.Done():
300									eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
301									return
302								case <-time.After(time.Duration(retryMs) * time.Millisecond):
303									// Continue with retry after specified delay
304									continue
305								}
306							}
307						}
308
309						// Fall back to exponential backoff if Retry-After parsing failed
310						continue
311					}
312				}
313
314				// For non-rate limit errors, report and exit
315				eventChan <- ProviderEvent{Type: EventError, Error: err}
316				return
317			}
318		}
319	}()
320
321	return eventChan, nil
322}
323
324func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
325	var toolCalls []message.ToolCall
326
327	for _, block := range content {
328		switch variant := block.AsAny().(type) {
329		case anthropic.ToolUseBlock:
330			toolCall := message.ToolCall{
331				ID:    variant.ID,
332				Name:  variant.Name,
333				Input: string(variant.Input),
334				Type:  string(variant.Type),
335			}
336			toolCalls = append(toolCalls, toolCall)
337		}
338	}
339
340	return toolCalls
341}
342
343func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
344	return TokenUsage{
345		InputTokens:         usage.InputTokens,
346		OutputTokens:        usage.OutputTokens,
347		CacheCreationTokens: usage.CacheCreationInputTokens,
348		CacheReadTokens:     usage.CacheReadInputTokens,
349	}
350}
351
352func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
353	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
354
355	for i, tool := range tools {
356		info := tool.Info()
357		toolParam := anthropic.ToolParam{
358			Name:        info.Name,
359			Description: anthropic.String(info.Description),
360			InputSchema: anthropic.ToolInputSchemaParam{
361				Properties: info.Parameters,
362			},
363		}
364
365		if i == len(tools)-1 && !a.disableCache {
366			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
367				Type: "ephemeral",
368			}
369		}
370
371		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
372	}
373
374	return anthropicTools
375}
376
377func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
378	anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
379	cachedBlocks := 0
380
381	for _, msg := range messages {
382		switch msg.Role {
383		case message.User:
384			content := anthropic.NewTextBlock(msg.Content().String())
385			if cachedBlocks < 2 && !a.disableCache {
386				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
387					Type: "ephemeral",
388				}
389				cachedBlocks++
390			}
391			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
392
393		case message.Assistant:
394			blocks := []anthropic.ContentBlockParamUnion{}
395			if msg.Content().String() != "" {
396				content := anthropic.NewTextBlock(msg.Content().String())
397				if cachedBlocks < 2 && !a.disableCache {
398					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
399						Type: "ephemeral",
400					}
401					cachedBlocks++
402				}
403				blocks = append(blocks, content)
404			}
405
406			for _, toolCall := range msg.ToolCalls() {
407				var inputMap map[string]any
408				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
409				if err != nil {
410					continue
411				}
412				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
413			}
414
415			// Skip empty assistant messages completely
416			if len(blocks) > 0 {
417				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
418			}
419
420		case message.Tool:
421			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
422			for i, toolResult := range msg.ToolResults() {
423				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
424			}
425			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
426		}
427	}
428
429	return anthropicMessages
430}