1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/anthropics/anthropic-sdk-go"
12 "github.com/anthropics/anthropic-sdk-go/option"
13 "github.com/kujtimiihoxha/termai/internal/llm/models"
14 "github.com/kujtimiihoxha/termai/internal/llm/tools"
15 "github.com/kujtimiihoxha/termai/internal/message"
16)
17
18type anthropicProvider struct {
19 client anthropic.Client
20 model models.Model
21 maxTokens int64
22 apiKey string
23 systemMessage string
24}
25
26type AnthropicOption func(*anthropicProvider)
27
28func WithAnthropicSystemMessage(message string) AnthropicOption {
29 return func(a *anthropicProvider) {
30 a.systemMessage = message
31 }
32}
33
34func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
35 return func(a *anthropicProvider) {
36 a.maxTokens = maxTokens
37 }
38}
39
40func WithAnthropicModel(model models.Model) AnthropicOption {
41 return func(a *anthropicProvider) {
42 a.model = model
43 }
44}
45
46func WithAnthropicKey(apiKey string) AnthropicOption {
47 return func(a *anthropicProvider) {
48 a.apiKey = apiKey
49 }
50}
51
52func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
53 provider := &anthropicProvider{
54 maxTokens: 1024,
55 }
56
57 for _, opt := range opts {
58 opt(provider)
59 }
60
61 if provider.systemMessage == "" {
62 return nil, errors.New("system message is required")
63 }
64
65 provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
66 return provider, nil
67}
68
69func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
70 anthropicMessages := a.convertToAnthropicMessages(messages)
71 anthropicTools := a.convertToAnthropicTools(tools)
72
73 response, err := a.client.Messages.New(
74 ctx,
75 anthropic.MessageNewParams{
76 Model: anthropic.Model(a.model.APIModel),
77 MaxTokens: a.maxTokens,
78 Temperature: anthropic.Float(0),
79 Messages: anthropicMessages,
80 Tools: anthropicTools,
81 System: []anthropic.TextBlockParam{
82 {
83 Text: a.systemMessage,
84 CacheControl: anthropic.CacheControlEphemeralParam{
85 Type: "ephemeral",
86 },
87 },
88 },
89 },
90 )
91 if err != nil {
92 return nil, err
93 }
94
95 content := ""
96 for _, block := range response.Content {
97 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
98 content += text.Text
99 }
100 }
101
102 toolCalls := a.extractToolCalls(response.Content)
103 tokenUsage := a.extractTokenUsage(response.Usage)
104
105 return &ProviderResponse{
106 Content: content,
107 ToolCalls: toolCalls,
108 Usage: tokenUsage,
109 }, nil
110}
111
112func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
113 anthropicMessages := a.convertToAnthropicMessages(messages)
114 anthropicTools := a.convertToAnthropicTools(tools)
115
116 var thinkingParam anthropic.ThinkingConfigParamUnion
117 lastMessage := messages[len(messages)-1]
118 temperature := anthropic.Float(0)
119 if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
120 thinkingParam = anthropic.ThinkingConfigParamUnion{
121 OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
122 BudgetTokens: int64(float64(a.maxTokens) * 0.8),
123 Type: "enabled",
124 },
125 }
126 temperature = anthropic.Float(1)
127 }
128
129 eventChan := make(chan ProviderEvent)
130
131 go func() {
132 defer close(eventChan)
133
134 const maxRetries = 8
135 attempts := 0
136
137 for {
138 // If this isn't the first attempt, we're retrying
139 if attempts > 0 {
140 if attempts > maxRetries {
141 eventChan <- ProviderEvent{
142 Type: EventError,
143 Error: errors.New("maximum retry attempts reached for rate limit (429)"),
144 }
145 return
146 }
147
148 // Inform user we're retrying with attempt number
149 eventChan <- ProviderEvent{
150 Type: EventWarning,
151 Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
152 }
153
154 // Calculate backoff with exponential backoff and jitter
155 backoffMs := 2000 * (1 << (attempts - 1)) // 2s, 4s, 8s, 16s, 32s
156 jitterMs := int(float64(backoffMs) * 0.2)
157 totalBackoffMs := backoffMs + jitterMs
158
159 // Sleep with backoff, respecting context cancellation
160 select {
161 case <-ctx.Done():
162 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
163 return
164 case <-time.After(time.Duration(totalBackoffMs) * time.Millisecond):
165 // Continue with retry
166 }
167 }
168
169 attempts++
170
171 // Create new streaming request
172 stream := a.client.Messages.NewStreaming(
173 ctx,
174 anthropic.MessageNewParams{
175 Model: anthropic.Model(a.model.APIModel),
176 MaxTokens: a.maxTokens,
177 Temperature: temperature,
178 Messages: anthropicMessages,
179 Tools: anthropicTools,
180 Thinking: thinkingParam,
181 System: []anthropic.TextBlockParam{
182 {
183 Text: a.systemMessage,
184 CacheControl: anthropic.CacheControlEphemeralParam{
185 Type: "ephemeral",
186 },
187 },
188 },
189 },
190 )
191
192 // Process stream events
193 accumulatedMessage := anthropic.Message{}
194 streamSuccess := false
195
196 // Process the stream until completion or error
197 for stream.Next() {
198 event := stream.Current()
199 err := accumulatedMessage.Accumulate(event)
200 if err != nil {
201 eventChan <- ProviderEvent{Type: EventError, Error: err}
202 return // Don't retry on accumulation errors
203 }
204
205 switch event := event.AsAny().(type) {
206 case anthropic.ContentBlockStartEvent:
207 eventChan <- ProviderEvent{Type: EventContentStart}
208
209 case anthropic.ContentBlockDeltaEvent:
210 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
211 eventChan <- ProviderEvent{
212 Type: EventThinkingDelta,
213 Thinking: event.Delta.Thinking,
214 }
215 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
216 eventChan <- ProviderEvent{
217 Type: EventContentDelta,
218 Content: event.Delta.Text,
219 }
220 }
221
222 case anthropic.ContentBlockStopEvent:
223 eventChan <- ProviderEvent{Type: EventContentStop}
224
225 case anthropic.MessageStopEvent:
226 streamSuccess = true
227 content := ""
228 for _, block := range accumulatedMessage.Content {
229 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
230 content += text.Text
231 }
232 }
233
234 toolCalls := a.extractToolCalls(accumulatedMessage.Content)
235 tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
236
237 eventChan <- ProviderEvent{
238 Type: EventComplete,
239 Response: &ProviderResponse{
240 Content: content,
241 ToolCalls: toolCalls,
242 Usage: tokenUsage,
243 FinishReason: string(accumulatedMessage.StopReason),
244 },
245 }
246 }
247 }
248
249 // If the stream completed successfully, we're done
250 if streamSuccess {
251 return
252 }
253
254 // Check for stream errors
255 err := stream.Err()
256 if err != nil {
257 var apierr *anthropic.Error
258 if errors.As(err, &apierr) {
259 if apierr.StatusCode == 429 || apierr.StatusCode == 529 {
260 // Check for Retry-After header
261 if retryAfterValues := apierr.Response.Header.Values("Retry-After"); len(retryAfterValues) > 0 {
262 // Parse the retry after value (seconds)
263 var retryAfterSec int
264 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
265 retryMs := retryAfterSec * 1000
266
267 // Inform user of retry with specific wait time
268 eventChan <- ProviderEvent{
269 Type: EventWarning,
270 Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
271 }
272
273 // Sleep respecting context cancellation
274 select {
275 case <-ctx.Done():
276 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
277 return
278 case <-time.After(time.Duration(retryMs) * time.Millisecond):
279 // Continue with retry after specified delay
280 continue
281 }
282 }
283 }
284
285 // Fall back to exponential backoff if Retry-After parsing failed
286 continue
287 }
288 }
289
290 // For non-rate limit errors, report and exit
291 eventChan <- ProviderEvent{Type: EventError, Error: err}
292 return
293 }
294 }
295 }()
296
297 return eventChan, nil
298}
299
300func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
301 var toolCalls []message.ToolCall
302
303 for _, block := range content {
304 switch variant := block.AsAny().(type) {
305 case anthropic.ToolUseBlock:
306 toolCall := message.ToolCall{
307 ID: variant.ID,
308 Name: variant.Name,
309 Input: string(variant.Input),
310 Type: string(variant.Type),
311 }
312 toolCalls = append(toolCalls, toolCall)
313 }
314 }
315
316 return toolCalls
317}
318
319func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
320 return TokenUsage{
321 InputTokens: usage.InputTokens,
322 OutputTokens: usage.OutputTokens,
323 CacheCreationTokens: usage.CacheCreationInputTokens,
324 CacheReadTokens: usage.CacheReadInputTokens,
325 }
326}
327
328func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
329 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
330
331 for i, tool := range tools {
332 info := tool.Info()
333 toolParam := anthropic.ToolParam{
334 Name: info.Name,
335 Description: anthropic.String(info.Description),
336 InputSchema: anthropic.ToolInputSchemaParam{
337 Properties: info.Parameters,
338 },
339 }
340
341 if i == len(tools)-1 {
342 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
343 Type: "ephemeral",
344 }
345 }
346
347 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
348 }
349
350 return anthropicTools
351}
352
353func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
354 anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
355 cachedBlocks := 0
356
357 for _, msg := range messages {
358 switch msg.Role {
359 case message.User:
360 content := anthropic.NewTextBlock(msg.Content().String())
361 if cachedBlocks < 2 {
362 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
363 Type: "ephemeral",
364 }
365 cachedBlocks++
366 }
367 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
368
369 case message.Assistant:
370 blocks := []anthropic.ContentBlockParamUnion{}
371 if msg.Content().String() != "" {
372 content := anthropic.NewTextBlock(msg.Content().String())
373 if cachedBlocks < 2 {
374 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
375 Type: "ephemeral",
376 }
377 cachedBlocks++
378 }
379 blocks = append(blocks, content)
380 }
381
382 for _, toolCall := range msg.ToolCalls() {
383 var inputMap map[string]any
384 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
385 if err != nil {
386 continue
387 }
388 blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
389 }
390
391 // Skip empty assistant messages completely
392 if len(blocks) > 0 {
393 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
394 }
395
396 case message.Tool:
397 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
398 for i, toolResult := range msg.ToolResults() {
399 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
400 }
401 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
402 }
403 }
404
405 return anthropicMessages
406}
407