1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/anthropics/anthropic-sdk-go"
12 "github.com/anthropics/anthropic-sdk-go/option"
13 "github.com/kujtimiihoxha/termai/internal/llm/models"
14 "github.com/kujtimiihoxha/termai/internal/llm/tools"
15 "github.com/kujtimiihoxha/termai/internal/message"
16)
17
18type anthropicProvider struct {
19 client anthropic.Client
20 model models.Model
21 maxTokens int64
22 apiKey string
23 systemMessage string
24}
25
26type AnthropicOption func(*anthropicProvider)
27
28func WithAnthropicSystemMessage(message string) AnthropicOption {
29 return func(a *anthropicProvider) {
30 a.systemMessage = message
31 }
32}
33
34func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
35 return func(a *anthropicProvider) {
36 a.maxTokens = maxTokens
37 }
38}
39
40func WithAnthropicModel(model models.Model) AnthropicOption {
41 return func(a *anthropicProvider) {
42 a.model = model
43 }
44}
45
46func WithAnthropicKey(apiKey string) AnthropicOption {
47 return func(a *anthropicProvider) {
48 a.apiKey = apiKey
49 }
50}
51
52func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
53 provider := &anthropicProvider{
54 maxTokens: 1024,
55 }
56
57 for _, opt := range opts {
58 opt(provider)
59 }
60
61 if provider.systemMessage == "" {
62 return nil, errors.New("system message is required")
63 }
64
65 provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
66 return provider, nil
67}
68
69func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
70 anthropicMessages := a.convertToAnthropicMessages(messages)
71 anthropicTools := a.convertToAnthropicTools(tools)
72
73 response, err := a.client.Messages.New(
74 ctx,
75 anthropic.MessageNewParams{
76 Model: anthropic.Model(a.model.APIModel),
77 MaxTokens: a.maxTokens,
78 Temperature: anthropic.Float(0),
79 Messages: anthropicMessages,
80 Tools: anthropicTools,
81 System: []anthropic.TextBlockParam{
82 {
83 Text: a.systemMessage,
84 CacheControl: anthropic.CacheControlEphemeralParam{
85 Type: "ephemeral",
86 },
87 },
88 },
89 },
90 )
91 if err != nil {
92 return nil, err
93 }
94
95 content := ""
96 for _, block := range response.Content {
97 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
98 content += text.Text
99 }
100 }
101
102 toolCalls := a.extractToolCalls(response.Content)
103 tokenUsage := a.extractTokenUsage(response.Usage)
104
105 return &ProviderResponse{
106 Content: content,
107 ToolCalls: toolCalls,
108 Usage: tokenUsage,
109 }, nil
110}
111
112func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
113 anthropicMessages := a.convertToAnthropicMessages(messages)
114 anthropicTools := a.convertToAnthropicTools(tools)
115
116 var thinkingParam anthropic.ThinkingConfigParamUnion
117 lastMessage := messages[len(messages)-1]
118 temperature := anthropic.Float(0)
119 if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
120 thinkingParam = anthropic.ThinkingConfigParamUnion{
121 OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
122 BudgetTokens: int64(float64(a.maxTokens) * 0.8),
123 Type: "enabled",
124 },
125 }
126 temperature = anthropic.Float(1)
127 }
128
129 eventChan := make(chan ProviderEvent)
130
131 go func() {
132 defer close(eventChan)
133
134 const maxRetries = 8
135 attempts := 0
136
137 for {
138
139 attempts++
140
141 stream := a.client.Messages.NewStreaming(
142 ctx,
143 anthropic.MessageNewParams{
144 Model: anthropic.Model(a.model.APIModel),
145 MaxTokens: a.maxTokens,
146 Temperature: temperature,
147 Messages: anthropicMessages,
148 Tools: anthropicTools,
149 Thinking: thinkingParam,
150 System: []anthropic.TextBlockParam{
151 {
152 Text: a.systemMessage,
153 CacheControl: anthropic.CacheControlEphemeralParam{
154 Type: "ephemeral",
155 },
156 },
157 },
158 },
159 )
160
161 accumulatedMessage := anthropic.Message{}
162
163 for stream.Next() {
164 event := stream.Current()
165 err := accumulatedMessage.Accumulate(event)
166 if err != nil {
167 eventChan <- ProviderEvent{Type: EventError, Error: err}
168 return // Don't retry on accumulation errors
169 }
170
171 switch event := event.AsAny().(type) {
172 case anthropic.ContentBlockStartEvent:
173 eventChan <- ProviderEvent{Type: EventContentStart}
174
175 case anthropic.ContentBlockDeltaEvent:
176 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
177 eventChan <- ProviderEvent{
178 Type: EventThinkingDelta,
179 Thinking: event.Delta.Thinking,
180 }
181 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
182 eventChan <- ProviderEvent{
183 Type: EventContentDelta,
184 Content: event.Delta.Text,
185 }
186 }
187
188 case anthropic.ContentBlockStopEvent:
189 eventChan <- ProviderEvent{Type: EventContentStop}
190
191 case anthropic.MessageStopEvent:
192 content := ""
193 for _, block := range accumulatedMessage.Content {
194 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
195 content += text.Text
196 }
197 }
198
199 toolCalls := a.extractToolCalls(accumulatedMessage.Content)
200 tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
201
202 eventChan <- ProviderEvent{
203 Type: EventComplete,
204 Response: &ProviderResponse{
205 Content: content,
206 ToolCalls: toolCalls,
207 Usage: tokenUsage,
208 FinishReason: string(accumulatedMessage.StopReason),
209 },
210 }
211 }
212 }
213
214 err := stream.Err()
215 if err == nil {
216 return
217 }
218
219 var apierr *anthropic.Error
220 if !errors.As(err, &apierr) {
221 eventChan <- ProviderEvent{Type: EventError, Error: err}
222 return
223 }
224
225 if apierr.StatusCode != 429 && apierr.StatusCode != 529 {
226 eventChan <- ProviderEvent{Type: EventError, Error: err}
227 return
228 }
229
230 if attempts > maxRetries {
231 eventChan <- ProviderEvent{
232 Type: EventError,
233 Error: errors.New("maximum retry attempts reached for rate limit (429)"),
234 }
235 return
236 }
237
238 retryMs := 0
239 retryAfterValues := apierr.Response.Header.Values("Retry-After")
240 if len(retryAfterValues) > 0 {
241 var retryAfterSec int
242 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryAfterSec); err == nil {
243 retryMs = retryAfterSec * 1000
244 eventChan <- ProviderEvent{
245 Type: EventWarning,
246 Info: fmt.Sprintf("[Rate limited: waiting %d seconds as specified by API]", retryAfterSec),
247 }
248 }
249 } else {
250 eventChan <- ProviderEvent{
251 Type: EventWarning,
252 Info: fmt.Sprintf("[Retrying due to rate limit... attempt %d of %d]", attempts, maxRetries),
253 }
254
255 backoffMs := 2000 * (1 << (attempts - 1))
256 jitterMs := int(float64(backoffMs) * 0.2)
257 retryMs = backoffMs + jitterMs
258 }
259 select {
260 case <-ctx.Done():
261 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
262 return
263 case <-time.After(time.Duration(retryMs) * time.Millisecond):
264 continue
265 }
266
267 }
268 }()
269
270 return eventChan, nil
271}
272
273func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
274 var toolCalls []message.ToolCall
275
276 for _, block := range content {
277 switch variant := block.AsAny().(type) {
278 case anthropic.ToolUseBlock:
279 toolCall := message.ToolCall{
280 ID: variant.ID,
281 Name: variant.Name,
282 Input: string(variant.Input),
283 Type: string(variant.Type),
284 }
285 toolCalls = append(toolCalls, toolCall)
286 }
287 }
288
289 return toolCalls
290}
291
292func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
293 return TokenUsage{
294 InputTokens: usage.InputTokens,
295 OutputTokens: usage.OutputTokens,
296 CacheCreationTokens: usage.CacheCreationInputTokens,
297 CacheReadTokens: usage.CacheReadInputTokens,
298 }
299}
300
301func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
302 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
303
304 for i, tool := range tools {
305 info := tool.Info()
306 toolParam := anthropic.ToolParam{
307 Name: info.Name,
308 Description: anthropic.String(info.Description),
309 InputSchema: anthropic.ToolInputSchemaParam{
310 Properties: info.Parameters,
311 },
312 }
313
314 if i == len(tools)-1 {
315 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
316 Type: "ephemeral",
317 }
318 }
319
320 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
321 }
322
323 return anthropicTools
324}
325
326func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
327 anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
328 cachedBlocks := 0
329
330 for _, msg := range messages {
331 switch msg.Role {
332 case message.User:
333 content := anthropic.NewTextBlock(msg.Content().String())
334 if cachedBlocks < 2 {
335 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
336 Type: "ephemeral",
337 }
338 cachedBlocks++
339 }
340 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
341
342 case message.Assistant:
343 blocks := []anthropic.ContentBlockParamUnion{}
344 if msg.Content().String() != "" {
345 content := anthropic.NewTextBlock(msg.Content().String())
346 if cachedBlocks < 2 {
347 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
348 Type: "ephemeral",
349 }
350 cachedBlocks++
351 }
352 blocks = append(blocks, content)
353 }
354
355 for _, toolCall := range msg.ToolCalls() {
356 var inputMap map[string]any
357 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
358 if err != nil {
359 continue
360 }
361 blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
362 }
363
364 if len(blocks) > 0 {
365 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
366 }
367
368 case message.Tool:
369 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
370 for i, toolResult := range msg.ToolResults() {
371 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
372 }
373 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
374 }
375 }
376
377 return anthropicMessages
378}