1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/anthropics/anthropic-sdk-go"
12 "github.com/anthropics/anthropic-sdk-go/bedrock"
13 "github.com/anthropics/anthropic-sdk-go/option"
14 "github.com/kujtimiihoxha/termai/internal/llm/models"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/message"
17)
18
19type anthropicProvider struct {
20 client anthropic.Client
21 model models.Model
22 maxTokens int64
23 apiKey string
24 systemMessage string
25 useBedrock bool
26 disableCache bool
27}
28
29type AnthropicOption func(*anthropicProvider)
30
31func WithAnthropicSystemMessage(message string) AnthropicOption {
32 return func(a *anthropicProvider) {
33 a.systemMessage = message
34 }
35}
36
37func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
38 return func(a *anthropicProvider) {
39 a.maxTokens = maxTokens
40 }
41}
42
43func WithAnthropicModel(model models.Model) AnthropicOption {
44 return func(a *anthropicProvider) {
45 a.model = model
46 }
47}
48
49func WithAnthropicKey(apiKey string) AnthropicOption {
50 return func(a *anthropicProvider) {
51 a.apiKey = apiKey
52 }
53}
54
55func WithAnthropicBedrock() AnthropicOption {
56 return func(a *anthropicProvider) {
57 a.useBedrock = true
58 }
59}
60
61func WithAnthropicDisableCache() AnthropicOption {
62 return func(a *anthropicProvider) {
63 a.disableCache = true
64 }
65}
66
67func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
68 provider := &anthropicProvider{
69 maxTokens: 1024,
70 }
71
72 for _, opt := range opts {
73 opt(provider)
74 }
75
76 if provider.systemMessage == "" {
77 return nil, errors.New("system message is required")
78 }
79
80 anthropicOptions := []option.RequestOption{}
81
82 if provider.apiKey != "" {
83 anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
84 }
85 if provider.useBedrock {
86 anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
87 }
88
89 provider.client = anthropic.NewClient(anthropicOptions...)
90 return provider, nil
91}
92
93func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
94 anthropicMessages := a.convertToAnthropicMessages(messages)
95 anthropicTools := a.convertToAnthropicTools(tools)
96
97 response, err := a.client.Messages.New(
98 ctx,
99 anthropic.MessageNewParams{
100 Model: anthropic.Model(a.model.APIModel),
101 MaxTokens: a.maxTokens,
102 Temperature: anthropic.Float(0),
103 Messages: anthropicMessages,
104 Tools: anthropicTools,
105 System: []anthropic.TextBlockParam{
106 {
107 Text: a.systemMessage,
108 CacheControl: anthropic.CacheControlEphemeralParam{
109 Type: "ephemeral",
110 },
111 },
112 },
113 },
114 )
115 if err != nil {
116 return nil, err
117 }
118
119 content := ""
120 for _, block := range response.Content {
121 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
122 content += text.Text
123 }
124 }
125
126 toolCalls := a.extractToolCalls(response.Content)
127 tokenUsage := a.extractTokenUsage(response.Usage)
128
129 return &ProviderResponse{
130 Content: content,
131 ToolCalls: toolCalls,
132 Usage: tokenUsage,
133 }, nil
134}
135
136func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
137 anthropicMessages := a.convertToAnthropicMessages(messages)
138 anthropicTools := a.convertToAnthropicTools(tools)
139
140 var thinkingParam anthropic.ThinkingConfigParamUnion
141 lastMessage := messages[len(messages)-1]
142 temperature := anthropic.Float(0)
143 if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
144 thinkingParam = anthropic.ThinkingConfigParamUnion{
145 OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
146 BudgetTokens: int64(float64(a.maxTokens) * 0.8),
147 Type: "enabled",
148 },
149 }
150 temperature = anthropic.Float(1)
151 }
152
153 eventChan := make(chan ProviderEvent)
154
155 go func() {
156 defer close(eventChan)
157
158 const maxRetries = 8
159 attempts := 0
160
161 for {
162 // If this isn't the first attempt, we're retrying
163 if attempts > 0 {
164 if attempts > maxRetries {
165 eventChan <- ProviderEvent{
166 Type: EventError,
167 Error: errors.New("maximum retry attempts reached for rate limit (429)"),
168 }
169 return
170 }
171
172 // Inform user we're retrying with attempt number
173 eventChan <- ProviderEvent{
174 Type: EventWarning,
175 Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
176 }
177
178 // Calculate backoff with exponential backoff and jitter
179 backoffMs := 2000 * (1 << (attempts - 1)) // 2s, 4s, 8s, 16s, 32s
180 jitterMs := int(float64(backoffMs) * 0.2)
181 totalBackoffMs := backoffMs + jitterMs
182
183 // Sleep with backoff, respecting context cancellation
184 select {
185 case <-ctx.Done():
186 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
187 return
188 case <-time.After(time.Duration(totalBackoffMs) * time.Millisecond):
189 // Continue with retry
190 }
191 }
192
193 attempts++
194
195 // Create new streaming request
196 stream := a.client.Messages.NewStreaming(
197 ctx,
198 anthropic.MessageNewParams{
199 Model: anthropic.Model(a.model.APIModel),
200 MaxTokens: a.maxTokens,
201 Temperature: temperature,
202 Messages: anthropicMessages,
203 Tools: anthropicTools,
204 Thinking: thinkingParam,
205 System: []anthropic.TextBlockParam{
206 {
207 Text: a.systemMessage,
208 CacheControl: anthropic.CacheControlEphemeralParam{
209 Type: "ephemeral",
210 },
211 },
212 },
213 },
214 )
215
216 // Process stream events
217 accumulatedMessage := anthropic.Message{}
218 streamSuccess := false
219
220 // Process the stream until completion or error
221 for stream.Next() {
222 event := stream.Current()
223 err := accumulatedMessage.Accumulate(event)
224 if err != nil {
225 eventChan <- ProviderEvent{Type: EventError, Error: err}
226 return // Don't retry on accumulation errors
227 }
228
229 switch event := event.AsAny().(type) {
230 case anthropic.ContentBlockStartEvent:
231 eventChan <- ProviderEvent{Type: EventContentStart}
232
233 case anthropic.ContentBlockDeltaEvent:
234 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
235 eventChan <- ProviderEvent{
236 Type: EventThinkingDelta,
237 Thinking: event.Delta.Thinking,
238 }
239 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
240 eventChan <- ProviderEvent{
241 Type: EventContentDelta,
242 Content: event.Delta.Text,
243 }
244 }
245
246 case anthropic.ContentBlockStopEvent:
247 eventChan <- ProviderEvent{Type: EventContentStop}
248
249 case anthropic.MessageStopEvent:
250 streamSuccess = true
251 content := ""
252 for _, block := range accumulatedMessage.Content {
253 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
254 content += text.Text
255 }
256 }
257
258 toolCalls := a.extractToolCalls(accumulatedMessage.Content)
259 tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
260
261 eventChan <- ProviderEvent{
262 Type: EventComplete,
263 Response: &ProviderResponse{
264 Content: content,
265 ToolCalls: toolCalls,
266 Usage: tokenUsage,
267 FinishReason: string(accumulatedMessage.StopReason),
268 },
269 }
270 }
271 }
272
273 // If the stream completed successfully, we're done
274 if streamSuccess {
275 return
276 }
277
278 // Check for stream errors
279 err := stream.Err()
280 if err != nil {
281 var apierr *anthropic.Error
282 if errors.As(err, &apierr) {
283 if apierr.StatusCode == 429 || apierr.StatusCode == 529 {
284 // Check for Retry-After header
285 if retryAfterValues := apierr.Response.Header.Values("Retry-After"); len(retryAfterValues) > 0 {
286 // Parse the retry after value (seconds)
287 var retryAfterSec int
288 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
289 retryMs := retryAfterSec * 1000
290
291 // Inform user of retry with specific wait time
292 eventChan <- ProviderEvent{
293 Type: EventWarning,
294 Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
295 }
296
297 // Sleep respecting context cancellation
298 select {
299 case <-ctx.Done():
300 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
301 return
302 case <-time.After(time.Duration(retryMs) * time.Millisecond):
303 // Continue with retry after specified delay
304 continue
305 }
306 }
307 }
308
309 // Fall back to exponential backoff if Retry-After parsing failed
310 continue
311 }
312 }
313
314 // For non-rate limit errors, report and exit
315 eventChan <- ProviderEvent{Type: EventError, Error: err}
316 return
317 }
318 }
319 }()
320
321 return eventChan, nil
322}
323
324func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
325 var toolCalls []message.ToolCall
326
327 for _, block := range content {
328 switch variant := block.AsAny().(type) {
329 case anthropic.ToolUseBlock:
330 toolCall := message.ToolCall{
331 ID: variant.ID,
332 Name: variant.Name,
333 Input: string(variant.Input),
334 Type: string(variant.Type),
335 }
336 toolCalls = append(toolCalls, toolCall)
337 }
338 }
339
340 return toolCalls
341}
342
343func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
344 return TokenUsage{
345 InputTokens: usage.InputTokens,
346 OutputTokens: usage.OutputTokens,
347 CacheCreationTokens: usage.CacheCreationInputTokens,
348 CacheReadTokens: usage.CacheReadInputTokens,
349 }
350}
351
352func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
353 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
354
355 for i, tool := range tools {
356 info := tool.Info()
357 toolParam := anthropic.ToolParam{
358 Name: info.Name,
359 Description: anthropic.String(info.Description),
360 InputSchema: anthropic.ToolInputSchemaParam{
361 Properties: info.Parameters,
362 },
363 }
364
365 if i == len(tools)-1 && !a.disableCache {
366 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
367 Type: "ephemeral",
368 }
369 }
370
371 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
372 }
373
374 return anthropicTools
375}
376
377func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
378 anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
379 cachedBlocks := 0
380
381 for _, msg := range messages {
382 switch msg.Role {
383 case message.User:
384 content := anthropic.NewTextBlock(msg.Content().String())
385 if cachedBlocks < 2 && !a.disableCache {
386 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
387 Type: "ephemeral",
388 }
389 cachedBlocks++
390 }
391 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
392
393 case message.Assistant:
394 blocks := []anthropic.ContentBlockParamUnion{}
395 if msg.Content().String() != "" {
396 content := anthropic.NewTextBlock(msg.Content().String())
397 if cachedBlocks < 2 && !a.disableCache {
398 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
399 Type: "ephemeral",
400 }
401 cachedBlocks++
402 }
403 blocks = append(blocks, content)
404 }
405
406 for _, toolCall := range msg.ToolCalls() {
407 var inputMap map[string]any
408 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
409 if err != nil {
410 continue
411 }
412 blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
413 }
414
415 // Skip empty assistant messages completely
416 if len(blocks) > 0 {
417 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
418 }
419
420 case message.Tool:
421 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
422 for i, toolResult := range msg.ToolResults() {
423 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
424 }
425 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
426 }
427 }
428
429 return anthropicMessages
430}