anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"strings"
  8
  9	"github.com/anthropics/anthropic-sdk-go"
 10	"github.com/anthropics/anthropic-sdk-go/option"
 11	"github.com/kujtimiihoxha/termai/internal/llm/models"
 12	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 13	"github.com/kujtimiihoxha/termai/internal/message"
 14)
 15
 16type anthropicProvider struct {
 17	client        anthropic.Client
 18	model         models.Model
 19	maxTokens     int64
 20	apiKey        string
 21	systemMessage string
 22}
 23
 24type AnthropicOption func(*anthropicProvider)
 25
 26func WithAnthropicSystemMessage(message string) AnthropicOption {
 27	return func(a *anthropicProvider) {
 28		a.systemMessage = message
 29	}
 30}
 31
 32func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 33	return func(a *anthropicProvider) {
 34		a.maxTokens = maxTokens
 35	}
 36}
 37
 38func WithAnthropicModel(model models.Model) AnthropicOption {
 39	return func(a *anthropicProvider) {
 40		a.model = model
 41	}
 42}
 43
 44func WithAnthropicKey(apiKey string) AnthropicOption {
 45	return func(a *anthropicProvider) {
 46		a.apiKey = apiKey
 47	}
 48}
 49
 50func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 51	provider := &anthropicProvider{
 52		maxTokens: 1024,
 53	}
 54
 55	for _, opt := range opts {
 56		opt(provider)
 57	}
 58
 59	if provider.systemMessage == "" {
 60		return nil, errors.New("system message is required")
 61	}
 62
 63	provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
 64	return provider, nil
 65}
 66
 67func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 68	anthropicMessages := a.convertToAnthropicMessages(messages)
 69	anthropicTools := a.convertToAnthropicTools(tools)
 70
 71	response, err := a.client.Messages.New(ctx, anthropic.MessageNewParams{
 72		Model:       anthropic.Model(a.model.APIModel),
 73		MaxTokens:   a.maxTokens,
 74		Temperature: anthropic.Float(0),
 75		Messages:    anthropicMessages,
 76		Tools:       anthropicTools,
 77		System: []anthropic.TextBlockParam{
 78			{
 79				Text: a.systemMessage,
 80				CacheControl: anthropic.CacheControlEphemeralParam{
 81					Type: "ephemeral",
 82				},
 83			},
 84		},
 85	})
 86	if err != nil {
 87		return nil, err
 88	}
 89
 90	content := ""
 91	for _, block := range response.Content {
 92		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
 93			content += text.Text
 94		}
 95	}
 96
 97	toolCalls := a.extractToolCalls(response.Content)
 98	tokenUsage := a.extractTokenUsage(response.Usage)
 99
