1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "strings"
8
9 "github.com/anthropics/anthropic-sdk-go"
10 "github.com/anthropics/anthropic-sdk-go/option"
11 "github.com/kujtimiihoxha/termai/internal/llm/models"
12 "github.com/kujtimiihoxha/termai/internal/llm/tools"
13 "github.com/kujtimiihoxha/termai/internal/message"
14)
15
16type anthropicProvider struct {
17 client anthropic.Client
18 model models.Model
19 maxTokens int64
20 apiKey string
21 systemMessage string
22}
23
24type AnthropicOption func(*anthropicProvider)
25
26func WithAnthropicSystemMessage(message string) AnthropicOption {
27 return func(a *anthropicProvider) {
28 a.systemMessage = message
29 }
30}
31
32func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
33 return func(a *anthropicProvider) {
34 a.maxTokens = maxTokens
35 }
36}
37
38func WithAnthropicModel(model models.Model) AnthropicOption {
39 return func(a *anthropicProvider) {
40 a.model = model
41 }
42}
43
44func WithAnthropicKey(apiKey string) AnthropicOption {
45 return func(a *anthropicProvider) {
46 a.apiKey = apiKey
47 }
48}
49
50func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
51 provider := &anthropicProvider{
52 maxTokens: 1024,
53 }
54
55 for _, opt := range opts {
56 opt(provider)
57 }
58
59 if provider.systemMessage == "" {
60 return nil, errors.New("system message is required")
61 }
62
63 provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
64 return provider, nil
65}
66
67func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
68 anthropicMessages := a.convertToAnthropicMessages(messages)
69 anthropicTools := a.convertToAnthropicTools(tools)
70
71 response, err := a.client.Messages.New(
72 ctx,
73 anthropic.MessageNewParams{
74 Model: anthropic.Model(a.model.APIModel),
75 MaxTokens: a.maxTokens,
76 Temperature: anthropic.Float(0),
77 Messages: anthropicMessages,
78 Tools: anthropicTools,
79 System: []anthropic.TextBlockParam{
80 {
81 Text: a.systemMessage,
82 CacheControl: anthropic.CacheControlEphemeralParam{
83 Type: "ephemeral",
84 },
85 },
86 },
87 },
88 option.WithMaxRetries(8),
89 )
90 if err != nil {
91 return nil, err
92 }
93
94 content := ""
95 for _, block := range response.Content {
96 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
97 content += text.Text
98 }
99 }
100
101 toolCalls := a.extractToolCalls(response.Content)
102 tokenUsage := a.extractTokenUsage(response.Usage)
103
104 return &ProviderResponse{
105 Content: content,
106 ToolCalls: toolCalls,
107 Usage: tokenUsage,
108 }, nil
109}
110
111func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
112 anthropicMessages := a.convertToAnthropicMessages(messages)
113 anthropicTools := a.convertToAnthropicTools(tools)
114
115 var thinkingParam anthropic.ThinkingConfigParamUnion
116 lastMessage := messages[len(messages)-1]
117 temperature := anthropic.Float(0)
118 if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
119 thinkingParam = anthropic.ThinkingConfigParamUnion{
120 OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
121 BudgetTokens: int64(float64(a.maxTokens) * 0.8),
122 Type: "enabled",
123 },
124 }
125 temperature = anthropic.Float(1)
126 }
127
128 stream := a.client.Messages.NewStreaming(
129 ctx,
130 anthropic.MessageNewParams{
131 Model: anthropic.Model(a.model.APIModel),
132 MaxTokens: a.maxTokens,
133 Temperature: temperature,
134 Messages: anthropicMessages,
135 Tools: anthropicTools,
136 Thinking: thinkingParam,
137 System: []anthropic.TextBlockParam{
138 {
139 Text: a.systemMessage,
140 CacheControl: anthropic.CacheControlEphemeralParam{
141 Type: "ephemeral",
142 },
143 },
144 },
145 },
146 option.WithMaxRetries(8),
147 )
148
149 eventChan := make(chan ProviderEvent)
150
151 go func() {
152 defer close(eventChan)
153
154 accumulatedMessage := anthropic.Message{}
155
156 for stream.Next() {
157 event := stream.Current()
158 err := accumulatedMessage.Accumulate(event)
159 if err != nil {
160 eventChan <- ProviderEvent{Type: EventError, Error: err}
161 return
162 }
163
164 switch event := event.AsAny().(type) {
165 case anthropic.ContentBlockStartEvent:
166 eventChan <- ProviderEvent{Type: EventContentStart}
167
168 case anthropic.ContentBlockDeltaEvent:
169 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
170 eventChan <- ProviderEvent{
171 Type: EventThinkingDelta,
172 Thinking: event.Delta.Thinking,
173 }
174 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
175 eventChan <- ProviderEvent{
176 Type: EventContentDelta,
177 Content: event.Delta.Text,
178 }
179 }
180
181 case anthropic.ContentBlockStopEvent:
182 eventChan <- ProviderEvent{Type: EventContentStop}
183
184 case anthropic.MessageStopEvent:
185 content := ""
186 for _, block := range accumulatedMessage.Content {
187 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
188 content += text.Text
189 }
190 }
191
192 toolCalls := a.extractToolCalls(accumulatedMessage.Content)
193 tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
194
195 eventChan <- ProviderEvent{
196 Type: EventComplete,
197 Response: &ProviderResponse{
198 Content: content,
199 ToolCalls: toolCalls,
200 Usage: tokenUsage,
201 FinishReason: string(accumulatedMessage.StopReason),
202 },
203 }
204 }
205 }
206
207 if stream.Err() != nil {
208 eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
209 }
210 }()
211
212 return eventChan, nil
213}
214
215func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
216 var toolCalls []message.ToolCall
217
218 for _, block := range content {
219 switch variant := block.AsAny().(type) {
220 case anthropic.ToolUseBlock:
221 toolCall := message.ToolCall{
222 ID: variant.ID,
223 Name: variant.Name,
224 Input: string(variant.Input),
225 Type: string(variant.Type),
226 }
227 toolCalls = append(toolCalls, toolCall)
228 }
229 }
230
231 return toolCalls
232}
233
234func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
235 return TokenUsage{
236 InputTokens: usage.InputTokens,
237 OutputTokens: usage.OutputTokens,
238 CacheCreationTokens: usage.CacheCreationInputTokens,
239 CacheReadTokens: usage.CacheReadInputTokens,
240 }
241}
242
243func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
244 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
245
246 for i, tool := range tools {
247 info := tool.Info()
248 toolParam := anthropic.ToolParam{
249 Name: info.Name,
250 Description: anthropic.String(info.Description),
251 InputSchema: anthropic.ToolInputSchemaParam{
252 Properties: info.Parameters,
253 },
254 }
255
256 if i == len(tools)-1 {
257 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
258 Type: "ephemeral",
259 }
260 }
261
262 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
263 }
264
265 return anthropicTools
266}
267
268func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
269 anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
270 cachedBlocks := 0
271
272 for _, msg := range messages {
273 switch msg.Role {
274 case message.User:
275 content := anthropic.NewTextBlock(msg.Content().String())
276 if cachedBlocks < 2 {
277 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
278 Type: "ephemeral",
279 }
280 cachedBlocks++
281 }
282 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
283
284 case message.Assistant:
285 blocks := []anthropic.ContentBlockParamUnion{}
286 if msg.Content().String() != "" {
287 content := anthropic.NewTextBlock(msg.Content().String())
288 if cachedBlocks < 2 {
289 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
290 Type: "ephemeral",
291 }
292 cachedBlocks++
293 }
294 blocks = append(blocks, content)
295 }
296
297 for _, toolCall := range msg.ToolCalls() {
298 var inputMap map[string]any
299 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
300 if err != nil {
301 continue
302 }
303 blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
304 }
305
306 // Skip empty assistant messages completely
307 if len(blocks) > 0 {
308 anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
309 }
310
311 case message.Tool:
312 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
313 for i, toolResult := range msg.ToolResults() {
314 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
315 }
316 anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
317 }
318 }
319
320 return anthropicMessages
321}