anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"strings"
  8
  9	"github.com/anthropics/anthropic-sdk-go"
 10	"github.com/anthropics/anthropic-sdk-go/option"
 11	"github.com/kujtimiihoxha/termai/internal/llm/models"
 12	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 13	"github.com/kujtimiihoxha/termai/internal/message"
 14)
 15
 16type anthropicProvider struct {
 17	client        anthropic.Client
 18	model         models.Model
 19	maxTokens     int64
 20	apiKey        string
 21	systemMessage string
 22}
 23
 24type AnthropicOption func(*anthropicProvider)
 25
 26func WithAnthropicSystemMessage(message string) AnthropicOption {
 27	return func(a *anthropicProvider) {
 28		a.systemMessage = message
 29	}
 30}
 31
 32func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 33	return func(a *anthropicProvider) {
 34		a.maxTokens = maxTokens
 35	}
 36}
 37
 38func WithAnthropicModel(model models.Model) AnthropicOption {
 39	return func(a *anthropicProvider) {
 40		a.model = model
 41	}
 42}
 43
 44func WithAnthropicKey(apiKey string) AnthropicOption {
 45	return func(a *anthropicProvider) {
 46		a.apiKey = apiKey
 47	}
 48}
 49
 50func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 51	provider := &anthropicProvider{
 52		maxTokens: 1024,
 53	}
 54
 55	for _, opt := range opts {
 56		opt(provider)
 57	}
 58
 59	if provider.systemMessage == "" {
 60		return nil, errors.New("system message is required")
 61	}
 62
 63	provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
 64	return provider, nil
 65}
 66
 67func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 68	anthropicMessages := a.convertToAnthropicMessages(messages)
 69	anthropicTools := a.convertToAnthropicTools(tools)
 70
 71	response, err := a.client.Messages.New(ctx, anthropic.MessageNewParams{
 72		Model:       anthropic.Model(a.model.APIModel),
 73		MaxTokens:   a.maxTokens,
 74		Temperature: anthropic.Float(0),
 75		Messages:    anthropicMessages,
 76		Tools:       anthropicTools,
 77		System: []anthropic.TextBlockParam{
 78			{
 79				Text: a.systemMessage,
 80				CacheControl: anthropic.CacheControlEphemeralParam{
 81					Type: "ephemeral",
 82				},
 83			},
 84		},
 85	})
 86	if err != nil {
 87		return nil, err
 88	}
 89
 90	content := ""
 91	for _, block := range response.Content {
 92		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
 93			content += text.Text
 94		}
 95	}
 96
 97	toolCalls := a.extractToolCalls(response.Content)
 98	tokenUsage := a.extractTokenUsage(response.Usage)
 99
