anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"strings"
 10	"time"
 11
 12	"github.com/anthropics/anthropic-sdk-go"
 13	"github.com/anthropics/anthropic-sdk-go/bedrock"
 14	"github.com/anthropics/anthropic-sdk-go/option"
 15	"github.com/kujtimiihoxha/termai/internal/llm/models"
 16	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 17	"github.com/kujtimiihoxha/termai/internal/message"
 18)
 19
 20type anthropicProvider struct {
 21	client        anthropic.Client
 22	model         models.Model
 23	maxTokens     int64
 24	apiKey        string
 25	systemMessage string
 26	useBedrock    bool
 27	disableCache  bool
 28}
 29
 30type AnthropicOption func(*anthropicProvider)
 31
 32func WithAnthropicSystemMessage(message string) AnthropicOption {
 33	return func(a *anthropicProvider) {
 34		a.systemMessage = message
 35	}
 36}
 37
 38func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 39	return func(a *anthropicProvider) {
 40		a.maxTokens = maxTokens
 41	}
 42}
 43
 44func WithAnthropicModel(model models.Model) AnthropicOption {
 45	return func(a *anthropicProvider) {
 46		a.model = model
 47	}
 48}
 49
 50func WithAnthropicKey(apiKey string) AnthropicOption {
 51	return func(a *anthropicProvider) {
 52		a.apiKey = apiKey
 53	}
 54}
 55
 56func WithAnthropicBedrock() AnthropicOption {
 57	return func(a *anthropicProvider) {
 58		a.useBedrock = true
 59	}
 60}
 61
 62func WithAnthropicDisableCache() AnthropicOption {
 63	return func(a *anthropicProvider) {
 64		a.disableCache = true
 65	}
 66}
 67
 68func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 69	provider := &anthropicProvider{
 70		maxTokens: 1024,
 71	}
 72
 73	for _, opt := range opts {
 74		opt(provider)
 75	}
 76
 77	if provider.systemMessage == "" {
 78		return nil, errors.New("system message is required")
 79	}
 80
 81	anthropicOptions := []option.RequestOption{}
 82
 83	if provider.apiKey != "" {
 84		anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
 85	}
 86	if provider.useBedrock {
 87		anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
 88	}
 89
 90	provider.client = anthropic.NewClient(anthropicOptions...)
 91	return provider, nil
 92}
 93
 94func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 95	anthropicMessages := a.convertToAnthropicMessages(messages)
 96	anthropicTools := a.convertToAnthropicTools(tools)
 97
 98	response, err := a.client.Messages.New(
 99		ctx,
100		anthropic.MessageNewParams{
101			Model:       anthropic.Model(a.model.APIModel),
102			MaxTokens:   a.maxTokens,
103			Temperature: anthropic.Float(0),
104			Messages:    anthropicMessages,
105			Tools:       anthropicTools,
106			System: []anthropic.TextBlockParam{
107				{
108					Text: a.systemMessage,
109					CacheControl: anthropic.CacheControlEphemeralParam{
110						Type: "ephemeral",
111					},
112				},
113			},
114		},
115	)
116	if err != nil {
117		return nil, err
118	}
119
120	content := ""
121	for _, block := range response.Content {
122		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
123			content += text.Text
124		}
125	}
126
127	toolCalls := a.extractToolCalls(response.Content)
128	tokenUsage := a.extractTokenUsage(response.Usage)
129
130	return &ProviderResponse{
131		Content:   content,
132		ToolCalls: toolCalls,
133		Usage:     tokenUsage,
134	}, nil
135}
136
137func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
138	anthropicMessages := a.convertToAnthropicMessages(messages)
139	anthropicTools := a.convertToAnthropicTools(tools)
140
141	var thinkingParam anthropic.ThinkingConfigParamUnion
142	lastMessage := messages[len(messages)-1]
143	temperature := anthropic.Float(0)
144	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
145		thinkingParam = anthropic.ThinkingConfigParamUnion{
146			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
147				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
148				Type:         "enabled",
149			},
150		}
151		temperature = anthropic.Float(1)
152	}
153
154	eventChan := make(chan ProviderEvent)
155
156	go func() {
157		defer close(eventChan)
158
159		const maxRetries = 8
160		attempts := 0
161
162		for {
163
164			attempts++
165
166			stream := a.client.Messages.NewStreaming(
167				ctx,
168				anthropic.MessageNewParams{
169					Model:       anthropic.Model(a.model.APIModel),
170					MaxTokens:   a.maxTokens,
171					Temperature: temperature,
172					Messages:    anthropicMessages,
173					Tools:       anthropicTools,
174					Thinking:    thinkingParam,
175					System: []anthropic.TextBlockParam{
176						{
177							Text: a.systemMessage,
178							CacheControl: anthropic.CacheControlEphemeralParam{
179								Type: "ephemeral",
180							},
181						},
182					},
183				},
184			)
185
186			accumulatedMessage := anthropic.Message{}
187
188			for stream.Next() {
189				event := stream.Current()
190				err := accumulatedMessage.Accumulate(event)
191				if err != nil {
192					eventChan <- ProviderEvent{Type: EventError, Error: err}
193					return // Don't retry on accumulation errors
194				}
195
196				switch event := event.AsAny().(type) {
197				case anthropic.ContentBlockStartEvent:
198					eventChan <- ProviderEvent{Type: EventContentStart}
199
200				case anthropic.ContentBlockDeltaEvent:
201					if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
202						eventChan <- ProviderEvent{
203							Type:     EventThinkingDelta,
204							Thinking: event.Delta.Thinking,
205						}
206					} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
207						eventChan <- ProviderEvent{
208							Type:    EventContentDelta,
209							Content: event.Delta.Text,
210						}
211					}
212
213				case anthropic.ContentBlockStopEvent:
214					eventChan <- ProviderEvent{Type: EventContentStop}
215
216				case anthropic.MessageStopEvent:
217					content := ""
218					for _, block := range accumulatedMessage.Content {
219						if text, ok := block.AsAny().(anthropic.TextBlock); ok {
220							content += text.Text
221						}
222					}
223
224					toolCalls := a.extractToolCalls(accumulatedMessage.Content)
225					tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
226
227					eventChan <- ProviderEvent{
228						Type: EventComplete,
229						Response: &ProviderResponse{
230							Content:      content,
231							ToolCalls:    toolCalls,
232							Usage:        tokenUsage,
233							FinishReason: string(accumulatedMessage.StopReason),
234						},
235					}
236				}
237			}
238
239			err := stream.Err()
240			if err == nil || errors.Is(err, io.EOF) {
241				return
242			}
243
244			var apierr *anthropic.Error
245			if !errors.As(err, &apierr) {
246				eventChan <- ProviderEvent{Type: EventError, Error: err}
247				return
248			}
249
250			if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
251				eventChan <- ProviderEvent{Type: EventError, Error: err}
252				return
253			}
254
255			if attempts > maxRetries {
256				eventChan <- ProviderEvent{
257					Type:  EventError,
258					Error: errors.New("maximum retry attempts reached for rate limit (429)"),
259				}
260				return
261			}
262
263			retryMs := 0
264			retryAfterValues := apierr.Response.Header.Values("Retry-After")
265			if len(retryAfterValues) > 0 {
266				var retryAfterSec int
267				if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
268					retryMs = retryAfterSec * 1000
269					eventChan <- ProviderEvent{
270						Type: EventWarning,
271						Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
272					}
273				}
274			} else {
275				eventChan <- ProviderEvent{
276					Type: EventWarning,
277					Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
278				}
279
280				backoffMs := 2000 * (1 << (attempts - 1))
281				jitterMs := int(float64(backoffMs) * 0.2)
282				retryMs = backoffMs + jitterMs
283			}
284			select {
285			case <-ctx.Done():
286				eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
287				return
288			case <-time.After(time.Duration(retryMs) * time.Millisecond):
289				continue
290			}
291
292		}
293	}()
294
295	return eventChan, nil
296}
297
298func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
299	var toolCalls []message.ToolCall
300
301	for _, block := range content {
302		switch variant := block.AsAny().(type) {
303		case anthropic.ToolUseBlock:
304			toolCall := message.ToolCall{
305				ID:    variant.ID,
306				Name:  variant.Name,
307				Input: string(variant.Input),
308				Type:  string(variant.Type),
309			}
310			toolCalls = append(toolCalls, toolCall)
311		}
312	}
313
314	return toolCalls
315}
316
317func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
318	return TokenUsage{
319		InputTokens:         usage.InputTokens,
320		OutputTokens:        usage.OutputTokens,
321		CacheCreationTokens: usage.CacheCreationInputTokens,
322		CacheReadTokens:     usage.CacheReadInputTokens,
323	}
324}
325
326func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
327	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
328
329	for i, tool := range tools {
330		info := tool.Info()
331		toolParam := anthropic.ToolParam{
332			Name:        info.Name,
333			Description: anthropic.String(info.Description),
334			InputSchema: anthropic.ToolInputSchemaParam{
335				Properties: info.Parameters,
336			},
337		}
338
339		if i == len(tools)-1 && !a.disableCache {
340			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
341				Type: "ephemeral",
342			}
343		}
344
345		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
346	}
347
348	return anthropicTools
349}
350
351func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
352	anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
353	cachedBlocks := 0
354
355	for _, msg := range messages {
356		switch msg.Role {
357		case message.User:
358			content := anthropic.NewTextBlock(msg.Content().String())
359			if cachedBlocks < 2 && !a.disableCache {
360				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
361					Type: "ephemeral",
362				}
363				cachedBlocks++
364			}
365			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
366
367		case message.Assistant:
368			blocks := []anthropic.ContentBlockParamUnion{}
369			if msg.Content().String() != "" {
370				content := anthropic.NewTextBlock(msg.Content().String())
371				if cachedBlocks < 2 && !a.disableCache {
372					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
373						Type: "ephemeral",
374					}
375					cachedBlocks++
376				}
377				blocks = append(blocks, content)
378			}
379
380			for _, toolCall := range msg.ToolCalls() {
381				var inputMap map[string]any
382				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
383				if err != nil {
384					continue
385				}
386				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
387			}
388
389			if len(blocks) > 0 {
390				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
391			}
392
393		case message.Tool:
394			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
395			for i, toolResult := range msg.ToolResults() {
396				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
397			}
398			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
399		}
400	}
401
402	return anthropicMessages
403}