anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"log"
  9	"strings"
 10	"time"
 11
 12	"github.com/anthropics/anthropic-sdk-go"
 13	"github.com/anthropics/anthropic-sdk-go/option"
 14	"github.com/kujtimiihoxha/termai/internal/llm/models"
 15	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 16	"github.com/kujtimiihoxha/termai/internal/message"
 17)
 18
 19type anthropicProvider struct {
 20	client        anthropic.Client
 21	model         models.Model
 22	maxTokens     int64
 23	apiKey        string
 24	systemMessage string
 25}
 26
 27type AnthropicOption func(*anthropicProvider)
 28
 29func WithAnthropicSystemMessage(message string) AnthropicOption {
 30	return func(a *anthropicProvider) {
 31		a.systemMessage = message
 32	}
 33}
 34
 35func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 36	return func(a *anthropicProvider) {
 37		a.maxTokens = maxTokens
 38	}
 39}
 40
 41func WithAnthropicModel(model models.Model) AnthropicOption {
 42	return func(a *anthropicProvider) {
 43		a.model = model
 44	}
 45}
 46
 47func WithAnthropicKey(apiKey string) AnthropicOption {
 48	return func(a *anthropicProvider) {
 49		a.apiKey = apiKey
 50	}
 51}
 52
 53func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 54	provider := &anthropicProvider{
 55		maxTokens: 1024,
 56	}
 57
 58	for _, opt := range opts {
 59		opt(provider)
 60	}
 61
 62	if provider.systemMessage == "" {
 63		return nil, errors.New("system message is required")
 64	}
 65
 66	provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
 67	return provider, nil
 68}
 69
 70func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 71	anthropicMessages := a.convertToAnthropicMessages(messages)
 72	anthropicTools := a.convertToAnthropicTools(tools)
 73
 74	response, err := a.client.Messages.New(
 75		ctx,
 76		anthropic.MessageNewParams{
 77			Model:       anthropic.Model(a.model.APIModel),
 78			MaxTokens:   a.maxTokens,
 79			Temperature: anthropic.Float(0),
 80			Messages:    anthropicMessages,
 81			Tools:       anthropicTools,
 82			System: []anthropic.TextBlockParam{
 83				{
 84					Text: a.systemMessage,
 85					CacheControl: anthropic.CacheControlEphemeralParam{
 86						Type: "ephemeral",
 87					},
 88				},
 89			},
 90		},
 91		option.WithMaxRetries(8),
 92	)
 93	if err != nil {
 94		return nil, err
 95	}
 96
 97	content := ""
 98	for _, block := range response.Content {
 99		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
100			content += text.Text
101		}
102	}
103
104	toolCalls := a.extractToolCalls(response.Content)
105	tokenUsage := a.extractTokenUsage(response.Usage)
106
107	return &ProviderResponse{
108		Content:   content,
109		ToolCalls: toolCalls,
110		Usage:     tokenUsage,
111	}, nil
112}
113
114func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
115	anthropicMessages := a.convertToAnthropicMessages(messages)
116	anthropicTools := a.convertToAnthropicTools(tools)
117
118	var thinkingParam anthropic.ThinkingConfigParamUnion
119	lastMessage := messages[len(messages)-1]
120	temperature := anthropic.Float(0)
121	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
122		thinkingParam = anthropic.ThinkingConfigParamUnion{
123			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
124				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
125				Type:         "enabled",
126			},
127		}
128		temperature = anthropic.Float(1)
129	}
130
131	eventChan := make(chan ProviderEvent)
132
133	go func() {
134		defer close(eventChan)
135
136		const maxRetries = 8
137		attempts := 0
138
139		for {
140			// If this isn't the first attempt, we're retrying
141			if attempts > 0 {
142				if attempts > maxRetries {
143					eventChan <- ProviderEvent{
144						Type:  EventError,
145						Error: errors.New("maximum retry attempts reached for rate limit (429)"),
146					}
147					return
148				}
149
150				// Inform user we're retrying with attempt number
151				eventChan <- ProviderEvent{
152					Type:    EventContentDelta,
153					Content: fmt.Sprintf("\n\n[Retrying due to rate limit... attempt %d of %d]\n\n", attempts, maxRetries),
154				}
155
156				// Calculate backoff with exponential backoff and jitter
157				backoffMs := 2000 * (1 << (attempts - 1)) // 2s, 4s, 8s, 16s, 32s
158				jitterMs := int(float64(backoffMs) * 0.2)
159				totalBackoffMs := backoffMs + jitterMs
160
161				// Sleep with backoff, respecting context cancellation
162				select {
163				case <-ctx.Done():
164					eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
165					return
166				case <-time.After(time.Duration(totalBackoffMs) * time.Millisecond):
167					// Continue with retry
168				}
169			}
170
171			attempts++
172
173			// Create new streaming request
174			stream := a.client.Messages.NewStreaming(
175				ctx,
176				anthropic.MessageNewParams{
177					Model:       anthropic.Model(a.model.APIModel),
178					MaxTokens:   a.maxTokens,
179					Temperature: temperature,
180					Messages:    anthropicMessages,
181					Tools:       anthropicTools,
182					Thinking:    thinkingParam,
183					System: []anthropic.TextBlockParam{
184						{
185							Text: a.systemMessage,
186							CacheControl: anthropic.CacheControlEphemeralParam{
187								Type: "ephemeral",
188							},
189						},
190					},
191				},
192			)
193
194			// Process stream events
195			accumulatedMessage := anthropic.Message{}
196			streamSuccess := false
197
198			// Process the stream until completion or error
199			for stream.Next() {
200				event := stream.Current()
201				err := accumulatedMessage.Accumulate(event)
202				if err != nil {
203					eventChan <- ProviderEvent{Type: EventError, Error: err}
204					return // Don't retry on accumulation errors
205				}
206
207				switch event := event.AsAny().(type) {
208				case anthropic.ContentBlockStartEvent:
209					eventChan <- ProviderEvent{Type: EventContentStart}
210
211				case anthropic.ContentBlockDeltaEvent:
212					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
213						eventChan <- ProviderEvent{
214							Type:     EventThinkingDelta,
215							Thinking: event.Delta.Thinking,
216						}
217					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
218						eventChan <- ProviderEvent{
219							Type:    EventContentDelta,
220							Content: event.Delta.Text,
221						}
222					}
223
224				case anthropic.ContentBlockStopEvent:
225					eventChan <- ProviderEvent{Type: EventContentStop}
226
227				case anthropic.MessageStopEvent:
228					streamSuccess = true
229					content := ""
230					for _, block := range accumulatedMessage.Content {
231						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
232							content += text.Text
233						}
234					}
235
236					toolCalls := a.extractToolCalls(accumulatedMessage.Content)
237					tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
238
239					eventChan <- ProviderEvent{
240						Type: EventComplete,
241						Response: &ProviderResponse{
242							Content:      content,
243							ToolCalls:    toolCalls,
244							Usage:        tokenUsage,
245							FinishReason: string(accumulatedMessage.StopReason),
246						},
247					}
248				}
249			}
250
251			// If the stream completed successfully, we're done
252			if streamSuccess {
253				return
254			}
255
256			// Check for stream errors
257			err := stream.Err()
258			if err != nil {
259				log.Println("error", err)
260
261				var apierr *anthropic.Error
262				if errors.As(err, &apierr) && apierr.StatusCode == 429 {
263					continue
264				}
265
266				// For non-rate limit errors, report and exit
267				eventChan <- ProviderEvent{Type: EventError, Error: err}
268				return
269			}
270		}
271	}()
272
273	return eventChan, nil
274}
275
276func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
277	var toolCalls []message.ToolCall
278
279	for _, block := range content {
280		switch variant := block.AsAny().(type) {
281		case anthropic.ToolUseBlock:
282			toolCall := message.ToolCall{
283				ID:    variant.ID,
284				Name:  variant.Name,
285				Input: string(variant.Input),
286				Type:  string(variant.Type),
287			}
288			toolCalls = append(toolCalls, toolCall)
289		}
290	}
291
292	return toolCalls
293}
294
295func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
296	return TokenUsage{
297		InputTokens:         usage.InputTokens,
298		OutputTokens:        usage.OutputTokens,
299		CacheCreationTokens: usage.CacheCreationInputTokens,
300		CacheReadTokens:     usage.CacheReadInputTokens,
301	}
302}
303
304func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
305	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
306
307	for i, tool := range tools {
308		info := tool.Info()
309		toolParam := anthropic.ToolParam{
310			Name:        info.Name,
311			Description: anthropic.String(info.Description),
312			InputSchema: anthropic.ToolInputSchemaParam{
313				Properties: info.Parameters,
314			},
315		}
316
317		if i == len(tools)-1 {
318			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
319				Type: "ephemeral",
320			}
321		}
322
323		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
324	}
325
326	return anthropicTools
327}
328
329func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
330	anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
331	cachedBlocks := 0
332
333	for _, msg := range messages {
334		switch msg.Role {
335		case message.User:
336			content := anthropic.NewTextBlock(msg.Content().String())
337			if cachedBlocks < 2 {
338				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
339					Type: "ephemeral",
340				}
341				cachedBlocks++
342			}
343			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
344
345		case message.Assistant:
346			blocks := []anthropic.ContentBlockParamUnion{}
347			if msg.Content().String() != "" {
348				content := anthropic.NewTextBlock(msg.Content().String())
349				if cachedBlocks < 2 {
350					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
351						Type: "ephemeral",
352					}
353					cachedBlocks++
354				}
355				blocks = append(blocks, content)
356			}
357
358			for _, toolCall := range msg.ToolCalls() {
359				var inputMap map[string]any
360				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
361				if err != nil {
362					continue
363				}
364				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
365			}
366
367			// Skip empty assistant messages completely
368			if len(blocks) > 0 {
369				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
370			}
371
372		case message.Tool:
373			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
374			for i, toolResult := range msg.ToolResults() {
375				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
376			}
377			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
378		}
379	}
380
381	return anthropicMessages
382}
383