provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6	"os"
  7
  8	"github.com/charmbracelet/crush/internal/llm/models"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11)
 12
 13type EventType string
 14
 15const maxRetries = 8
 16
 17const (
 18	EventContentStart  EventType = "content_start"
 19	EventToolUseStart  EventType = "tool_use_start"
 20	EventToolUseDelta  EventType = "tool_use_delta"
 21	EventToolUseStop   EventType = "tool_use_stop"
 22	EventContentDelta  EventType = "content_delta"
 23	EventThinkingDelta EventType = "thinking_delta"
 24	EventContentStop   EventType = "content_stop"
 25	EventComplete      EventType = "complete"
 26	EventError         EventType = "error"
 27	EventWarning       EventType = "warning"
 28)
 29
 30type TokenUsage struct {
 31	InputTokens         int64
 32	OutputTokens        int64
 33	CacheCreationTokens int64
 34	CacheReadTokens     int64
 35}
 36
 37type ProviderResponse struct {
 38	Content      string
 39	ToolCalls    []message.ToolCall
 40	Usage        TokenUsage
 41	FinishReason message.FinishReason
 42}
 43
 44type ProviderEvent struct {
 45	Type EventType
 46
 47	Content  string
 48	Thinking string
 49	Response *ProviderResponse
 50	ToolCall *message.ToolCall
 51	Error    error
 52}
 53type Provider interface {
 54	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 55
 56	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 57
 58	Model() models.Model
 59}
 60
 61type providerClientOptions struct {
 62	apiKey        string
 63	model         models.Model
 64	maxTokens     int64
 65	systemMessage string
 66
 67	anthropicOptions []AnthropicOption
 68	openaiOptions    []OpenAIOption
 69	geminiOptions    []GeminiOption
 70	bedrockOptions   []BedrockOption
 71}
 72
 73type ProviderClientOption func(*providerClientOptions)
 74
 75type ProviderClient interface {
 76	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 77	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 78}
 79
 80type baseProvider[C ProviderClient] struct {
 81	options providerClientOptions
 82	client  C
 83}
 84
 85func NewProvider(providerName models.ModelProvider, opts ...ProviderClientOption) (Provider, error) {
 86	clientOptions := providerClientOptions{}
 87	for _, o := range opts {
 88		o(&clientOptions)
 89	}
 90	switch providerName {
 91	case models.ProviderAnthropic:
 92		return &baseProvider[AnthropicClient]{
 93			options: clientOptions,
 94			client:  newAnthropicClient(clientOptions),
 95		}, nil
 96	case models.ProviderOpenAI:
 97		return &baseProvider[OpenAIClient]{
 98			options: clientOptions,
 99			client:  newOpenAIClient(clientOptions),
100		}, nil
101	case models.ProviderGemini:
102		return &baseProvider[GeminiClient]{
103			options: clientOptions,
104			client:  newGeminiClient(clientOptions),
105		}, nil
106	case models.ProviderBedrock:
107		return &baseProvider[BedrockClient]{
108			options: clientOptions,
109			client:  newBedrockClient(clientOptions),
110		}, nil
111	case models.ProviderGROQ:
112		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
113			WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
114		)
115		return &baseProvider[OpenAIClient]{
116			options: clientOptions,
117			client:  newOpenAIClient(clientOptions),
118		}, nil
119	case models.ProviderAzure:
120		return &baseProvider[AzureClient]{
121			options: clientOptions,
122			client:  newAzureClient(clientOptions),
123		}, nil
124	case models.ProviderVertexAI:
125		return &baseProvider[VertexAIClient]{
126			options: clientOptions,
127			client:  newVertexAIClient(clientOptions),
128		}, nil
129	case models.ProviderOpenRouter:
130		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
131			WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
132			WithOpenAIExtraHeaders(map[string]string{
133				"HTTP-Referer": "crush.charm.land",
134				"X-Title":      "Crush",
135			}),
136		)
137		return &baseProvider[OpenAIClient]{
138			options: clientOptions,
139			client:  newOpenAIClient(clientOptions),
140		}, nil
141	case models.ProviderXAI:
142		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
143			WithOpenAIBaseURL("https://api.x.ai/v1"),
144		)
145		return &baseProvider[OpenAIClient]{
146			options: clientOptions,
147			client:  newOpenAIClient(clientOptions),
148		}, nil
149	case models.ProviderLocal:
150		clientOptions.openaiOptions = append(clientOptions.openaiOptions,
151			WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")),
152		)
153		return &baseProvider[OpenAIClient]{
154			options: clientOptions,
155			client:  newOpenAIClient(clientOptions),
156		}, nil
157	case models.ProviderMock:
158		// TODO: implement mock client for test
159		panic("not implemented")
160	}
161	return nil, fmt.Errorf("provider not supported: %s", providerName)
162}
163
164func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
165	for _, msg := range messages {
166		// The message has no content
167		if len(msg.Parts) == 0 {
168			continue
169		}
170		cleaned = append(cleaned, msg)
171	}
172	return
173}
174
175func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
176	messages = p.cleanMessages(messages)
177	return p.client.send(ctx, messages, tools)
178}
179
180func (p *baseProvider[C]) Model() models.Model {
181	return p.options.model
182}
183
184func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
185	messages = p.cleanMessages(messages)
186	return p.client.stream(ctx, messages, tools)
187}
188
189func WithAPIKey(apiKey string) ProviderClientOption {
190	return func(options *providerClientOptions) {
191		options.apiKey = apiKey
192	}
193}
194
195func WithModel(model models.Model) ProviderClientOption {
196	return func(options *providerClientOptions) {
197		options.model = model
198	}
199}
200
201func WithMaxTokens(maxTokens int64) ProviderClientOption {
202	return func(options *providerClientOptions) {
203		options.maxTokens = maxTokens
204	}
205}
206
207func WithSystemMessage(systemMessage string) ProviderClientOption {
208	return func(options *providerClientOptions) {
209		options.systemMessage = systemMessage
210	}
211}
212
213func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
214	return func(options *providerClientOptions) {
215		options.anthropicOptions = anthropicOptions
216	}
217}
218
219func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
220	return func(options *providerClientOptions) {
221		options.openaiOptions = openaiOptions
222	}
223}
224
225func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
226	return func(options *providerClientOptions) {
227		options.geminiOptions = geminiOptions
228	}
229}
230
231func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
232	return func(options *providerClientOptions) {
233		options.bedrockOptions = bedrockOptions
234	}
235}