anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"strings"
  9	"time"
 10
 11	"github.com/anthropics/anthropic-sdk-go"
 12	"github.com/anthropics/anthropic-sdk-go/option"
 13	"github.com/kujtimiihoxha/termai/internal/llm/models"
 14	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 15	"github.com/kujtimiihoxha/termai/internal/message"
 16)
 17
 18type anthropicProvider struct {
 19	client        anthropic.Client
 20	model         models.Model
 21	maxTokens     int64
 22	apiKey        string
 23	systemMessage string
 24}
 25
 26type AnthropicOption func(*anthropicProvider)
 27
 28func WithAnthropicSystemMessage(message string) AnthropicOption {
 29	return func(a *anthropicProvider) {
 30		a.systemMessage = message
 31	}
 32}
 33
 34func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 35	return func(a *anthropicProvider) {
 36		a.maxTokens = maxTokens
 37	}
 38}
 39
 40func WithAnthropicModel(model models.Model) AnthropicOption {
 41	return func(a *anthropicProvider) {
 42		a.model = model
 43	}
 44}
 45
 46func WithAnthropicKey(apiKey string) AnthropicOption {
 47	return func(a *anthropicProvider) {
 48		a.apiKey = apiKey
 49	}
 50}
 51
 52func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 53	provider := &anthropicProvider{
 54		maxTokens: 1024,
 55	}
 56
 57	for _, opt := range opts {
 58		opt(provider)
 59	}
 60
 61	if provider.systemMessage == "" {
 62		return nil, errors.New("system message is required")
 63	}
 64
 65	provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
 66	return provider, nil
 67}
 68
 69func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 70	anthropicMessages := a.convertToAnthropicMessages(messages)
 71	anthropicTools := a.convertToAnthropicTools(tools)
 72
 73	response, err := a.client.Messages.New(
 74		ctx,
 75		anthropic.MessageNewParams{
 76			Model:       anthropic.Model(a.model.APIModel),
 77			MaxTokens:   a.maxTokens,
 78			Temperature: anthropic.Float(0),
 79			Messages:    anthropicMessages,
 80			Tools:       anthropicTools,
 81			System: []anthropic.TextBlockParam{
 82				{
 83					Text: a.systemMessage,
 84					CacheControl: anthropic.CacheControlEphemeralParam{
 85						Type: "ephemeral",
 86					},
 87				},
 88			},
 89		},
 90	)
 91	if err != nil {
 92		return nil, err
 93	}
 94
 95	content := ""
 96	for _, block := range response.Content {
 97		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
 98			content += text.Text
 99		}
