1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/anthropics/anthropic-sdk-go"
13 "github.com/anthropics/anthropic-sdk-go/bedrock"
14 "github.com/anthropics/anthropic-sdk-go/option"
15 "github.com/kujtimiihoxha/termai/internal/llm/models"
16 "github.com/kujtimiihoxha/termai/internal/llm/tools"
17 "github.com/kujtimiihoxha/termai/internal/message"
18)
19
20type anthropicProvider struct {
21 client anthropic.Client
22 model models.Model
23 maxTokens int64
24 apiKey string
25 systemMessage string
26 useBedrock bool
27 disableCache bool
28}
29
30type AnthropicOption func(*anthropicProvider)
31
32func WithAnthropicSystemMessage(message string) AnthropicOption {
33 return func(a *anthropicProvider) {
34 a.systemMessage = message
35 }
36}
37
38func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
39 return func(a *anthropicProvider) {
40 a.maxTokens = maxTokens
41 }
42}
43
44func WithAnthropicModel(model models.Model) AnthropicOption {
45 return func(a *anthropicProvider) {
46 a.model = model
47 }
48}
49
50func WithAnthropicKey(apiKey string) AnthropicOption {
51 return func(a *anthropicProvider) {
52 a.apiKey = apiKey
53 }
54}
55
56func WithAnthropicBedrock() AnthropicOption {
57 return func(a *anthropicProvider) {
58 a.useBedrock = true
59 }
60}
61
62func WithAnthropicDisableCache() AnthropicOption {
63 return func(a *anthropicProvider) {
64 a.disableCache = true
65 }
66}
67
68func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
69 provider := &anthropicProvider{
70 maxTokens: 1024,
71 }
72
73 for _, opt := range opts {
74 opt(provider)
75 }
76
77 if provider.systemMessage == "" {
78 return nil, errors.New("system message is required")
79 }
80
81 anthropicOptions := []option.RequestOption{}
82
83 if provider.apiKey != "" {
84 anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
85 }
86 if provider.useBedrock {
87 anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
88 }
89
90 provider.client = anthropic.NewClient(anthropicOptions...)
91 return provider, nil
92}
93
94func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
95 messages = cleanupMessages(messages)
96 anthropicMessages := a.convertToAnthropicMessages(messages)
97 anthropicTools := a.convertToAnthropicTools(tools)
98
99 response, err := a.client.Messages.New(
100 ctx,
101 anthropic.MessageNewParams{
102 Model: anthropic.Model(a.model.APIModel),
103 MaxTokens: a.maxTokens,
104 Temperature: anthropic.Float(0),
105 Messages: anthropicMessages,
106 Tools: anthropicTools,
107 System: []anthropic.TextBlockParam{
108 {
109 Text: a.systemMessage,
110 CacheControl: anthropic.CacheControlEphemeralParam{
111 Type: "ephemeral",
112 },
113 },
114 },
115 },
116 )
117 if err != nil {
118 return nil, err
119 }
120
121 content := ""
122 for _, block := range response.Content {
123 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
124 content += text.Text
125 }
126 }
127
128 toolCalls := a.extractToolCalls(response.Content)
129 tokenUsage := a.extractTokenUsage(response.Usage)
130
131 return &ProviderResponse{
132 Content: content,
133 ToolCalls: toolCalls,
134 Usage: tokenUsage,
135 }, nil
136}
137
138func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
139 messages = cleanupMessages(messages)
140 anthropicMessages := a.convertToAnthropicMessages(messages)
141 anthropicTools := a.convertToAnthropicTools(tools)
142
143 var thinkingParam anthropic.ThinkingConfigParamUnion
144 lastMessage := messages[len(messages)-1]
145 temperature := anthropic.Float(0)
146 if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
147 thinkingParam = anthropic.ThinkingConfigParamUnion{
148 OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
149 BudgetTokens: int64(float64(a.maxTokens) * 0.8),
150 Type: "enabled",
151 },
152 }
153 temperature = anthropic.Float(1)
154 }
155
156 eventChan := make(chan ProviderEvent)
157
158 go func() {
159 defer close(eventChan)
160
161 const maxRetries = 8
162 attempts := 0
163
164 for {
165
166 attempts++
167
168 stream := a.client.Messages.NewStreaming(
169 ctx,
170 anthropic.MessageNewParams{
171 Model: anthropic.Model(a.model.APIModel),
172 MaxTokens: a.maxTokens,
173 Temperature: temperature,
174 Messages: anthropicMessages,
175 Tools: anthropicTools,
176 Thinking: thinkingParam,
177 System: []anthropic.TextBlockParam{
178 {
179 Text: a.systemMessage,
180 CacheControl: anthropic.CacheControlEphemeralParam{
181 Type: "ephemeral",
182 },
183 },
184 },
185 },
186 )
187
188 accumulatedMessage := anthropic.Message{}
189
190 for stream.Next() {
191 event := stream.Current()
192 err := accumulatedMessage.Accumulate(event)
193 if err != nil {
194 eventChan <- ProviderEvent{Type: EventError, Error: err}
195 return // Don't retry on accumulation errors
196 }
197
198 switch event := event.AsAny().(type) {
199 case anthropic.ContentBlockStartEvent:
200 eventChan <- ProviderEvent{Type: EventContentStart}
201
202 case anthropic.ContentBlockDeltaEvent:
203 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
204 eventChan <- ProviderEvent{
205 Type: EventThinkingDelta,
206 Thinking: event.Delta.Thinking,
207 }
208 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
209 eventChan <- ProviderEvent{
210 Type: EventContentDelta,
211 Content: event.Delta.Text,
212 }
213 }
214
215 case anthropic.ContentBlockStopEvent:
216 eventChan <- ProviderEvent{Type: EventContentStop}
217
218 case anthropic.MessageStopEvent:
219 content := ""
220 for _, block := range accumulatedMessage.Content {
221 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
222 content += text.Text
223 }
224 }
225
226 toolCalls := a.extractToolCalls(accumulatedMessage.Content)
227 tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
228
229 eventChan <- ProviderEvent{
230 Type: EventComplete,
231 Response: &ProviderResponse{
232 Content: content,
233 ToolCalls: toolCalls,
234 Usage: tokenUsage,
235 FinishReason: string(accumulatedMessage.StopReason),
236 },
237 }
238 }
239 }
240
241 err := stream.Err()
242 if err == nil || errors.Is(err, io.EOF) {
243 return
244 }
245
246 var apierr *anthropic.Error
247 if !errors.As(err, &apierr) {
248 eventChan <- ProviderEvent{Type: EventError, Error: err}
249 return
250 }
251
252 if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
253 eventChan <- ProviderEvent{Type: EventError, Error: err}
254 return
255 }
256
257 if attempts > maxRetries {
258 eventChan <- ProviderEvent{
259 Type: EventError,
260 Error: errors.New("maximum retry attempts reached for rate limit (429)"),
261 }
262 return
263 }
264
265 retryMs := 0
266 retryAfterValues := apierr.Response.Header.Values("Retry-After")
267 if len(retryAfterValues) > 0 {
268 var retryAfterSec int
269 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
270 retryMs = retryAfterSec * 1000
271 eventChan <- ProviderEvent{
272 Type: EventWarning,
273 Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
274 }
275 }
276 } else {
277 eventChan <- ProviderEvent{
278 Type: EventWarning,
279 Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
280 }
281
282 backoffMs := 2000 * (1 << (attempts - 1))
283 jitterMs := int(float64(backoffMs) * 0.2)
284 retryMs = backoffMs + jitterMs
285 }
286 select {
287 case <-ctx.Done():
288 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
289 return
290 case <-time.After(time.Duration(retryMs) * time.Millisecond):
291 continue
292 }
293
294 }
295 }()
296
297 return eventChan, nil
298}
299
300func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
301 var toolCalls []message.ToolCall
302
303 for _, block := range content {
304 switch variant := block.AsAny().(type) {
305 case anthropic.ToolUseBlock:
306 toolCall := message.ToolCall{
307 ID: variant.ID,
308 Name: variant.Name,
309 Input: string(variant.Input),
310 Type: string(variant.Type),
311 }
312 toolCalls = append(toolCalls, toolCall)
313 }
314 }
315
316 return toolCalls
317}
318
319func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
320 return TokenUsage{
321 InputTokens: usage.InputTokens,
322 OutputTokens: usage.OutputTokens,
323 CacheCreationTokens: usage.CacheCreationInputTokens,
324 CacheReadTokens: usage.CacheReadInputTokens,
325 }
326}
327
328func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
329 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
330
331 for i, tool := range tools {
332 info := tool.Info()
333 toolParam := anthropic.ToolParam{
334 Name: info.Name,
335 Description: anthropic.String(info.Description),
336 InputSchema: anthropic.ToolInputSchemaParam{
337 Properties: info.Parameters,
338 },
339 }
340
341 if i == len(tools)-1 && !a.disableCache {
342 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
343 Type: "ephemeral",
344 }
345 }
346
347 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
348 }
349
350 return anthropicTools
351}
352
353func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
354 anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
355 cachedBlocks := 0
356
357 for _, msg := range messages {
358 switch msg.Role {
359 case message.User:
360 content := anthropic.NewTextBlock(msg.Content().String())
361 if cachedBlocks < 2 && !a.disableCache {
362 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
363 Type: "ephemeral",
364 }
365 cachedBlocks++
366 }
367 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
368
369 case message.Assistant:
370 blocks := []anthropic.ContentBlockParamUnion{}
371 if msg.Content().String() != "" {
372 content := anthropic.NewTextBlock(msg.Content().String())
373 if cachedBlocks < 2 && !a.disableCache {
374 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
375 Type: "ephemeral",
376 }
377 cachedBlocks++
378 }
379 blocks = append(blocks, content)
380 }
381
382 for _, toolCall := range msg.ToolCalls() {
383 var inputMap map[string]any
384 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
385 if err != nil {
386 continue
387 }
388 blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
389 }
390
391 if len(blocks) > 0 {
392 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
393 }
394
395 case message.Tool:
396 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
397 for i, toolResult := range msg.ToolResults() {
398 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
399 }
400 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
401 }
402 }
403
404 return anthropicMessages
405}