anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"strings"
  9	"time"
 10
 11	"github.com/anthropics/anthropic-sdk-go"
 12	"github.com/anthropics/anthropic-sdk-go/option"
 13	"github.com/kujtimiihoxha/termai/internal/llm/models"
 14	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 15	"github.com/kujtimiihoxha/termai/internal/message"
 16)
 17
 18type anthropicProvider struct {
 19	client        anthropic.Client
 20	model         models.Model
 21	maxTokens     int64
 22	apiKey        string
 23	systemMessage string
 24}
 25
 26type AnthropicOption func(*anthropicProvider)
 27
 28func WithAnthropicSystemMessage(message string) AnthropicOption {
 29	return func(a *anthropicProvider) {
 30		a.systemMessage = message
 31	}
 32}
 33
 34func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 35	return func(a *anthropicProvider) {
 36		a.maxTokens = maxTokens
 37	}
 38}
 39
 40func WithAnthropicModel(model models.Model) AnthropicOption {
 41	return func(a *anthropicProvider) {
 42		a.model = model
 43	}
 44}
 45
 46func WithAnthropicKey(apiKey string) AnthropicOption {
 47	return func(a *anthropicProvider) {
 48		a.apiKey = apiKey
 49	}
 50}
 51
 52func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 53	provider := &anthropicProvider{
 54		maxTokens: 1024,
 55	}
 56
 57	for _, opt := range opts {
 58		opt(provider)
 59	}
 60
 61	if provider.systemMessage == "" {
 62		return nil, errors.New("system message is required")
 63	}
 64
 65	provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
 66	return provider, nil
 67}
 68
 69func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 70	anthropicMessages := a.convertToAnthropicMessages(messages)
 71	anthropicTools := a.convertToAnthropicTools(tools)
 72
 73	response, err := a.client.Messages.New(
 74		ctx,
 75		anthropic.MessageNewParams{
 76			Model:       anthropic.Model(a.model.APIModel),
 77			MaxTokens:   a.maxTokens,
 78			Temperature: anthropic.Float(0),
 79			Messages:    anthropicMessages,
 80			Tools:       anthropicTools,
 81			System: []anthropic.TextBlockParam{
 82				{
 83					Text: a.systemMessage,
 84					CacheControl: anthropic.CacheControlEphemeralParam{
 85						Type: "ephemeral",
 86					},
 87				},
 88			},
 89		},
 90	)
 91	if err != nil {
 92		return nil, err
 93	}
 94
 95	content := ""
 96	for _, block := range response.Content {
 97		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
 98			content += text.Text
 99		}
