anthropic.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"strings"
  8
  9	"github.com/anthropics/anthropic-sdk-go"
 10	"github.com/anthropics/anthropic-sdk-go/option"
 11	"github.com/kujtimiihoxha/termai/internal/llm/models"
 12	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 13	"github.com/kujtimiihoxha/termai/internal/message"
 14)
 15
 16type anthropicProvider struct {
 17	client        anthropic.Client
 18	model         models.Model
 19	maxTokens     int64
 20	apiKey        string
 21	systemMessage string
 22}
 23
 24type AnthropicOption func(*anthropicProvider)
 25
 26func WithAnthropicSystemMessage(message string) AnthropicOption {
 27	return func(a *anthropicProvider) {
 28		a.systemMessage = message
 29	}
 30}
 31
 32func WithAnthropicMaxTokens(maxTokens int64) AnthropicOption {
 33	return func(a *anthropicProvider) {
 34		a.maxTokens = maxTokens
 35	}
 36}
 37
 38func WithAnthropicModel(model models.Model) AnthropicOption {
 39	return func(a *anthropicProvider) {
 40		a.model = model
 41	}
 42}
 43
 44func WithAnthropicKey(apiKey string) AnthropicOption {
 45	return func(a *anthropicProvider) {
 46		a.apiKey = apiKey
 47	}
 48}
 49
 50func NewAnthropicProvider(opts ...AnthropicOption) (Provider, error) {
 51	provider := &anthropicProvider{
 52		maxTokens: 1024,
 53	}
 54
 55	for _, opt := range opts {
 56		opt(provider)
 57	}
 58
 59	if provider.systemMessage == "" {
 60		return nil, errors.New("system message is required")
 61	}
 62
 63	provider.client = anthropic.NewClient(option.WithAPIKey(provider.apiKey))
 64	return provider, nil
 65}
 66
 67func (a *anthropicProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 68	anthropicMessages := a.convertToAnthropicMessages(messages)
 69	anthropicTools := a.convertToAnthropicTools(tools)
 70
 71	response, err := a.client.Messages.New(
 72		ctx,
 73		anthropic.MessageNewParams{
 74			Model:       anthropic.Model(a.model.APIModel),
 75			MaxTokens:   a.maxTokens,
 76			Temperature: anthropic.Float(0),
 77			Messages:    anthropicMessages,
 78			Tools:       anthropicTools,
 79			System: []anthropic.TextBlockParam{
 80				{
 81					Text: a.systemMessage,
 82					CacheControl: anthropic.CacheControlEphemeralParam{
 83						Type: "ephemeral",
 84					},
 85				},
 86			},
 87		},
 88		option.WithMaxRetries(8),
 89	)
 90	if err != nil {
 91		return nil, err
 92	}
 93
 94	content := ""
 95	for _, block := range response.Content {
 96		if text, ok := block.AsAny().(anthropic.TextBlock); ok {
 97			content += text.Text
 98		}
 99	}
100
101	toolCalls := a.extractToolCalls(response.Content)
102	tokenUsage := a.extractTokenUsage(response.Usage)
103
104	return &ProviderResponse{
105		Content:   content,
106		ToolCalls: toolCalls,
107		Usage:     tokenUsage,
108	}, nil
109}
110
111func (a *anthropicProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
112	anthropicMessages := a.convertToAnthropicMessages(messages)
113	anthropicTools := a.convertToAnthropicTools(tools)
114
115	var thinkingParam anthropic.ThinkingConfigParamUnion
116	lastMessage := messages[len(messages)-1]
117	temperature := anthropic.Float(0)
118	if lastMessage.Role == message.User && strings.Contains(strings.ToLower(lastMessage.Content().String()), "think") {
119		thinkingParam = anthropic.ThinkingConfigParamUnion{
120			OfThinkingConfigEnabled: &anthropic.ThinkingConfigEnabledParam{
121				BudgetTokens: int64(float64(a.maxTokens) * 0.8),
122				Type:         "enabled",
123			},
124		}
125		temperature = anthropic.Float(1)
126	}
127
128	stream := a.client.Messages.NewStreaming(
129		ctx,
130		anthropic.MessageNewParams{
131			Model:       anthropic.Model(a.model.APIModel),
132			MaxTokens:   a.maxTokens,
133			Temperature: temperature,
134			Messages:    anthropicMessages,
135			Tools:       anthropicTools,
136			Thinking:    thinkingParam,
137			System: []anthropic.TextBlockParam{
138				{
139					Text: a.systemMessage,
140					CacheControl: anthropic.CacheControlEphemeralParam{
141						Type: "ephemeral",
142					},
143				},
144			},
145		},
146		option.WithMaxRetries(8),
147	)
148
149	eventChan := make(chan ProviderEvent)
150
151	go func() {
152		defer close(eventChan)
153
154		accumulatedMessage := anthropic.Message{}
155
156		for stream.Next() {
157			event := stream.Current()
158			err := accumulatedMessage.Accumulate(event)
159			if err != nil {
160				eventChan <- ProviderEvent{Type: EventError, Error: err}
161				return
162			}
163
164			switch event := event.AsAny().(type) {
165			case anthropic.ContentBlockStartEvent:
166				eventChan <- ProviderEvent{Type: EventContentStart}
167
168			case anthropic.ContentBlockDeltaEvent:
169				if event.Delta.Type == "thinking_delta" && event.Delta.Thinking != "" {
170					eventChan <- ProviderEvent{
171						Type:     EventThinkingDelta,
172						Thinking: event.Delta.Thinking,
173					}
174				} else if event.Delta.Type == "text_delta" && event.Delta.Text != "" {
175					eventChan <- ProviderEvent{
176						Type:    EventContentDelta,
177						Content: event.Delta.Text,
178					}
179				}
180
181			case anthropic.ContentBlockStopEvent:
182				eventChan <- ProviderEvent{Type: EventContentStop}
183
184			case anthropic.MessageStopEvent:
185				content := ""
186				for _, block := range accumulatedMessage.Content {
187					if text, ok := block.AsAny().(anthropic.TextBlock); ok {
188						content += text.Text
189					}
190				}
191
192				toolCalls := a.extractToolCalls(accumulatedMessage.Content)
193				tokenUsage := a.extractTokenUsage(accumulatedMessage.Usage)
194
195				eventChan <- ProviderEvent{
196					Type: EventComplete,
197					Response: &ProviderResponse{
198						Content:      content,
199						ToolCalls:    toolCalls,
200						Usage:        tokenUsage,
201						FinishReason: string(accumulatedMessage.StopReason),
202					},
203				}
204			}
205		}
206
207		if stream.Err() != nil {
208			eventChan <- ProviderEvent{Type: EventError, Error: stream.Err()}
209		}
210	}()
211
212	return eventChan, nil
213}
214
215func (a *anthropicProvider) extractToolCalls(content []anthropic.ContentBlockUnion) []message.ToolCall {
216	var toolCalls []message.ToolCall
217
218	for _, block := range content {
219		switch variant := block.AsAny().(type) {
220		case anthropic.ToolUseBlock:
221			toolCall := message.ToolCall{
222				ID:    variant.ID,
223				Name:  variant.Name,
224				Input: string(variant.Input),
225				Type:  string(variant.Type),
226			}
227			toolCalls = append(toolCalls, toolCall)
228		}
229	}
230
231	return toolCalls
232}
233
234func (a *anthropicProvider) extractTokenUsage(usage anthropic.Usage) TokenUsage {
235	return TokenUsage{
236		InputTokens:         usage.InputTokens,
237		OutputTokens:        usage.OutputTokens,
238		CacheCreationTokens: usage.CacheCreationInputTokens,
239		CacheReadTokens:     usage.CacheReadInputTokens,
240	}
241}
242
243func (a *anthropicProvider) convertToAnthropicTools(tools []tools.BaseTool) []anthropic.ToolUnionParam {
244	anthropicTools := make([]anthropic.ToolUnionParam, len(tools))
245
246	for i, tool := range tools {
247		info := tool.Info()
248		toolParam := anthropic.ToolParam{
249			Name:        info.Name,
250			Description: anthropic.String(info.Description),
251			InputSchema: anthropic.ToolInputSchemaParam{
252				Properties: info.Parameters,
253			},
254		}
255
256		if i == len(tools)-1 {
257			toolParam.CacheControl = anthropic.CacheControlEphemeralParam{
258				Type: "ephemeral",
259			}
260		}
261
262		anthropicTools[i] = anthropic.ToolUnionParam{OfTool: &toolParam}
263	}
264
265	return anthropicTools
266}
267
268func (a *anthropicProvider) convertToAnthropicMessages(messages []message.Message) []anthropic.MessageParam {
269	anthropicMessages := make([]anthropic.MessageParam, 0, len(messages))
270	cachedBlocks := 0
271
272	for _, msg := range messages {
273		switch msg.Role {
274		case message.User:
275			content := anthropic.NewTextBlock(msg.Content().String())
276			if cachedBlocks < 2 {
277				content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
278					Type: "ephemeral",
279				}
280				cachedBlocks++
281			}
282			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(content))
283
284		case message.Assistant:
285			blocks := []anthropic.ContentBlockParamUnion{}
286			if msg.Content().String() != "" {
287				content := anthropic.NewTextBlock(msg.Content().String())
288				if cachedBlocks < 2 {
289					content.OfRequestTextBlock.CacheControl = anthropic.CacheControlEphemeralParam{
290						Type: "ephemeral",
291					}
292					cachedBlocks++
293				}
294				blocks = append(blocks, content)
295			}
296
297			for _, toolCall := range msg.ToolCalls() {
298				var inputMap map[string]any
299				err := json.Unmarshal([]byte(toolCall.Input), &inputMap)
300				if err != nil {
301					continue
302				}
303				blocks = append(blocks, anthropic.ContentBlockParamOfRequestToolUseBlock(toolCall.ID, inputMap, toolCall.Name))
304			}
305
306			// Skip empty assistant messages completely
307			if len(blocks) > 0 {
308				anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
309			}
310
311		case message.Tool:
312			results := make([]anthropic.ContentBlockParamUnion, len(msg.ToolResults()))
313			for i, toolResult := range msg.ToolResults() {
314				results[i] = anthropic.NewToolResultBlock(toolResult.ToolCallID, toolResult.Content, toolResult.IsError)
315			}
316			anthropicMessages = append(anthropicMessages, anthropic.NewUserMessage(results...))
317		}
318	}
319
320	return anthropicMessages
321}