provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6	"maps"
  7	"os"
  8
  9	"github.com/charmbracelet/crush/internal/llm/models"
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11	"github.com/charmbracelet/crush/internal/message"
 12)
 13
 14type EventType string
 15
 16const maxRetries = 8
 17
 18const (
 19	EventContentStart  EventType = "content_start"
 20	EventToolUseStart  EventType = "tool_use_start"
 21	EventToolUseDelta  EventType = "tool_use_delta"
 22	EventToolUseStop   EventType = "tool_use_stop"
 23	EventContentDelta  EventType = "content_delta"
 24	EventThinkingDelta EventType = "thinking_delta"
 25	EventContentStop   EventType = "content_stop"
 26	EventComplete      EventType = "complete"
 27	EventError         EventType = "error"
 28	EventWarning       EventType = "warning"
 29)
 30
 31type TokenUsage struct {
 32	InputTokens         int64
 33	OutputTokens        int64
 34	CacheCreationTokens int64
 35	CacheReadTokens     int64
 36}
 37
 38type ProviderResponse struct {
 39	Content      string
 40	ToolCalls    []message.ToolCall
 41	Usage        TokenUsage
 42	FinishReason message.FinishReason
 43}
 44
 45type ProviderEvent struct {
 46	Type EventType
 47
 48	Content  string
 49	Thinking string
 50	Response *ProviderResponse
 51	ToolCall *message.ToolCall
 52	Error    error
 53}
 54type Provider interface {
 55	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 56
 57	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 58
 59	Model() models.Model
 60}
 61
 62type providerClientOptions struct {
 63	baseURL       string
 64	apiKey        string
 65	model         models.Model
 66	disableCache  bool
 67	maxTokens     int64
 68	systemMessage string
 69	extraHeaders  map[string]string
 70}
 71
 72type ProviderClientOption func(*providerClientOptions)
 73
 74type ProviderClient interface {
 75	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 76	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 77}
 78
 79type baseProvider[C ProviderClient] struct {
 80	options providerClientOptions
 81	client  C
 82}
 83
 84func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) {
 85	clientOptions := providerClientOptions{}
 86	for _, o := range opts {
 87		o(&clientOptions)
 88	}
 89	switch providerName {
 90	case models.ProviderAnthropic:
 91		return &baseProvider[AnthropicClient]{
 92			options: clientOptions,
 93			client:  newAnthropicClient(clientOptions, false),
 94		}, nil
 95	case models.ProviderOpenAI:
 96		return &baseProvider[OpenAIClient]{
 97			options: clientOptions,
 98			client:  newOpenAIClient(clientOptions),
 99		}, nil
100	case models.ProviderGemini:
101		return &baseProvider[GeminiClient]{
102			options: clientOptions,
103			client:  newGeminiClient(clientOptions),
104		}, nil
105	case models.ProviderBedrock:
106		return &baseProvider[BedrockClient]{
107			options: clientOptions,
108			client:  newBedrockClient(clientOptions),
109		}, nil
110	case models.ProviderGROQ:
111		clientOptions.baseURL = "https://api.groq.com/openai/v1"
112		return &baseProvider[OpenAIClient]{
113			options: clientOptions,
114			client:  newOpenAIClient(clientOptions),
115		}, nil
116	case models.ProviderAzure:
117		return &baseProvider[AzureClient]{
118			options: clientOptions,
119			client:  newAzureClient(clientOptions),
120		}, nil
121	case models.ProviderVertexAI:
122		return &baseProvider[VertexAIClient]{
123			options: clientOptions,
124			client:  newVertexAIClient(clientOptions),
125		}, nil
126	case models.ProviderOpenRouter:
127		clientOptions.baseURL = "https://openrouter.ai/api/v1"
128		clientOptions.extraHeaders = map[string]string{
129			"HTTP-Referer": "crush.charm.land",
130			"X-Title":      "Crush",
131		}
132		return &baseProvider[OpenAIClient]{
133			options: clientOptions,
134			client:  newOpenAIClient(clientOptions),
135		}, nil
136	case models.ProviderXAI:
137		clientOptions.baseURL = "https://api.x.ai/v1"
138		return &baseProvider[OpenAIClient]{
139			options: clientOptions,
140			client:  newOpenAIClient(clientOptions),
141		}, nil
142	case models.ProviderLocal:
143		clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT")
144		return &baseProvider[OpenAIClient]{
145			options: clientOptions,
146			client:  newOpenAIClient(clientOptions),
147		}, nil
148	case models.ProviderMock:
149		// TODO: implement mock client for test
150		panic("not implemented")
151	}
152	return nil, fmt.Errorf("provider not supported: %s", providerName)
153}
154
155func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
156	for _, msg := range messages {
157		// The message has no content
158		if len(msg.Parts) == 0 {
159			continue
160		}
161		cleaned = append(cleaned, msg)
162	}
163	return
164}
165
166func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
167	messages = p.cleanMessages(messages)
168	return p.client.send(ctx, messages, tools)
169}
170
171func (p *baseProvider[C]) Model() models.Model {
172	return p.options.model
173}
174
175func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
176	messages = p.cleanMessages(messages)
177	return p.client.stream(ctx, messages, tools)
178}
179
180func WithBaseURL(baseURL string) ProviderClientOption {
181	return func(options *providerClientOptions) {
182		options.baseURL = baseURL
183	}
184}
185
186func WithAPIKey(apiKey string) ProviderClientOption {
187	return func(options *providerClientOptions) {
188		options.apiKey = apiKey
189	}
190}
191
192func WithModel(model models.Model) ProviderClientOption {
193	return func(options *providerClientOptions) {
194		options.model = model
195	}
196}
197
198func WithDisableCache(disableCache bool) ProviderClientOption {
199	return func(options *providerClientOptions) {
200		options.disableCache = disableCache
201	}
202}
203
204func WithExtraHeaders(extraHeaders map[string]string) ProviderClientOption {
205	return func(options *providerClientOptions) {
206		if options.extraHeaders == nil {
207			options.extraHeaders = make(map[string]string)
208		}
209		maps.Copy(options.extraHeaders, extraHeaders)
210	}
211}
212
213func WithMaxTokens(maxTokens int64) ProviderClientOption {
214	return func(options *providerClientOptions) {
215		options.maxTokens = maxTokens
216	}
217}
218
219func WithSystemMessage(systemMessage string) ProviderClientOption {
220	return func(options *providerClientOptions) {
221		options.systemMessage = systemMessage
222	}
223}