100	}
101
102	toolCalls := a.extractToolCalls(response.Content)
103	tokenUsage := a.extractTokenUsage(response.Usage)
104
105	return &ProviderResponse{
106		Content:   content,
107		ToolCalls: toolCalls,
108		Usage:     tokenUsage,
109	}, nil
110}
111
112func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
113	anthropicMessages := a.convertToAnthropicMessages(messages)
114	anthropicTools := a.convertToAnthropicTools(tools)
115
116	var thinkingParam anthropic.ThinkingConfigParamUnion
117	lastMessage := messages[len(messages)-1]
118	temperature := anthropic.Float(0)
119	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
120		thinkingParam = anthropic.ThinkingConfigParamUnion{
121			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
122				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
123				Type:         "enabled",
124			},
125		}
126		temperature = anthropic.Float(1)
127	}
128
129	eventChan := make(chan ProviderEvent)
130
131	go func() {
132		defer close(eventChan)
133
134		const maxRetries = 8
135		attempts := 0
136
137		for {
138
139			attempts++
140
141			stream := a.client.Messages.NewStreaming(
142				ctx,
143				anthropic.MessageNewParams{
144					Model:       anthropic.Model(a.model.APIModel),
145					MaxTokens:   a.maxTokens,
146					Temperature: temperature,
147					Messages:    anthropicMessages,
148					Tools:       anthropicTools,
149					Thinking:    thinkingParam,
150					System: []anthropic.TextBlockParam{
151						{
152							Text: a.systemMessage,
153							CacheControl: anthropic.CacheControlEphemeralParam{
154								Type: "ephemeral",
155							},
156						},
157					},
158				},
159			)
160
161			accumulatedMessage := anthropic.Message{}
162
163			for stream.Next() {
164				event := stream.Current()
165				err := accumulatedMessage.Accumulate(event)
166				if err != nil {
167					eventChan <- ProviderEvent{Type: EventError, Error: err}
168					return // Don't retry on accumulation errors
169				}
170
171				switch event := event.AsAny().(type) {
172				case anthropic.ContentBlockStartEvent:
173					eventChan <- ProviderEvent{Type: EventContentStart}
174
175				case anthropic.ContentBlockDeltaEvent:
176					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
177						eventChan <- ProviderEvent{
178							Type:     EventThinkingDelta,
179							Thinking: event.Delta.Thinking,
180						}
181					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
182						eventChan <- ProviderEvent{
183							Type:    EventContentDelta,
184							Content: event.Delta.Text,
185						}
186					}
187
188				case anthropic.ContentBlockStopEvent:
189					eventChan <- ProviderEvent{Type: EventContentStop}
190
191				case anthropic.MessageStopEvent:
192					content := ""
193					for _, block := range accumulatedMessage.Content {
194						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
195							content += text.Text
196						}
197					}
198
199					toolCalls := a.extractToolCalls(accumulatedMessage.Content)
200					tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
201
202					eventChan <- ProviderEvent{
203						Type: EventComplete,
204						Response: &ProviderResponse{
205							Content:      content,
206							ToolCalls:    toolCalls,
207							Usage:        tokenUsage,
208							FinishReason: string(accumulatedMessage.StopReason),
209						},
210					}
211				}
212			}
213
214			err := stream.Err()
215			if err == nil {
216				return
217			}
218
219			var apierr *anthropic.Error
220			if !errors.As(err, &apierr) {
221				eventChan <- ProviderEvent{Type: EventError, Error: err}
222				return
223			}
224
225			if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
226				eventChan <- ProviderEvent{Type: EventError, Error: err}
227				return
228			}
229
230			if attempts > maxRetries {
231				eventChan <- ProviderEvent{
232					Type:  EventError,
233					Error: errors.New("maximum retry attempts reached for rate limit (429)"),
234				}
235				return
236			}
237
238			retryMs := 0
239			retryAfterValues := apierr.Response.Header.Values("Retry-After")
240			if len(retryAfterValues) > 0 {
241				var retryAfterSec int
242				if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
243					retryMs = retryAfterSec * 1000
244					eventChan <- ProviderEvent{
245						Type: EventWarning,
246						Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
247					}
248				}
249			} else {
250				eventChan <- ProviderEvent{
251					Type: EventWarning,
252					Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
253				}
254
255				backoffMs := 2000 * (1 << (attempts - 1))
256				jitterMs := int(float64(backoffMs) * 0.2)
257				retryMs = backoffMs + jitterMs
258			}
259			select {
260			case <-ctx.Done():
261				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
262				return
263			case <-time.After(time.Duration(retryMs) * time.Millisecond):
264				continue
265			}
266
267		}
268	}()
269
270	return eventChan, nil
271}
272
273func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
274	var toolCalls []message.ToolCall
275
276	for _, block := range content {
277		switch variant := block.AsAny().(type) {
278		case anthropic.ToolUseBlock:
279			toolCall := message.ToolCall{
280				ID:    variant.ID,
281				Name:  variant.Name,
282				Input: string(variant.Input),
283				Type:  string(variant.Type),
284			}
285			toolCalls = append(toolCalls, toolCall)
286		}
287	}
288
289	return toolCalls
290}
291
292func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
293	return TokenUsage{
294		InputTokens:         usage.InputTokens,
295		OutputTokens:        usage.OutputTokens,
296		CacheCreationTokens: usage.CacheCreationInputTokens,
297		CacheReadTokens:     usage.CacheReadInputTokens,
298	}
299}
300
301func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
302	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
303
304	for i, tool := range tools {
305		info := tool.Info()
306		toolParam := anthropic.ToolParam{
307			Name:        info.Name,
308			Description: anthropic.String(info.Description),
309			InputSchema: anthropic.ToolInputSchemaParam{
310				Properties: info.Parameters,
311			},
312		}
313
314		if i == len(tools)-1 {
315			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
316				Type: "ephemeral",
317			}
318		}
319
320		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
321	}
322
323	return anthropicTools
324}
325
326func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
327	anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
328	cachedBlocks := 0
329
330	for _, msg := range messages {
331		switch msg.Role {
332		case message.User:
333			content := anthropic.NewTextBlock(msg.Content().String())
334			if cachedBlocks < 2 {
335				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
336					Type: "ephemeral",
337				}
338				cachedBlocks++
339			}
340			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
341
342		case message.Assistant:
343			blocks := []anthropic.ContentBlockParamUnion{}
344			if msg.Content().String() != "" {
345				content := anthropic.NewTextBlock(msg.Content().String())
346				if cachedBlocks < 2 {
347					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
348						Type: "ephemeral",
349					}
350					cachedBlocks++
351				}
352				blocks = append(blocks, content)
353			}
354
355			for _, toolCall := range msg.ToolCalls() {
356				var inputMap map[string]any
357				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
358				if err != nil {
359					continue
360				}
361				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
362			}
363
364			if len(blocks) > 0 {
365				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
366			}
367
368		case message.Tool:
369			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
370			for i, toolResult := range msg.ToolResults() {
371				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
372			}
373			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
374		}
375	}
376
377	return anthropicMessages
378}