1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "log"
9 "strings"
10 "time"
11
12 "github.com/anthropics/anthropic-sdk-go"
13 "github.com/anthropics/anthropic-sdk-go/option"
14 "github.com/kujtimiihoxha/termai/internal/llm/models"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/message"
17)
18
19type anthropicProvider struct {
20 client anthropic.Client
21 model models.Model
22 maxTokens int64
23 apiKey string
24 systemMessage string
25}
26
27type AnthropicOption func(*anthropicProvider)
28
29func WithAnthropicSystemMessage(message string) AnthropicOption {
30 return func(a *anthropicProvider) {
31 a.systemMessage = message
32 }
33}
34
35func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
36 return func(a *anthropicProvider) {
37 a.maxTokens = maxTokens
38 }
39}
40
41func WithAnthropicModel(model models.Model) AnthropicOption {
42 return func(a *anthropicProvider) {
43 a.model = model
44 }
45}
46
47func WithAnthropicKey(apiKey string) AnthropicOption {
48 return func(a *anthropicProvider) {
49 a.apiKey = apiKey
50 }
51}
52
53func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
54 provider := &anthropicProvider{
55 maxTokens: 1024,
56 }
57
58 for _, opt := range opts {
59 opt(provider)
60 }
61
62 if provider.systemMessage == "" {
63 return nil, errors.New("system message is required")
64 }
65
66 provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
67 return provider, nil
68}
69
70func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
71 anthropicMessages := a.convertToAnthropicMessages(messages)
72 anthropicTools := a.convertToAnthropicTools(tools)
73
74 response, err := a.client.Messages.New(
75 ctx,
76 anthropic.MessageNewParams{
77 Model: anthropic.Model(a.model.APIModel),
78 MaxTokens: a.maxTokens,
79 Temperature: anthropic.Float(0),
80 Messages: anthropicMessages,
81 Tools: anthropicTools,
82 System: []anthropic.TextBlockParam{
83 {
84 Text: a.systemMessage,
85 CacheControl: anthropic.CacheControlEphemeralParam{
86 Type: "ephemeral",
87 },
88 },
89 },
90 },
91 option.WithMaxRetries(8),
92 )
93 if err != nil {
94 return nil, err
95 }
96
97 content := ""
98 for _, block := range response.Content {
99 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
100 content += text.Text
101 }
102 }
103
104 toolCalls := a.extractToolCalls(response.Content)
105 tokenUsage := a.extractTokenUsage(response.Usage)
106
107 return &ProviderResponse{
108 Content: content,
109 ToolCalls: toolCalls,
110 Usage: tokenUsage,
111 }, nil
112}
113
114func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
115 anthropicMessages := a.convertToAnthropicMessages(messages)
116 anthropicTools := a.convertToAnthropicTools(tools)
117
118 var thinkingParam anthropic.ThinkingConfigParamUnion
119 lastMessage := messages[len(messages)-1]
120 temperature := anthropic.Float(0)
121 if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
122 thinkingParam = anthropic.ThinkingConfigParamUnion{
123 OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
124 BudgetTokens: int64(float64(a.maxTokens) * 0.8),
125 Type: "enabled",
126 },
127 }
128 temperature = anthropic.Float(1)
129 }
130
131 eventChan := make(chan ProviderEvent)
132
133 go func() {
134 defer close(eventChan)
135
136 const maxRetries = 8
137 attempts := 0
138
139 for {
140 // If this isn't the first attempt, we're retrying
141 if attempts > 0 {
142 if attempts > maxRetries {
143 eventChan <- ProviderEvent{
144 Type: EventError,
145 Error: errors.New("maximum retry attempts reached for rate limit (429)"),
146 }
147 return
148 }
149
150 // Inform user we're retrying with attempt number
151 eventChan <- ProviderEvent{
152 Type: EventContentDelta,
153 Content: fmt.Sprintf("\n\n[Retrying due to rate limit... attempt %d of %d]\n\n", attempts, maxRetries),
154 }
155
156 // Calculate backoff with exponential backoff and jitter
157 backoffMs := 2000 * (1 << (attempts - 1)) // 2s, 4s, 8s, 16s, 32s
158 jitterMs := int(float64(backoffMs) * 0.2)
159 totalBackoffMs := backoffMs + jitterMs
160
161 // Sleep with backoff, respecting context cancellation
162 select {
163 case <-ctx.Done():
164 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
165 return
166 case <-time.After(time.Duration(totalBackoffMs) * time.Millisecond):
167 // Continue with retry
168 }
169 }
170
171 attempts++
172
173 // Create new streaming request
174 stream := a.client.Messages.NewStreaming(
175 ctx,
176 anthropic.MessageNewParams{
177 Model: anthropic.Model(a.model.APIModel),
178 MaxTokens: a.maxTokens,
179 Temperature: temperature,
180 Messages: anthropicMessages,
181 Tools: anthropicTools,
182 Thinking: thinkingParam,
183 System: []anthropic.TextBlockParam{
184 {
185 Text: a.systemMessage,
186 CacheControl: anthropic.CacheControlEphemeralParam{
187 Type: "ephemeral",
188 },
189 },
190 },
191 },
192 )
193
194 // Process stream events
195 accumulatedMessage := anthropic.Message{}
196 streamSuccess := false
197
198 // Process the stream until completion or error
199 for stream.Next() {
200 event := stream.Current()
201 err := accumulatedMessage.Accumulate(event)
202 if err != nil {
203 eventChan <- ProviderEvent{Type: EventError, Error: err}
204 return // Don't retry on accumulation errors
205 }
206
207 switch event := event.AsAny().(type) {
208 case anthropic.ContentBlockStartEvent:
209 eventChan <- ProviderEvent{Type: EventContentStart}
210
211 case anthropic.ContentBlockDeltaEvent:
212 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
213 eventChan <- ProviderEvent{
214 Type: EventThinkingDelta,
215 Thinking: event.Delta.Thinking,
216 }
217 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
218 eventChan <- ProviderEvent{
219 Type: EventContentDelta,
220 Content: event.Delta.Text,
221 }
222 }
223
224 case anthropic.ContentBlockStopEvent:
225 eventChan <- ProviderEvent{Type: EventContentStop}
226
227 case anthropic.MessageStopEvent:
228 streamSuccess = true
229 content := ""
230 for _, block := range accumulatedMessage.Content {
231 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
232 content += text.Text
233 }
234 }
235
236 toolCalls := a.extractToolCalls(accumulatedMessage.Content)
237 tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
238
239 eventChan <- ProviderEvent{
240 Type: EventComplete,
241 Response: &ProviderResponse{
242 Content: content,
243 ToolCalls: toolCalls,
244 Usage: tokenUsage,
245 FinishReason: string(accumulatedMessage.StopReason),
246 },
247 }
248 }
249 }
250
251 // If the stream completed successfully, we're done
252 if streamSuccess {
253 return
254 }
255
256 // Check for stream errors
257 err := stream.Err()
258 if err != nil {
259 log.Println("error", err)
260
261 var apierr *anthropic.Error
262 if errors.As(err, &apierr) && apierr.StatusCode == 429 {
263 continue
264 }
265
266 // For non-rate limit errors, report and exit
267 eventChan <- ProviderEvent{Type: EventError, Error: err}
268 return
269 }
270 }
271 }()
272
273 return eventChan, nil
274}
275
276func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
277 var toolCalls []message.ToolCall
278
279 for _, block := range content {
280 switch variant := block.AsAny().(type) {
281 case anthropic.ToolUseBlock:
282 toolCall := message.ToolCall{
283 ID: variant.ID,
284 Name: variant.Name,
285 Input: string(variant.Input),
286 Type: string(variant.Type),
287 }
288 toolCalls = append(toolCalls, toolCall)
289 }
290 }
291
292 return toolCalls
293}
294
295func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
296 return TokenUsage{
297 InputTokens: usage.InputTokens,
298 OutputTokens: usage.OutputTokens,
299 CacheCreationTokens: usage.CacheCreationInputTokens,
300 CacheReadTokens: usage.CacheReadInputTokens,
301 }
302}
303
304func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
305 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
306
307 for i, tool := range tools {
308 info := tool.Info()
309 toolParam := anthropic.ToolParam{
310 Name: info.Name,
311 Description: anthropic.String(info.Description),
312 InputSchema: anthropic.ToolInputSchemaParam{
313 Properties: info.Parameters,
314 },
315 }
316
317 if i == len(tools)-1 {
318 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
319 Type: "ephemeral",
320 }
321 }
322
323 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
324 }
325
326 return anthropicTools
327}
328
329func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
330 anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
331 cachedBlocks := 0
332
333 for _, msg := range messages {
334 switch msg.Role {
335 case message.User:
336 content := anthropic.NewTextBlock(msg.Content().String())
337 if cachedBlocks < 2 {
338 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
339 Type: "ephemeral",
340 }
341 cachedBlocks++
342 }
343 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
344
345 case message.Assistant:
346 blocks := []anthropic.ContentBlockParamUnion{}
347 if msg.Content().String() != "" {
348 content := anthropic.NewTextBlock(msg.Content().String())
349 if cachedBlocks < 2 {
350 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
351 Type: "ephemeral",
352 }
353 cachedBlocks++
354 }
355 blocks = append(blocks, content)
356 }
357
358 for _, toolCall := range msg.ToolCalls() {
359 var inputMap map[string]any
360 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
361 if err != nil {
362 continue
363 }
364 blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
365 }
366
367 // Skip empty assistant messages completely
368 if len(blocks) > 0 {
369 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
370 }
371
372 case message.Tool:
373 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
374 for i, toolResult := range msg.ToolResults() {
375 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
376 }
377 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
378 }
379 }
380
381 return anthropicMessages
382}
383