100	}
101
102	toolCalls := a.extractToolCalls(response.Content)
103	tokenUsage := a.extractTokenUsage(response.Usage)
104
105	return &ProviderResponse{
106		Content:   content,
107		ToolCalls: toolCalls,
108		Usage:     tokenUsage,
109	}, nil
110}
111
112func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
113	anthropicMessages := a.convertToAnthropicMessages(messages)
114	anthropicTools := a.convertToAnthropicTools(tools)
115
116	var thinkingParam anthropic.ThinkingConfigParamUnion
117	lastMessage := messages[len(messages)-1]
118	temperature := anthropic.Float(0)
119	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
120		thinkingParam = anthropic.ThinkingConfigParamUnion{
121			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
122				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
123				Type:         "enabled",
124			},
125		}
126		temperature = anthropic.Float(1)
127	}
128
129	eventChan := make(chan ProviderEvent)
130
131	go func() {
132		defer close(eventChan)
133
134		const maxRetries = 8
135		attempts := 0
136
137		for {
138			// If this isn't the first attempt, we're retrying
139			if attempts > 0 {
140				if attempts > maxRetries {
141					eventChan <- ProviderEvent{
142						Type:  EventError,
143						Error: errors.New("maximum retry attempts reached for rate limit (429)"),
144					}
145					return
146				}
147
148				// Inform user we're retrying with attempt number
149				eventChan <- ProviderEvent{
150					Type: EventWarning,
151					Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
152				}
153
154				// Calculate backoff with exponential backoff and jitter
155				backoffMs := 2000 * (1 << (attempts - 1)) // 2s, 4s, 8s, 16s, 32s
156				jitterMs := int(float64(backoffMs) * 0.2)
157				totalBackoffMs := backoffMs + jitterMs
158
159				// Sleep with backoff, respecting context cancellation
160				select {
161				case <-ctx.Done():
162					eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
163					return
164				case <-time.After(time.Duration(totalBackoffMs) * time.Millisecond):
165					// Continue with retry
166				}
167			}
168
169			attempts++
170
171			// Create new streaming request
172			stream := a.client.Messages.NewStreaming(
173				ctx,
174				anthropic.MessageNewParams{
175					Model:       anthropic.Model(a.model.APIModel),
176					MaxTokens:   a.maxTokens,
177					Temperature: temperature,
178					Messages:    anthropicMessages,
179					Tools:       anthropicTools,
180					Thinking:    thinkingParam,
181					System: []anthropic.TextBlockParam{
182						{
183							Text: a.systemMessage,
184							CacheControl: anthropic.CacheControlEphemeralParam{
185								Type: "ephemeral",
186							},
187						},
188					},
189				},
190			)
191
192			// Process stream events
193			accumulatedMessage := anthropic.Message{}
194			streamSuccess := false
195
196			// Process the stream until completion or error
197			for stream.Next() {
198				event := stream.Current()
199				err := accumulatedMessage.Accumulate(event)
200				if err != nil {
201					eventChan <- ProviderEvent{Type: EventError, Error: err}
202					return // Don't retry on accumulation errors
203				}
204
205				switch event := event.AsAny().(type) {
206				case anthropic.ContentBlockStartEvent:
207					eventChan <- ProviderEvent{Type: EventContentStart}
208
209				case anthropic.ContentBlockDeltaEvent:
210					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
211						eventChan <- ProviderEvent{
212							Type:     EventThinkingDelta,
213							Thinking: event.Delta.Thinking,
214						}
215					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
216						eventChan <- ProviderEvent{
217							Type:    EventContentDelta,
218							Content: event.Delta.Text,
219						}
220					}
221
222				case anthropic.ContentBlockStopEvent:
223					eventChan <- ProviderEvent{Type: EventContentStop}
224
225				case anthropic.MessageStopEvent:
226					streamSuccess = true
227					content := ""
228					for _, block := range accumulatedMessage.Content {
229						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
230							content += text.Text
231						}
232					}
233
234					toolCalls := a.extractToolCalls(accumulatedMessage.Content)
235					tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
236
237					eventChan <- ProviderEvent{
238						Type: EventComplete,
239						Response: &ProviderResponse{
240							Content:      content,
241							ToolCalls:    toolCalls,
242							Usage:        tokenUsage,
243							FinishReason: string(accumulatedMessage.StopReason),
244						},
245					}
246				}
247			}
248
249			// If the stream completed successfully, we're done
250			if streamSuccess {
251				return
252			}
253
254			// Check for stream errors
255			err := stream.Err()
256			if err != nil {
257				var apierr *anthropic.Error
258				if errors.As(err, &apierr) {
259					if apierr.StatusCode == 429 || apierr.StatusCode == 529 {
260						// Check for Retry-After header
261						if retryAfterValues := apierr.Response.Header.Values("Retry-After"); len(retryAfterValues) > 0 {
262							// Parse the retry after value (seconds)
263							var retryAfterSec int
264							if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
265								retryMs := retryAfterSec * 1000
266
267								// Inform user of retry with specific wait time
268								eventChan <- ProviderEvent{
269									Type: EventWarning,
270									Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
271								}
272
273								// Sleep respecting context cancellation
274								select {
275								case <-ctx.Done():
276									eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
277									return
278								case <-time.After(time.Duration(retryMs) * time.Millisecond):
279									// Continue with retry after specified delay
280									continue
281								}
282							}
283						}
284
285						// Fall back to exponential backoff if Retry-After parsing failed
286						continue
287					}
288				}
289
290				// For non-rate limit errors, report and exit
291				eventChan <- ProviderEvent{Type: EventError, Error: err}
292				return
293			}
294		}
295	}()
296
297	return eventChan, nil
298}
299
300func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
301	var toolCalls []message.ToolCall
302
303	for _, block := range content {
304		switch variant := block.AsAny().(type) {
305		case anthropic.ToolUseBlock:
306			toolCall := message.ToolCall{
307				ID:    variant.ID,
308				Name:  variant.Name,
309				Input: string(variant.Input),
310				Type:  string(variant.Type),
311			}
312			toolCalls = append(toolCalls, toolCall)
313		}
314	}
315
316	return toolCalls
317}
318
319func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
320	return TokenUsage{
321		InputTokens:         usage.InputTokens,
322		OutputTokens:        usage.OutputTokens,
323		CacheCreationTokens: usage.CacheCreationInputTokens,
324		CacheReadTokens:     usage.CacheReadInputTokens,
325	}
326}
327
328func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
329	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
330
331	for i, tool := range tools {
332		info := tool.Info()
333		toolParam := anthropic.ToolParam{
334			Name:        info.Name,
335			Description: anthropic.String(info.Description),
336			InputSchema: anthropic.ToolInputSchemaParam{
337				Properties: info.Parameters,
338			},
339		}
340
341		if i == len(tools)-1 {
342			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
343				Type: "ephemeral",
344			}
345		}
346
347		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
348	}
349
350	return anthropicTools
351}
352
353func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
354	anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
355	cachedBlocks := 0
356
357	for _, msg := range messages {
358		switch msg.Role {
359		case message.User:
360			content := anthropic.NewTextBlock(msg.Content().String())
361			if cachedBlocks < 2 {
362				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
363					Type: "ephemeral",
364				}
365				cachedBlocks++
366			}
367			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
368
369		case message.Assistant:
370			blocks := []anthropic.ContentBlockParamUnion{}
371			if msg.Content().String() != "" {
372				content := anthropic.NewTextBlock(msg.Content().String())
373				if cachedBlocks < 2 {
374					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
375						Type: "ephemeral",
376					}
377					cachedBlocks++
378				}
379				blocks = append(blocks, content)
380			}
381
382			for _, toolCall := range msg.ToolCalls() {
383				var inputMap map[string]any
384				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
385				if err != nil {
386					continue
387				}
388				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
389			}
390
391			// Skip empty assistant messages completely
392			if len(blocks) > 0 {
393				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
394			}
395
396		case message.Tool:
397			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
398			for i, toolResult := range msg.ToolResults() {
399				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
400			}
401			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
402		}
403	}
404
405	return anthropicMessages
406}
407