100	return &ProviderResponse{
101		Content:   content,
102		ToolCalls: toolCalls,
103		Usage:     tokenUsage,
104	}, nil
105}
106
107func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
108	anthropicMessages := a.convertToAnthropicMessages(messages)
109	anthropicTools := a.convertToAnthropicTools(tools)
110
111	var thinkingParam anthropic.ThinkingConfigParamUnion
112	lastMessage := messages[len(messages)-1]
113	temperature := anthropic.Float(0)
114	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content), "think") {
115		thinkingParam = anthropic.ThinkingConfigParamUnion{
116			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
117				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
118				Type:         "enabled",
119			},
120		}
121		temperature = anthropic.Float(1)
122	}
123
124	stream := a.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
125		Model:       anthropic.Model(a.model.APIModel),
126		MaxTokens:   a.maxTokens,
127		Temperature: temperature,
128		Messages:    anthropicMessages,
129		Tools:       anthropicTools,
130		Thinking:    thinkingParam,
131		System: []anthropic.TextBlockParam{
132			{
133				Text: a.systemMessage,
134				CacheControl: anthropic.CacheControlEphemeralParam{
135					Type: "ephemeral",
136				},
137			},
138		},
139	})
140
141	eventChan := make(chan ProviderEvent)
142
143	go func() {
144		defer close(eventChan)
145
146		accumulatedMessage := anthropic.Message{}
147
148		for stream.Next() {
149			event := stream.Current()
150			err := accumulatedMessage.Accumulate(event)
151			if err != nil {
152				eventChan <- ProviderEvent{Type: EventError, Error: err}
153				return
154			}
155
156			switch event := event.AsAny().(type) {
157			case anthropic.ContentBlockStartEvent:
158				eventChan <- ProviderEvent{Type: EventContentStart}
159
160			case anthropic.ContentBlockDeltaEvent:
161				if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
162					eventChan <- ProviderEvent{
163						Type:     EventThinkingDelta,
164						Thinking: event.Delta.Thinking,
165					}
166				} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
167					eventChan <- ProviderEvent{
168						Type:    EventContentDelta,
169						Content: event.Delta.Text,
170					}
171				}
172
173			case anthropic.ContentBlockStopEvent:
174				eventChan <- ProviderEvent{Type: EventContentStop}
175
176			case anthropic.MessageStopEvent:
177				content := ""
178				for _, block := range accumulatedMessage.Content {
179					if text, ok := block.AsAny().(anthropic.TextBlock); ok {
180						content += text.Text
181					}
182				}
183
184				toolCalls := a.extractToolCalls(accumulatedMessage.Content)
185				tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
186
187				eventChan <- ProviderEvent{
188					Type: EventComplete,
189					Response: &ProviderResponse{
190						Content:   content,
191						ToolCalls: toolCalls,
192						Usage:     tokenUsage,
193					},
194				}
195			}
196		}
197
198		if stream.Err() != nil {
199			eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
200		}
201	}()
202
203	return eventChan, nil
204}
205
206func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
207	var toolCalls []message.ToolCall
208
209	for _, block := range content {
210		switch variant := block.AsAny().(type) {
211		case anthropic.ToolUseBlock:
212			toolCall := message.ToolCall{
213				ID:    variant.ID,
214				Name:  variant.Name,
215				Input: string(variant.Input),
216				Type:  string(variant.Type),
217			}
218			toolCalls = append(toolCalls, toolCall)
219		}
220	}
221
222	return toolCalls
223}
224
225func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
226	return TokenUsage{
227		InputTokens:         usage.InputTokens,
228		OutputTokens:        usage.OutputTokens,
229		CacheCreationTokens: usage.CacheCreationInputTokens,
230		CacheReadTokens:     usage.CacheReadInputTokens,
231	}
232}
233
234func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
235	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
236
237	for i, tool := range tools {
238		info := tool.Info()
239		toolParam := anthropic.ToolParam{
240			Name:        info.Name,
241			Description: anthropic.String(info.Description),
242			InputSchema: anthropic.ToolInputSchemaParam{
243				Properties: info.Parameters,
244			},
245		}
246
247		if i == len(tools)-1 {
248			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
249				Type: "ephemeral",
250			}
251		}
252
253		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
254	}
255
256	return anthropicTools
257}
258
259func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
260	anthropicMessages := make([]anthropic.MessageParam, len(messages))
261	cachedBlocks := 0
262
263	for i, msg := range messages {
264		switch msg.Role {
265		case message.User:
266			content := anthropic.NewTextBlock(msg.Content)
267			if cachedBlocks < 2 {
268				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
269					Type: "ephemeral",
270				}
271				cachedBlocks++
272			}
273			anthropicMessages[i] = anthropic.NewUserMessage(content)
274
275		case message.Assistant:
276			blocks := []anthropic.ContentBlockParamUnion{}
277			if msg.Content != "" {
278				content := anthropic.NewTextBlock(msg.Content)
279				if cachedBlocks < 2 {
280					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
281						Type: "ephemeral",
282					}
283					cachedBlocks++
284				}
285				blocks = append(blocks, content)
286			}
287
288			for _, toolCall := range msg.ToolCalls {
289				var inputMap map[string]any
290				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
291				if err != nil {
292					continue
293				}
294				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
295			}
296
297			anthropicMessages[i] = anthropic.NewAssistantMessage(blocks...)
298
299		case message.Tool:
300			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults))
301			for i, toolResult := range msg.ToolResults {
302				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
303			}
304			anthropicMessages[i] = anthropic.NewUserMessage(results...)
305		}
306	}
307
308	return anthropicMessages
309}