anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/anthropics/anthropic-sdk-go"
 13	"github.com/anthropics/anthropic-sdk-go/bedrock"
 14	"github.com/anthropics/anthropic-sdk-go/option"
 15	"github.com/kujtimiihoxha/termai/internal/llm/models"
 16	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 17	"github.com/kujtimiihoxha/termai/internal/message"
 18)
 19
 20type anthropicProvider struct {
 21	client        anthropic.Client
 22	model         models.Model
 23	maxTokens     int64
 24	apiKey        string
 25	systemMessage string
 26	useBedrock    bool
 27	disableCache  bool
 28}
 29
 30type AnthropicOption func(*anthropicProvider)
 31
 32func WithAnthropicSystemMessage(message string) AnthropicOption {
 33	return func(a *anthropicProvider) {
 34		a.systemMessage = message
 35	}
 36}
 37
 38func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 39	return func(a *anthropicProvider) {
 40		a.maxTokens = maxTokens
 41	}
 42}
 43
 44func WithAnthropicModel(model models.Model) AnthropicOption {
 45	return func(a *anthropicProvider) {
 46		a.model = model
 47	}
 48}
 49
 50func WithAnthropicKey(apiKey string) AnthropicOption {
 51	return func(a *anthropicProvider) {
 52		a.apiKey = apiKey
 53	}
 54}
 55
 56func WithAnthropicBedrock() AnthropicOption {
 57	return func(a *anthropicProvider) {
 58		a.useBedrock = true
 59	}
 60}
 61
 62func WithAnthropicDisableCache() AnthropicOption {
 63	return func(a *anthropicProvider) {
 64		a.disableCache = true
 65	}
 66}
 67
 68func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 69	provider := &anthropicProvider{
 70		maxTokens: 1024,
 71	}
 72
 73	for _, opt := range opts {
 74		opt(provider)
 75	}
 76
 77	if provider.systemMessage == "" {
 78		return nil, errors.New("system message is required")
 79	}
 80
 81	anthropicOptions := []option.RequestOption{}
 82
 83	if provider.apiKey != "" {
 84		anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
 85	}
 86	if provider.useBedrock {
 87		anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 88	}
 89
 90	provider.client = anthropic.NewClient(anthropicOptions...)
 91	return provider, nil
 92}
 93
 94func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 95	messages = cleanupMessages(messages)
 96	anthropicMessages := a.convertToAnthropicMessages(messages)
 97	anthropicTools := a.convertToAnthropicTools(tools)
 98
 99	response, err := a.client.Messages.New(
100		ctx,
101		anthropic.MessageNewParams{
102			Model:       anthropic.Model(a.model.APIModel),
103			MaxTokens:   a.maxTokens,
104			Temperature: anthropic.Float(0),
105			Messages:    anthropicMessages,
106			Tools:       anthropicTools,
107			System: []anthropic.TextBlockParam{
108				{
109					Text: a.systemMessage,
110					CacheControl: anthropic.CacheControlEphemeralParam{
111						Type: "ephemeral",
112					},
113				},
114			},
115		},
116	)
117	if err != nil {
118		return nil, err
119	}
120
121	content := ""
122	for _, block := range response.Content {
123		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
124			content += text.Text
125		}
126	}
127
128	toolCalls := a.extractToolCalls(response.Content)
129	tokenUsage := a.extractTokenUsage(response.Usage)
130
131	return &ProviderResponse{
132		Content:   content,
133		ToolCalls: toolCalls,
134		Usage:     tokenUsage,
135	}, nil
136}
137
138func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
139	messages = cleanupMessages(messages)
140	anthropicMessages := a.convertToAnthropicMessages(messages)
141	anthropicTools := a.convertToAnthropicTools(tools)
142
143	var thinkingParam anthropic.ThinkingConfigParamUnion
144	lastMessage := messages[len(messages)-1]
145	temperature := anthropic.Float(0)
146	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
147		thinkingParam = anthropic.ThinkingConfigParamUnion{
148			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
149				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
150				Type:         "enabled",
151			},
152		}
153		temperature = anthropic.Float(1)
154	}
155
156	eventChan := make(chan ProviderEvent)
157
158	go func() {
159		defer close(eventChan)
160
161		const maxRetries = 8
162		attempts := 0
163
164		for {
165
166			attempts++
167
168			stream := a.client.Messages.NewStreaming(
169				ctx,
170				anthropic.MessageNewParams{
171					Model:       anthropic.Model(a.model.APIModel),
172					MaxTokens:   a.maxTokens,
173					Temperature: temperature,
174					Messages:    anthropicMessages,
175					Tools:       anthropicTools,
176					Thinking:    thinkingParam,
177					System: []anthropic.TextBlockParam{
178						{
179							Text: a.systemMessage,
180							CacheControl: anthropic.CacheControlEphemeralParam{
181								Type: "ephemeral",
182							},
183						},
184					},
185				},
186			)
187
188			accumulatedMessage := anthropic.Message{}
189
190			for stream.Next() {
191				event := stream.Current()
192				err := accumulatedMessage.Accumulate(event)
193				if err != nil {
194					eventChan <- ProviderEvent{Type: EventError, Error: err}
195					return // Don't retry on accumulation errors
196				}
197
198				switch event := event.AsAny().(type) {
199				case anthropic.ContentBlockStartEvent:
200					eventChan <- ProviderEvent{Type: EventContentStart}
201
202				case anthropic.ContentBlockDeltaEvent:
203					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
204						eventChan <- ProviderEvent{
205							Type:     EventThinkingDelta,
206							Thinking: event.Delta.Thinking,
207						}
208					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
209						eventChan <- ProviderEvent{
210							Type:    EventContentDelta,
211							Content: event.Delta.Text,
212						}
213					}
214
215				case anthropic.ContentBlockStopEvent:
216					eventChan <- ProviderEvent{Type: EventContentStop}
217
218				case anthropic.MessageStopEvent:
219					content := ""
220					for _, block := range accumulatedMessage.Content {
221						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
222							content += text.Text
223						}
224					}
225
226					toolCalls := a.extractToolCalls(accumulatedMessage.Content)
227					tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
228
229					eventChan <- ProviderEvent{
230						Type: EventComplete,
231						Response: &ProviderResponse{
232							Content:      content,
233							ToolCalls:    toolCalls,
234							Usage:        tokenUsage,
235							FinishReason: string(accumulatedMessage.StopReason),
236						},
237					}
238				}
239			}
240
241			err := stream.Err()
242			if err == nil || errors.Is(err, io.EOF) {
243				return
244			}
245
246			var apierr *anthropic.Error
247			if !errors.As(err, &apierr) {
248				eventChan <- ProviderEvent{Type: EventError, Error: err}
249				return
250			}
251
252			if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
253				eventChan <- ProviderEvent{Type: EventError, Error: err}
254				return
255			}
256
257			if attempts > maxRetries {
258				eventChan <- ProviderEvent{
259					Type:  EventError,
260					Error: errors.New("maximum retry attempts reached for rate limit (429)"),
261				}
262				return
263			}
264
265			retryMs := 0
266			retryAfterValues := apierr.Response.Header.Values("Retry-After")
267			if len(retryAfterValues) > 0 {
268				var retryAfterSec int
269				if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
270					retryMs = retryAfterSec * 1000
271					eventChan <- ProviderEvent{
272						Type: EventWarning,
273						Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
274					}
275				}
276			} else {
277				eventChan <- ProviderEvent{
278					Type: EventWarning,
279					Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
280				}
281
282				backoffMs := 2000 * (1 << (attempts - 1))
283				jitterMs := int(float64(backoffMs) * 0.2)
284				retryMs = backoffMs + jitterMs
285			}
286			select {
287			case <-ctx.Done():
288				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
289				return
290			case <-time.After(time.Duration(retryMs) * time.Millisecond):
291				continue
292			}
293
294		}
295	}()
296
297	return eventChan, nil
298}
299
300func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
301	var toolCalls []message.ToolCall
302
303	for _, block := range content {
304		switch variant := block.AsAny().(type) {
305		case anthropic.ToolUseBlock:
306			toolCall := message.ToolCall{
307				ID:    variant.ID,
308				Name:  variant.Name,
309				Input: string(variant.Input),
310				Type:  string(variant.Type),
311			}
312			toolCalls = append(toolCalls, toolCall)
313		}
314	}
315
316	return toolCalls
317}
318
319func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
320	return TokenUsage{
321		InputTokens:         usage.InputTokens,
322		OutputTokens:        usage.OutputTokens,
323		CacheCreationTokens: usage.CacheCreationInputTokens,
324		CacheReadTokens:     usage.CacheReadInputTokens,
325	}
326}
327
328func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
329	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
330
331	for i, tool := range tools {
332		info := tool.Info()
333		toolParam := anthropic.ToolParam{
334			Name:        info.Name,
335			Description: anthropic.String(info.Description),
336			InputSchema: anthropic.ToolInputSchemaParam{
337				Properties: info.Parameters,
338			},
339		}
340
341		if i == len(tools)-1 && !a.disableCache {
342			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
343				Type: "ephemeral",
344			}
345		}
346
347		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
348	}
349
350	return anthropicTools
351}
352
353func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
354	anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
355	cachedBlocks := 0
356
357	for _, msg := range messages {
358		switch msg.Role {
359		case message.User:
360			content := anthropic.NewTextBlock(msg.Content().String())
361			if cachedBlocks < 2 && !a.disableCache {
362				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
363					Type: "ephemeral",
364				}
365				cachedBlocks++
366			}
367			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
368
369		case message.Assistant:
370			blocks := []anthropic.ContentBlockParamUnion{}
371			if msg.Content().String() != "" {
372				content := anthropic.NewTextBlock(msg.Content().String())
373				if cachedBlocks < 2 && !a.disableCache {
374					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
375						Type: "ephemeral",
376					}
377					cachedBlocks++
378				}
379				blocks = append(blocks, content)
380			}
381
382			for _, toolCall := range msg.ToolCalls() {
383				var inputMap map[string]any
384				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
385				if err != nil {
386					continue
387				}
388				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
389			}
390
391			if len(blocks) > 0 {
392				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
393			}
394
395		case message.Tool:
396			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
397			for i, toolResult := range msg.ToolResults() {
398				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
399			}
400			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
401		}
402	}
403
404	return anthropicMessages
405}