100	return &ProviderResponse{
101		Content:   content,
102		ToolCalls: toolCalls,
103		Usage:     tokenUsage,
104	}, nil
105}
106
107func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
108	anthropicMessages := a.convertToAnthropicMessages(messages)
109	anthropicTools := a.convertToAnthropicTools(tools)
110
111	var thinkingParam anthropic.ThinkingConfigParamUnion
112	lastMessage := messages[len(messages)-1]
113	temperature := anthropic.Float(0)
114	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
115		thinkingParam = anthropic.ThinkingConfigParamUnion{
116			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
117				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
118				Type:         "enabled",
119			},
120		}
121		temperature = anthropic.Float(1)
122	}
123
124	stream := a.client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{
125		Model:       anthropic.Model(a.model.APIModel),
126		MaxTokens:   a.maxTokens,
127		Temperature: temperature,
128		Messages:    anthropicMessages,
129		Tools:       anthropicTools,
130		Thinking:    thinkingParam,
131		System: []anthropic.TextBlockParam{
132			{
133				Text: a.systemMessage,
134				CacheControl: anthropic.CacheControlEphemeralParam{
135					Type: "ephemeral",
136				},
137			},
138		},
139	})
140
141	eventChan := make(chan ProviderEvent)
142
143	go func() {
144		defer close(eventChan)
145
146		accumulatedMessage := anthropic.Message{}
147
148		for stream.Next() {
149			event := stream.Current()
150			err := accumulatedMessage.Accumulate(event)
151			if err != nil {
152				eventChan <- ProviderEvent{Type: EventError, Error: err}
153				return
154			}
155
156			switch event := event.AsAny().(type) {
157			case anthropic.ContentBlockStartEvent:
158				eventChan <- ProviderEvent{Type: EventContentStart}
159
160			case anthropic.ContentBlockDeltaEvent:
161				if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
162					eventChan <- ProviderEvent{
163						Type:     EventThinkingDelta,
164						Thinking: event.Delta.Thinking,
165					}
166				} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
167					eventChan <- ProviderEvent{
168						Type:    EventContentDelta,
169						Content: event.Delta.Text,
170					}
171				}
172
173			case anthropic.ContentBlockStopEvent:
174				eventChan <- ProviderEvent{Type: EventContentStop}
175
176			case anthropic.MessageStopEvent:
177				content := ""
178				for _, block := range accumulatedMessage.Content {
179					if text, ok := block.AsAny().(anthropic.TextBlock); ok {
180						content += text.Text
181					}
182				}
183
184				toolCalls := a.extractToolCalls(accumulatedMessage.Content)
185				tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
186
187				eventChan <- ProviderEvent{
188					Type: EventComplete,
189					Response: &ProviderResponse{
190						Content:      content,
191						ToolCalls:    toolCalls,
192						Usage:        tokenUsage,
193						FinishReason: string(accumulatedMessage.StopReason),
194					},
195				}
196			}
197		}
198
199		if stream.Err() != nil {
200			eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
201		}
202	}()
203
204	return eventChan, nil
205}
206
207func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
208	var toolCalls []message.ToolCall
209
210	for _, block := range content {
211		switch variant := block.AsAny().(type) {
212		case anthropic.ToolUseBlock:
213			toolCall := message.ToolCall{
214				ID:    variant.ID,
215				Name:  variant.Name,
216				Input: string(variant.Input),
217				Type:  string(variant.Type),
218			}
219			toolCalls = append(toolCalls, toolCall)
220		}
221	}
222
223	return toolCalls
224}
225
226func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
227	return TokenUsage{
228		InputTokens:         usage.InputTokens,
229		OutputTokens:        usage.OutputTokens,
230		CacheCreationTokens: usage.CacheCreationInputTokens,
231		CacheReadTokens:     usage.CacheReadInputTokens,
232	}
233}
234
235func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
236	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
237
238	for i, tool := range tools {
239		info := tool.Info()
240		toolParam := anthropic.ToolParam{
241			Name:        info.Name,
242			Description: anthropic.String(info.Description),
243			InputSchema: anthropic.ToolInputSchemaParam{
244				Properties: info.Parameters,
245			},
246		}
247
248		if i == len(tools)-1 {
249			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
250				Type: "ephemeral",
251			}
252		}
253
254		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
255	}
256
257	return anthropicTools
258}
259
260func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
261	anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
262	cachedBlocks := 0
263
264	for _, msg := range messages {
265		switch msg.Role {
266		case message.User:
267			content := anthropic.NewTextBlock(msg.Content().String())
268			if cachedBlocks < 2 {
269				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
270					Type: "ephemeral",
271				}
272				cachedBlocks++
273			}
274			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
275
276		case message.Assistant:
277			blocks := []anthropic.ContentBlockParamUnion{}
278			if msg.Content().String() != "" {
279				content := anthropic.NewTextBlock(msg.Content().String())
280				if cachedBlocks < 2 {
281					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
282						Type: "ephemeral",
283					}
284					cachedBlocks++
285				}
286				blocks = append(blocks, content)
287			}
288
289			for _, toolCall := range msg.ToolCalls() {
290				var inputMap map[string]any
291				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
292				if err != nil {
293					continue
294				}
295				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
296			}
297
298			// Skip empty assistant messages completely
299			if len(blocks) > 0 {
300				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
301			}
302
303		case message.Tool:
304			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
305			for i, toolResult := range msg.ToolResults() {
306				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
307			}
308			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
309		}
310	}
311
312	return anthropicMessages
313}