anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"strings"
  9	"time"
 10
 11	"github.com/anthropics/anthropic-sdk-go"
 12	"github.com/anthropics/anthropic-sdk-go/bedrock"
 13	"github.com/anthropics/anthropic-sdk-go/option"
 14	"github.com/kujtimiihoxha/termai/internal/llm/models"
 15	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 16	"github.com/kujtimiihoxha/termai/internal/message"
 17)
 18
 19type anthropicProvider struct {
 20	client        anthropic.Client
 21	model         models.Model
 22	maxTokens     int64
 23	apiKey        string
 24	systemMessage string
 25	useBedrock    bool
 26	disableCache  bool
 27}
 28
 29type AnthropicOption func(*anthropicProvider)
 30
 31func WithAnthropicSystemMessage(message string) AnthropicOption {
 32	return func(a *anthropicProvider) {
 33		a.systemMessage = message
 34	}
 35}
 36
 37func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 38	return func(a *anthropicProvider) {
 39		a.maxTokens = maxTokens
 40	}
 41}
 42
 43func WithAnthropicModel(model models.Model) AnthropicOption {
 44	return func(a *anthropicProvider) {
 45		a.model = model
 46	}
 47}
 48
 49func WithAnthropicKey(apiKey string) AnthropicOption {
 50	return func(a *anthropicProvider) {
 51		a.apiKey = apiKey
 52	}
 53}
 54
 55func WithAnthropicBedrock() AnthropicOption {
 56	return func(a *anthropicProvider) {
 57		a.useBedrock = true
 58	}
 59}
 60
 61func WithAnthropicDisableCache() AnthropicOption {
 62	return func(a *anthropicProvider) {
 63		a.disableCache = true
 64	}
 65}
 66
 67func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 68	provider := &anthropicProvider{
 69		maxTokens: 1024,
 70	}
 71
 72	for _, opt := range opts {
 73		opt(provider)
 74	}
 75
 76	if provider.systemMessage == "" {
 77		return nil, errors.New("system message is required")
 78	}
 79
 80	anthropicOptions := []option.RequestOption{}
 81
 82	if provider.apiKey != "" {
 83		anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
 84	}
 85	if provider.useBedrock {
 86		anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 87	}
 88
 89	provider.client = anthropic.NewClient(anthropicOptions...)
 90	return provider, nil
 91}
 92
 93func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 94	anthropicMessages := a.convertToAnthropicMessages(messages)
 95	anthropicTools := a.convertToAnthropicTools(tools)
 96
 97	response, err := a.client.Messages.New(
 98		ctx,
 99		anthropic.MessageNewParams{
100			Model:       anthropic.Model(a.model.APIModel),
101			MaxTokens:   a.maxTokens,
102			Temperature: anthropic.Float(0),
103			Messages:    anthropicMessages,
104			Tools:       anthropicTools,
105			System: []anthropic.TextBlockParam{
106				{
107					Text: a.systemMessage,
108					CacheControl: anthropic.CacheControlEphemeralParam{
109						Type: "ephemeral",
110					},
111				},
112			},
113		},
114	)
115	if err != nil {
116		return nil, err
117	}
118
119	content := ""
120	for _, block := range response.Content {
121		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
122			content += text.Text
123		}
124	}
125
126	toolCalls := a.extractToolCalls(response.Content)
127	tokenUsage := a.extractTokenUsage(response.Usage)
128
129	return &ProviderResponse{
130		Content:   content,
131		ToolCalls: toolCalls,
132		Usage:     tokenUsage,
133	}, nil
134}
135
136func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
137	anthropicMessages := a.convertToAnthropicMessages(messages)
138	anthropicTools := a.convertToAnthropicTools(tools)
139
140	var thinkingParam anthropic.ThinkingConfigParamUnion
141	lastMessage := messages[len(messages)-1]
142	temperature := anthropic.Float(0)
143	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
144		thinkingParam = anthropic.ThinkingConfigParamUnion{
145			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
146				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
147				Type:         "enabled",
148			},
149		}
150		temperature = anthropic.Float(1)
151	}
152
153	eventChan := make(chan ProviderEvent)
154
155	go func() {
156		defer close(eventChan)
157
158		const maxRetries = 8
159		attempts := 0
160
161		for {
162
163			attempts++
164
165			stream := a.client.Messages.NewStreaming(
166				ctx,
167				anthropic.MessageNewParams{
168					Model:       anthropic.Model(a.model.APIModel),
169					MaxTokens:   a.maxTokens,
170					Temperature: temperature,
171					Messages:    anthropicMessages,
172					Tools:       anthropicTools,
173					Thinking:    thinkingParam,
174					System: []anthropic.TextBlockParam{
175						{
176							Text: a.systemMessage,
177							CacheControl: anthropic.CacheControlEphemeralParam{
178								Type: "ephemeral",
179							},
180						},
181					},
182				},
183			)
184
185			accumulatedMessage := anthropic.Message{}
186
187			for stream.Next() {
188				event := stream.Current()
189				err := accumulatedMessage.Accumulate(event)
190				if err != nil {
191					eventChan <- ProviderEvent{Type: EventError, Error: err}
192					return // Don't retry on accumulation errors
193				}
194
195				switch event := event.AsAny().(type) {
196				case anthropic.ContentBlockStartEvent:
197					eventChan <- ProviderEvent{Type: EventContentStart}
198
199				case anthropic.ContentBlockDeltaEvent:
200					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
201						eventChan <- ProviderEvent{
202							Type:     EventThinkingDelta,
203							Thinking: event.Delta.Thinking,
204						}
205					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
206						eventChan <- ProviderEvent{
207							Type:    EventContentDelta,
208							Content: event.Delta.Text,
209						}
210					}
211
212				case anthropic.ContentBlockStopEvent:
213					eventChan <- ProviderEvent{Type: EventContentStop}
214
215				case anthropic.MessageStopEvent:
216					content := ""
217					for _, block := range accumulatedMessage.Content {
218						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
219							content += text.Text
220						}
221					}
222
223					toolCalls := a.extractToolCalls(accumulatedMessage.Content)
224					tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
225
226					eventChan <- ProviderEvent{
227						Type: EventComplete,
228						Response: &ProviderResponse{
229							Content:      content,
230							ToolCalls:    toolCalls,
231							Usage:        tokenUsage,
232							FinishReason: string(accumulatedMessage.StopReason),
233						},
234					}
235				}
236			}
237
238			err := stream.Err()
239			if err == nil {
240				return
241			}
242
243			var apierr *anthropic.Error
244			if !errors.As(err, &apierr) {
245				eventChan <- ProviderEvent{Type: EventError, Error: err}
246				return
247			}
248
249			if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
250				eventChan <- ProviderEvent{Type: EventError, Error: err}
251				return
252			}
253
254			if attempts > maxRetries {
255				eventChan <- ProviderEvent{
256					Type:  EventError,
257					Error: errors.New("maximum retry attempts reached for rate limit (429)"),
258				}
259				return
260			}
261
262			retryMs := 0
263			retryAfterValues := apierr.Response.Header.Values("Retry-After")
264			if len(retryAfterValues) > 0 {
265				var retryAfterSec int
266				if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
267					retryMs = retryAfterSec * 1000
268					eventChan <- ProviderEvent{
269						Type: EventWarning,
270						Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
271					}
272				}
273			} else {
274				eventChan <- ProviderEvent{
275					Type: EventWarning,
276					Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
277				}
278
279				backoffMs := 2000 * (1 << (attempts - 1))
280				jitterMs := int(float64(backoffMs) * 0.2)
281				retryMs = backoffMs + jitterMs
282			}
283			select {
284			case <-ctx.Done():
285				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
286				return
287			case <-time.After(time.Duration(retryMs) * time.Millisecond):
288				continue
289			}
290
291		}
292	}()
293
294	return eventChan, nil
295}
296
297func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
298	var toolCalls []message.ToolCall
299
300	for _, block := range content {
301		switch variant := block.AsAny().(type) {
302		case anthropic.ToolUseBlock:
303			toolCall := message.ToolCall{
304				ID:    variant.ID,
305				Name:  variant.Name,
306				Input: string(variant.Input),
307				Type:  string(variant.Type),
308			}
309			toolCalls = append(toolCalls, toolCall)
310		}
311	}
312
313	return toolCalls
314}
315
316func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
317	return TokenUsage{
318		InputTokens:         usage.InputTokens,
319		OutputTokens:        usage.OutputTokens,
320		CacheCreationTokens: usage.CacheCreationInputTokens,
321		CacheReadTokens:     usage.CacheReadInputTokens,
322	}
323}
324
325func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
326	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
327
328	for i, tool := range tools {
329		info := tool.Info()
330		toolParam := anthropic.ToolParam{
331			Name:        info.Name,
332			Description: anthropic.String(info.Description),
333			InputSchema: anthropic.ToolInputSchemaParam{
334				Properties: info.Parameters,
335			},
336		}
337
338		if i == len(tools)-1 && !a.disableCache {
339			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
340				Type: "ephemeral",
341			}
342		}
343
344		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
345	}
346
347	return anthropicTools
348}
349
350func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
351	anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
352	cachedBlocks := 0
353
354	for _, msg := range messages {
355		switch msg.Role {
356		case message.User:
357			content := anthropic.NewTextBlock(msg.Content().String())
358			if cachedBlocks < 2 && !a.disableCache {
359				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
360					Type: "ephemeral",
361				}
362				cachedBlocks++
363			}
364			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
365
366		case message.Assistant:
367			blocks := []anthropic.ContentBlockParamUnion{}
368			if msg.Content().String() != "" {
369				content := anthropic.NewTextBlock(msg.Content().String())
370				if cachedBlocks < 2 && !a.disableCache {
371					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
372						Type: "ephemeral",
373					}
374					cachedBlocks++
375				}
376				blocks = append(blocks, content)
377			}
378
379			for _, toolCall := range msg.ToolCalls() {
380				var inputMap map[string]any
381				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
382				if err != nil {
383					continue
384				}
385				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
386			}
387
388			if len(blocks) > 0 {
389				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
390			}
391
392		case message.Tool:
393			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
394			for i, toolResult := range msg.ToolResults() {
395				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
396			}
397			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
398		}
399	}
400
401	return anthropicMessages
402}