1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "strings"
8
9 "github.com/anthropics/anthropic-sdk-go"
10 "github.com/anthropics/anthropic-sdk-go/option"
11 "github.com/kujtimiihoxha/termai/internal/llm/models"
12 "github.com/kujtimiihoxha/termai/internal/llm/tools"
13 "github.com/kujtimiihoxha/termai/internal/message"
14)
15
16type anthropicProvider struct {
17 client anthropic.Client
18 model models.Model
19 maxTokens int64
20 apiKey string
21 systemMessage string
22}
23
24type AnthropicOption func(*anthropicProvider)
25
26func WithAnthropicSystemMessage(message string) AnthropicOption {
27 return func(a *anthropicProvider) {
28 a.systemMessage = message
29 }
30}
31
32func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
33 return func(a *anthropicProvider) {
34 a.maxTokens = maxTokens
35 }
36}
37
38func WithAnthropicModel(model models.Model) AnthropicOption {
39 return func(a *anthropicProvider) {
40 a.model = model
41 }
42}
43
44func WithAnthropicKey(apiKey string) AnthropicOption {
45 return func(a *anthropicProvider) {
46 a.apiKey = apiKey
47 }
48}
49
50func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
51 provider := &anthropicProvider{
52 maxTokens: 1024,
53 }
54
55 for _, opt := range opts {
56 opt(provider)
57 }
58
59 if provider.systemMessage == "" {
60 return nil, errors.New("system message is required")
61 }
62
63 provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
64 return provider, nil
65}
66
67func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
68 anthropicMessages := a.convertToAnthropicMessages(messages)
69 anthropicTools := a.convertToAnthropicTools(tools)
70
71 response, err := a.client.Messages.New(ctx, anthropic.MessageNewParams{
72 Model: anthropic.Model(a.model.APIModel),
73 MaxTokens: a.maxTokens,
74 Temperature: anthropic.Float(0),
75 Messages: anthropicMessages,
76 Tools: anthropicTools,
77 System: []anthropic.TextBlockParam{
78 {
79 Text: a.systemMessage,
80 CacheControl: anthropic.CacheControlEphemeralParam{
81 Type: "ephemeral",
82 },
83 },
84 },
85 })
86 if err != nil {
87 return nil, err
88 }
89
90 content := ""
91 for _, block := range response.Content {
92 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
93 content += text.Text
94 }
95 }
96
97 toolCalls := a.extractToolCalls(response.Content)
98 tokenUsage := a.extractTokenUsage(response.Usage)
99
100 return &ProviderResponse{
101 Content: content,
102 ToolCalls: toolCalls,
103 Usage: tokenUsage,
104 }, nil
105}
106
107func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
108 anthropicMessages := a.convertToAnthropicMessages(messages)
109 anthropicTools := a.convertToAnthropicTools(tools)
110
111 var thinkingParam anthropic.ThinkingConfigParamUnion
112 lastMessage := messages[len(messages)-1]
113 temperature := anthropic.Float(0)
114 if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") {
115 thinkingParam = anthropic.ThinkingConfigParamUnion{
116 OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
117 BudgetTokens: int64(float64(a.maxTokens) * 0.8),
118 Type: "enabled",
119 },
120 }
121 temperature = anthropic.Float(1)
122 }
123
124 stream := a.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
125 Model: anthropic.Model(a.model.APIModel),
126 MaxTokens: a.maxTokens,
127 Temperature: temperature,
128 Messages: anthropicMessages,
129 Tools: anthropicTools,
130 Thinking: thinkingParam,
131 System: []anthropic.TextBlockParam{
132 {
133 Text: a.systemMessage,
134 CacheControl: anthropic.CacheControlEphemeralParam{
135 Type: "ephemeral",
136 },
137 },
138 },
139 })
140
141 eventChan := make(chan ProviderEvent)
142
143 go func() {
144 defer close(eventChan)
145
146 accumulatedMessage := anthropic.Message{}
147
148 for stream.Next() {
149 event := stream.Current()
150 err := accumulatedMessage.Accumulate(event)
151 if err != nil {
152 eventChan <- ProviderEvent{Type: EventError, Error: err}
153 return
154 }
155
156 switch event := event.AsAny().(type) {
157 case anthropic.ContentBlockStartEvent:
158 eventChan <- ProviderEvent{Type: EventContentStart}
159
160 case anthropic.ContentBlockDeltaEvent:
161 if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
162 eventChan <- ProviderEvent{
163 Type: EventThinkingDelta,
164 Thinking: event.Delta.Thinking,
165 }
166 } else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
167 eventChan <- ProviderEvent{
168 Type: EventContentDelta,
169 Content: event.Delta.Text,
170 }
171 }
172
173 case anthropic.ContentBlockStopEvent:
174 eventChan <- ProviderEvent{Type: EventContentStop}
175
176 case anthropic.MessageStopEvent:
177 content := ""
178 for _, block := range accumulatedMessage.Content {
179 if text, ok := block.AsAny().(anthropic.TextBlock); ok {
180 content += text.Text
181 }
182 }
183
184 toolCalls := a.extractToolCalls(accumulatedMessage.Content)
185 tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
186
187 eventChan <- ProviderEvent{
188 Type: EventComplete,
189 Response: &ProviderResponse{
190 Content: content,
191 ToolCalls: toolCalls,
192 Usage: tokenUsage,
193 },
194 }
195 }
196 }
197
198 if stream.Err() != nil {
199 eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
200 }
201 }()
202
203 return eventChan, nil
204}
205
206func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
207 var toolCalls []message.ToolCall
208
209 for _, block := range content {
210 switch variant := block.AsAny().(type) {
211 case anthropic.ToolUseBlock:
212 toolCall := message.ToolCall{
213 ID: variant.ID,
214 Name: variant.Name,
215 Input: string(variant.Input),
216 Type: string(variant.Type),
217 }
218 toolCalls = append(toolCalls, toolCall)
219 }
220 }
221
222 return toolCalls
223}
224
225func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
226 return TokenUsage{
227 InputTokens: usage.InputTokens,
228 OutputTokens: usage.OutputTokens,
229 CacheCreationTokens: usage.CacheCreationInputTokens,
230 CacheReadTokens: usage.CacheReadInputTokens,
231 }
232}
233
234func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
235 anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
236
237 for i, tool := range tools {
238 info := tool.Info()
239 toolParam := anthropic.ToolParam{
240 Name: info.Name,
241 Description: anthropic.String(info.Description),
242 InputSchema: anthropic.ToolInputSchemaParam{
243 Properties: info.Parameters,
244 },
245 }
246
247 if i == len(tools)-1 {
248 toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
249 Type: "ephemeral",
250 }
251 }
252
253 anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
254 }
255
256 return anthropicTools
257}
258
259func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
260 anthropicMessages := make([]anthropic.MessageParam, len(messages))
261 cachedBlocks := 0
262
263 for i, msg := range messages {
264 switch msg.Role {
265 case message.User:
266 content := anthropic.NewTextBlock(msg.Content)
267 if cachedBlocks < 2 {
268 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
269 Type: "ephemeral",
270 }
271 cachedBlocks++
272 }
273 anthropicMessages[i] = anthropic.NewUserMessage(content)
274
275 case message.Assistant:
276 blocks := []anthropic.ContentBlockParamUnion{}
277 if msg.Content != "" {
278 content := anthropic.NewTextBlock(msg.Content)
279 if cachedBlocks < 2 {
280 content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
281 Type: "ephemeral",
282 }
283 cachedBlocks++
284 }
285 blocks = append(blocks, content)
286 }
287
288 for _, toolCall := range msg.ToolCalls {
289 var inputMap map[string]any
290 err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
291 if err != nil {
292 continue
293 }
294 blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
295 }
296
297 anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...)
298
299 case message.Tool:
300 results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults))
301 for i, toolResult := range msg.ToolResults {
302 results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
303 }
304 anthropicMessages[i] = anthropic.NewUserMessage(results...)
305 }
306 }
307
308 return anthropicMessages
309}