1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/anthropics/anthropic-sdk-go"
12 "github.com/anthropics/anthropic-sdk-go/bedrock"
13 "github.com/anthropics/anthropic-sdk-go/option"
14 "github.com/kujtimiihoxha/termai/internal/llm/models"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/message"
17)
18
19type anthropicProvider struct {
20 client anthropic.Client
21 model models.Model
22 maxTokens int64
23 apiKey string
24 systemMessage string
25 useBedrock bool
26 disableCache bool
27}
28
29type AnthropicOption func(*anthropicProvider)
30
31func WithAnthropicSystemMessage(message string) AnthropicOption {
32 return func(a *anthropicProvider) {
33 a.systemMessage = message
34 }
35}
36
37func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
38 return func(a *anthropicProvider) {
39 a.maxTokens = maxTokens
40 }
41}
42
43func WithAnthropicModel(model models.Model) AnthropicOption {
44 return func(a *anthropicProvider) {
45 a.model = model
46 }
47}
48
49func WithAnthropicKey(apiKey string) AnthropicOption {
50 return func(a *anthropicProvider) {
51 a.apiKey = apiKey
52 }
53}
54
55func WithAnthropicBedrock() AnthropicOption {
56 return func(a *anthropicProvider) {
57 a.useBedrock = true
58 }
59}
60
61func WithAnthropicDisableCache() AnthropicOption {
62 return func(a *anthropicProvider) {
63 a.disableCache = true
64 }
65}
66
67func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
68 provider := &anthropicProvider{
69 maxTokens: 1024,
70 }
71
72 for _, opt := range opts {
73 opt(provider)
74 }
75
76 if provider.systemMessage == "" {
77 return nil, errors.New("system message is required")
78 }
79
80 anthropicOptions := []option.RequestOption{}
81
82 if provider.apiKey != "" {
83 anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
84 }
85 if provider.useBedrock {
86 anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
87 }
88
89 provider.client = anthropic.NewClient(anthropicOptions...)
90 return provider, nil
91}
92
93func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
94 anthropicMessages := a.convertToAnthropicMessages(messages)
95 anthropicTools := a.convertToAnthropicTools(tools)
96
97 response, err := a.client.Messages.New(
98 ctx,
99 anthropic.MessageNewParams{
100 Model: anthropic.Model(a.model.APIModel),
101 MaxTokens: a.maxTokens,
102 Temperature: anthropic.Float(0),
103 Messages: anthropicMessages,
104 Tools: anthropicTools,
105 System: []anthropic.TextBlockParam{
106 {
107 Text: a.systemMessage,
108 CacheControl: anthropic.CacheControlEphemeralParam{
109 Type: "ephemeral",
110 },
111 },
112 },
113 },
114 )
115 if err != nil {
116 return nil, err
117 }
118
119 content := ""
120 for _, block := range response.Content {
121 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
122 content += text.Text
123 }
124 }
125
126 toolCalls := a.extractToolCalls(response.Content)
127 tokenUsage := a.extractTokenUsage(response.Usage)
128
129 return &ProviderResponse{
130 Content: content,
131 ToolCalls: toolCalls,
132 Usage: tokenUsage,
133 }, nil
134}
135
136func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
137 anthropicMessages := a.convertToAnthropicMessages(messages)
138 anthropicTools := a.convertToAnthropicTools(tools)
139
140 var thinkingParam anthropic.ThinkingConfigParamUnion
141 lastMessage := messages[len(messages)-1]
142 temperature := anthropic.Float(0)
143 if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
144 thinkingParam = anthropic.ThinkingConfigParamUnion{
145 OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
146 BudgetTokens: int64(float64(a.maxTokens) * 0.8),
147 Type: "enabled",
148 },
149 }
150 temperature = anthropic.Float(1)
151 }
152
153 eventChan := make(chan ProviderEvent)
154
155 go func() {
156 defer close(eventChan)
157
158 const maxRetries = 8
159 attempts := 0
160
161 for {
162
163 attempts++
164
165 stream := a.client.Messages.NewStreaming(
166 ctx,
167 anthropic.MessageNewParams{
168 Model: anthropic.Model(a.model.APIModel),
169 MaxTokens: a.maxTokens,
170 Temperature: temperature,
171 Messages: anthropicMessages,
172 Tools: anthropicTools,
173 Thinking: thinkingParam,
174 System: []anthropic.TextBlockParam{
175 {
176 Text: a.systemMessage,
177 CacheControl: anthropic.CacheControlEphemeralParam{
178 Type: "ephemeral",
179 },
180 },
181 },
182 },
183 )
184
185 accumulatedMessage := anthropic.Message{}
186
187 for stream.Next() {
188 event := stream.Current()
189 err := accumulatedMessage.Accumulate(event)
190 if err != nil {
191 eventChan <- ProviderEvent{Type: EventError, Error: err}
192 return // Don't retry on accumulation errors
193 }
194
195 switch event := event.AsAny().(type) {
196 case anthropic.ContentBlockStartEvent:
197 eventChan <- ProviderEvent{Type: EventContentStart}
198
199 case anthropic.ContentBlockDeltaEvent:
200 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
201 eventChan <- ProviderEvent{
202 Type: EventThinkingDelta,
203 Thinking: event.Delta.Thinking,
204 }
205 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
206 eventChan <- ProviderEvent{
207 Type: EventContentDelta,
208 Content: event.Delta.Text,
209 }
210 }
211
212 case anthropic.ContentBlockStopEvent:
213 eventChan <- ProviderEvent{Type: EventContentStop}
214
215 case anthropic.MessageStopEvent:
216 content := ""
217 for _, block := range accumulatedMessage.Content {
218 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
219 content += text.Text
220 }
221 }
222
223 toolCalls := a.extractToolCalls(accumulatedMessage.Content)
224 tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
225
226 eventChan <- ProviderEvent{
227 Type: EventComplete,
228 Response: &ProviderResponse{
229 Content: content,
230 ToolCalls: toolCalls,
231 Usage: tokenUsage,
232 FinishReason: string(accumulatedMessage.StopReason),
233 },
234 }
235 }
236 }
237
238 err := stream.Err()
239 if err == nil {
240 return
241 }
242
243 var apierr *anthropic.Error
244 if !errors.As(err, &apierr) {
245 eventChan <- ProviderEvent{Type: EventError, Error: err}
246 return
247 }
248
249 if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
250 eventChan <- ProviderEvent{Type: EventError, Error: err}
251 return
252 }
253
254 if attempts > maxRetries {
255 eventChan <- ProviderEvent{
256 Type: EventError,
257 Error: errors.New("maximum retry attempts reached for rate limit (429)"),
258 }
259 return
260 }
261
262 retryMs := 0
263 retryAfterValues := apierr.Response.Header.Values("Retry-After")
264 if len(retryAfterValues) > 0 {
265 var retryAfterSec int
266 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
267 retryMs = retryAfterSec * 1000
268 eventChan <- ProviderEvent{
269 Type: EventWarning,
270 Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
271 }
272 }
273 } else {
274 eventChan <- ProviderEvent{
275 Type: EventWarning,
276 Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
277 }
278
279 backoffMs := 2000 * (1 << (attempts - 1))
280 jitterMs := int(float64(backoffMs) * 0.2)
281 retryMs = backoffMs + jitterMs
282 }
283 select {
284 case <-ctx.Done():
285 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
286 return
287 case <-time.After(time.Duration(retryMs) * time.Millisecond):
288 continue
289 }
290
291 }
292 }()
293
294 return eventChan, nil
295}
296
297func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
298 var toolCalls []message.ToolCall
299
300 for _, block := range content {
301 switch variant := block.AsAny().(type) {
302 case anthropic.ToolUseBlock:
303 toolCall := message.ToolCall{
304 ID: variant.ID,
305 Name: variant.Name,
306 Input: string(variant.Input),
307 Type: string(variant.Type),
308 }
309 toolCalls = append(toolCalls, toolCall)
310 }
311 }
312
313 return toolCalls
314}
315
316func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
317 return TokenUsage{
318 InputTokens: usage.InputTokens,
319 OutputTokens: usage.OutputTokens,
320 CacheCreationTokens: usage.CacheCreationInputTokens,
321 CacheReadTokens: usage.CacheReadInputTokens,
322 }
323}
324
325func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
326 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
327
328 for i, tool := range tools {
329 info := tool.Info()
330 toolParam := anthropic.ToolParam{
331 Name: info.Name,
332 Description: anthropic.String(info.Description),
333 InputSchema: anthropic.ToolInputSchemaParam{
334 Properties: info.Parameters,
335 },
336 }
337
338 if i == len(tools)-1 && !a.disableCache {
339 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
340 Type: "ephemeral",
341 }
342 }
343
344 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
345 }
346
347 return anthropicTools
348}
349
350func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
351 anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
352 cachedBlocks := 0
353
354 for _, msg := range messages {
355 switch msg.Role {
356 case message.User:
357 content := anthropic.NewTextBlock(msg.Content().String())
358 if cachedBlocks < 2 && !a.disableCache {
359 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
360 Type: "ephemeral",
361 }
362 cachedBlocks++
363 }
364 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
365
366 case message.Assistant:
367 blocks := []anthropic.ContentBlockParamUnion{}
368 if msg.Content().String() != "" {
369 content := anthropic.NewTextBlock(msg.Content().String())
370 if cachedBlocks < 2 && !a.disableCache {
371 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
372 Type: "ephemeral",
373 }
374 cachedBlocks++
375 }
376 blocks = append(blocks, content)
377 }
378
379 for _, toolCall := range msg.ToolCalls() {
380 var inputMap map[string]any
381 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
382 if err != nil {
383 continue
384 }
385 blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
386 }
387
388 if len(blocks) > 0 {
389 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
390 }
391
392 case message.Tool:
393 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
394 for i, toolResult := range msg.ToolResults() {
395 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
396 }
397 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
398 }
399 }
400
401 return anthropicMessages
402}