1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "strings"
10 "time"
11
12 "github.com/anthropics/anthropic-sdk-go"
13 "github.com/anthropics/anthropic-sdk-go/bedrock"
14 "github.com/anthropics/anthropic-sdk-go/option"
15 "github.com/kujtimiihoxha/termai/internal/llm/models"
16 "github.com/kujtimiihoxha/termai/internal/llm/tools"
17 "github.com/kujtimiihoxha/termai/internal/message"
18)
19
20type anthropicProvider struct {
21 client anthropic.Client
22 model models.Model
23 maxTokens int64
24 apiKey string
25 systemMessage string
26 useBedrock bool
27 disableCache bool
28}
29
30type AnthropicOption func(*anthropicProvider)
31
32func WithAnthropicSystemMessage(message string) AnthropicOption {
33 return func(a *anthropicProvider) {
34 a.systemMessage = message
35 }
36}
37
38func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
39 return func(a *anthropicProvider) {
40 a.maxTokens = maxTokens
41 }
42}
43
44func WithAnthropicModel(model models.Model) AnthropicOption {
45 return func(a *anthropicProvider) {
46 a.model = model
47 }
48}
49
50func WithAnthropicKey(apiKey string) AnthropicOption {
51 return func(a *anthropicProvider) {
52 a.apiKey = apiKey
53 }
54}
55
56func WithAnthropicBedrock() AnthropicOption {
57 return func(a *anthropicProvider) {
58 a.useBedrock = true
59 }
60}
61
62func WithAnthropicDisableCache() AnthropicOption {
63 return func(a *anthropicProvider) {
64 a.disableCache = true
65 }
66}
67
68func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
69 provider := &anthropicProvider{
70 maxTokens: 1024,
71 }
72
73 for _, opt := range opts {
74 opt(provider)
75 }
76
77 if provider.systemMessage == "" {
78 return nil, errors.New("system message is required")
79 }
80
81 anthropicOptions := []option.RequestOption{}
82
83 if provider.apiKey != "" {
84 anthropicOptions = append(anthropicOptions, option.WithAPIKey(provider.apiKey))
85 }
86 if provider.useBedrock {
87 anthropicOptions = append(anthropicOptions, bedrock.WithLoadDefaultConfig(context.Background()))
88 }
89
90 provider.client = anthropic.NewClient(anthropicOptions...)
91 return provider, nil
92}
93
94func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
95 anthropicMessages := a.convertToAnthropicMessages(messages)
96 anthropicTools := a.convertToAnthropicTools(tools)
97
98 response, err := a.client.Messages.New(
99 ctx,
100 anthropic.MessageNewParams{
101 Model: anthropic.Model(a.model.APIModel),
102 MaxTokens: a.maxTokens,
103 Temperature: anthropic.Float(0),
104 Messages: anthropicMessages,
105 Tools: anthropicTools,
106 System: []anthropic.TextBlockParam{
107 {
108 Text: a.systemMessage,
109 CacheControl: anthropic.CacheControlEphemeralParam{
110 Type: "ephemeral",
111 },
112 },
113 },
114 },
115 )
116 if err != nil {
117 return nil, err
118 }
119
120 content := ""
121 for _, block := range response.Content {
122 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
123 content += text.Text
124 }
125 }
126
127 toolCalls := a.extractToolCalls(response.Content)
128 tokenUsage := a.extractTokenUsage(response.Usage)
129
130 return &ProviderResponse{
131 Content: content,
132 ToolCalls: toolCalls,
133 Usage: tokenUsage,
134 }, nil
135}
136
137func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
138 anthropicMessages := a.convertToAnthropicMessages(messages)
139 anthropicTools := a.convertToAnthropicTools(tools)
140
141 var thinkingParam anthropic.ThinkingConfigParamUnion
142 lastMessage := messages[len(messages)-1]
143 temperature := anthropic.Float(0)
144 if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
145 thinkingParam = anthropic.ThinkingConfigParamUnion{
146 OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
147 BudgetTokens: int64(float64(a.maxTokens) * 0.8),
148 Type: "enabled",
149 },
150 }
151 temperature = anthropic.Float(1)
152 }
153
154 eventChan := make(chan ProviderEvent)
155
156 go func() {
157 defer close(eventChan)
158
159 const maxRetries = 8
160 attempts := 0
161
162 for {
163
164 attempts++
165
166 stream := a.client.Messages.NewStreaming(
167 ctx,
168 anthropic.MessageNewParams{
169 Model: anthropic.Model(a.model.APIModel),
170 MaxTokens: a.maxTokens,
171 Temperature: temperature,
172 Messages: anthropicMessages,
173 Tools: anthropicTools,
174 Thinking: thinkingParam,
175 System: []anthropic.TextBlockParam{
176 {
177 Text: a.systemMessage,
178 CacheControl: anthropic.CacheControlEphemeralParam{
179 Type: "ephemeral",
180 },
181 },
182 },
183 },
184 )
185
186 accumulatedMessage := anthropic.Message{}
187
188 for stream.Next() {
189 event := stream.Current()
190 err := accumulatedMessage.Accumulate(event)
191 if err != nil {
192 eventChan <- ProviderEvent{Type: EventError, Error: err}
193 return // Don't retry on accumulation errors
194 }
195
196 switch event := event.AsAny().(type) {
197 case anthropic.ContentBlockStartEvent:
198 eventChan <- ProviderEvent{Type: EventContentStart}
199
200 case anthropic.ContentBlockDeltaEvent:
201 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
202 eventChan <- ProviderEvent{
203 Type: EventThinkingDelta,
204 Thinking: event.Delta.Thinking,
205 }
206 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
207 eventChan <- ProviderEvent{
208 Type: EventContentDelta,
209 Content: event.Delta.Text,
210 }
211 }
212
213 case anthropic.ContentBlockStopEvent:
214 eventChan <- ProviderEvent{Type: EventContentStop}
215
216 case anthropic.MessageStopEvent:
217 content := ""
218 for _, block := range accumulatedMessage.Content {
219 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
220 content += text.Text
221 }
222 }
223
224 toolCalls := a.extractToolCalls(accumulatedMessage.Content)
225 tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
226
227 eventChan <- ProviderEvent{
228 Type: EventComplete,
229 Response: &ProviderResponse{
230 Content: content,
231 ToolCalls: toolCalls,
232 Usage: tokenUsage,
233 FinishReason: string(accumulatedMessage.StopReason),
234 },
235 }
236 }
237 }
238
239 err := stream.Err()
240 if err == nil || errors.Is(err, io.EOF) {
241 return
242 }
243
244 var apierr *anthropic.Error
245 if !errors.As(err, &apierr) {
246 eventChan <- ProviderEvent{Type: EventError, Error: err}
247 return
248 }
249
250 if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
251 eventChan <- ProviderEvent{Type: EventError, Error: err}
252 return
253 }
254
255 if attempts > maxRetries {
256 eventChan <- ProviderEvent{
257 Type: EventError,
258 Error: errors.New("maximum retry attempts reached for rate limit (429)"),
259 }
260 return
261 }
262
263 retryMs := 0
264 retryAfterValues := apierr.Response.Header.Values("Retry-After")
265 if len(retryAfterValues) > 0 {
266 var retryAfterSec int
267 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
268 retryMs = retryAfterSec * 1000
269 eventChan <- ProviderEvent{
270 Type: EventWarning,
271 Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
272 }
273 }
274 } else {
275 eventChan <- ProviderEvent{
276 Type: EventWarning,
277 Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
278 }
279
280 backoffMs := 2000 * (1 << (attempts - 1))
281 jitterMs := int(float64(backoffMs) * 0.2)
282 retryMs = backoffMs + jitterMs
283 }
284 select {
285 case <-ctx.Done():
286 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
287 return
288 case <-time.After(time.Duration(retryMs) * time.Millisecond):
289 continue
290 }
291
292 }
293 }()
294
295 return eventChan, nil
296}
297
298func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
299 var toolCalls []message.ToolCall
300
301 for _, block := range content {
302 switch variant := block.AsAny().(type) {
303 case anthropic.ToolUseBlock:
304 toolCall := message.ToolCall{
305 ID: variant.ID,
306 Name: variant.Name,
307 Input: string(variant.Input),
308 Type: string(variant.Type),
309 }
310 toolCalls = append(toolCalls, toolCall)
311 }
312 }
313
314 return toolCalls
315}
316
317func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
318 return TokenUsage{
319 InputTokens: usage.InputTokens,
320 OutputTokens: usage.OutputTokens,
321 CacheCreationTokens: usage.CacheCreationInputTokens,
322 CacheReadTokens: usage.CacheReadInputTokens,
323 }
324}
325
326func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
327 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
328
329 for i, tool := range tools {
330 info := tool.Info()
331 toolParam := anthropic.ToolParam{
332 Name: info.Name,
333 Description: anthropic.String(info.Description),
334 InputSchema: anthropic.ToolInputSchemaParam{
335 Properties: info.Parameters,
336 },
337 }
338
339 if i == len(tools)-1 && !a.disableCache {
340 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
341 Type: "ephemeral",
342 }
343 }
344
345 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
346 }
347
348 return anthropicTools
349}
350
351func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
352 anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
353 cachedBlocks := 0
354
355 for _, msg := range messages {
356 switch msg.Role {
357 case message.User:
358 content := anthropic.NewTextBlock(msg.Content().String())
359 if cachedBlocks < 2 && !a.disableCache {
360 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
361 Type: "ephemeral",
362 }
363 cachedBlocks++
364 }
365 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
366
367 case message.Assistant:
368 blocks := []anthropic.ContentBlockParamUnion{}
369 if msg.Content().String() != "" {
370 content := anthropic.NewTextBlock(msg.Content().String())
371 if cachedBlocks < 2 && !a.disableCache {
372 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
373 Type: "ephemeral",
374 }
375 cachedBlocks++
376 }
377 blocks = append(blocks, content)
378 }
379
380 for _, toolCall := range msg.ToolCalls() {
381 var inputMap map[string]any
382 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
383 if err != nil {
384 continue
385 }
386 blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
387 }
388
389 if len(blocks) > 0 {
390 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
391 }
392
393 case message.Tool:
394 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
395 for i, toolResult := range msg.ToolResults() {
396 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
397 }
398 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
399 }
400 }
401
402 return anthropicMessages
403}