provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11)
 12
 13type EventType string
 14
 15const maxRetries = 8
 16
 17const (
 18	EventContentStart   EventType = "content_start"
 19	EventToolUseStart   EventType = "tool_use_start"
 20	EventToolUseDelta   EventType = "tool_use_delta"
 21	EventToolUseStop    EventType = "tool_use_stop"
 22	EventContentDelta   EventType = "content_delta"
 23	EventThinkingDelta  EventType = "thinking_delta"
 24	EventSignatureDelta EventType = "signature_delta"
 25	EventContentStop    EventType = "content_stop"
 26	EventComplete       EventType = "complete"
 27	EventError          EventType = "error"
 28	EventWarning        EventType = "warning"
 29)
 30
 31type TokenUsage struct {
 32	InputTokens         int64
 33	OutputTokens        int64
 34	CacheCreationTokens int64
 35	CacheReadTokens     int64
 36}
 37
 38type ProviderResponse struct {
 39	Content      string
 40	ToolCalls    []message.ToolCall
 41	Usage        TokenUsage
 42	FinishReason message.FinishReason
 43}
 44
 45type ProviderEvent struct {
 46	Type EventType
 47
 48	Content   string
 49	Thinking  string
 50	Signature string
 51	Response  *ProviderResponse
 52	ToolCall  *message.ToolCall
 53	Error     error
 54}
 55type Provider interface {
 56	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 57
 58	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 59
 60	Model() provider.Model
 61}
 62
 63type providerClientOptions struct {
 64	baseURL       string
 65	config        config.ProviderConfig
 66	apiKey        string
 67	modelType     config.SelectedModelType
 68	model         func(config.SelectedModelType) provider.Model
 69	disableCache  bool
 70	systemMessage string
 71	maxTokens     int64
 72	extraHeaders  map[string]string
 73	extraParams   map[string]string
 74}
 75
 76type ProviderClientOption func(*providerClientOptions)
 77
 78type ProviderClient interface {
 79	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 80	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 81
 82	Model() provider.Model
 83}
 84
 85type baseProvider[C ProviderClient] struct {
 86	options providerClientOptions
 87	client  C
 88}
 89
 90func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 91	for _, msg := range messages {
 92		// The message has no content
 93		if len(msg.Parts) == 0 {
 94			continue
 95		}
 96		cleaned = append(cleaned, msg)
 97	}
 98	return
 99}
100
101func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
102	messages = p.cleanMessages(messages)
103	return p.client.send(ctx, messages, tools)
104}
105
106func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
107	messages = p.cleanMessages(messages)
108	return p.client.stream(ctx, messages, tools)
109}
110
111func (p *baseProvider[C]) Model() provider.Model {
112	return p.client.Model()
113}
114
115func WithModel(model config.SelectedModelType) ProviderClientOption {
116	return func(options *providerClientOptions) {
117		options.modelType = model
118	}
119}
120
121func WithDisableCache(disableCache bool) ProviderClientOption {
122	return func(options *providerClientOptions) {
123		options.disableCache = disableCache
124	}
125}
126
127func WithSystemMessage(systemMessage string) ProviderClientOption {
128	return func(options *providerClientOptions) {
129		options.systemMessage = systemMessage
130	}
131}
132
133func WithMaxTokens(maxTokens int64) ProviderClientOption {
134	return func(options *providerClientOptions) {
135		options.maxTokens = maxTokens
136	}
137}
138
139func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
140	resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
141	if err != nil {
142		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
143	}
144
145	clientOptions := providerClientOptions{
146		baseURL:      cfg.BaseURL,
147		config:       cfg,
148		apiKey:       resolvedAPIKey,
149		extraHeaders: cfg.ExtraHeaders,
150		model: func(tp config.SelectedModelType) provider.Model {
151			return *config.Get().GetModelByType(tp)
152		},
153	}
154	for _, o := range opts {
155		o(&clientOptions)
156	}
157	switch cfg.Type {
158	case provider.TypeAnthropic:
159		return &baseProvider[AnthropicClient]{
160			options: clientOptions,
161			client:  newAnthropicClient(clientOptions, false),
162		}, nil
163	case provider.TypeOpenAI:
164		return &baseProvider[OpenAIClient]{
165			options: clientOptions,
166			client:  newOpenAIClient(clientOptions),
167		}, nil
168	case provider.TypeGemini:
169		return &baseProvider[GeminiClient]{
170			options: clientOptions,
171			client:  newGeminiClient(clientOptions),
172		}, nil
173	case provider.TypeBedrock:
174		return &baseProvider[BedrockClient]{
175			options: clientOptions,
176			client:  newBedrockClient(clientOptions),
177		}, nil
178	case provider.TypeAzure:
179		return &baseProvider[AzureClient]{
180			options: clientOptions,
181			client:  newAzureClient(clientOptions),
182		}, nil
183	case provider.TypeVertexAI:
184		return &baseProvider[VertexAIClient]{
185			options: clientOptions,
186			client:  newVertexAIClient(clientOptions),
187		}, nil
188	case provider.TypeXAI:
189		clientOptions.baseURL = "https://api.x.ai/v1"
190		return &baseProvider[OpenAIClient]{
191			options: clientOptions,
192			client:  newOpenAIClient(clientOptions),
193		}, nil
194	}
195	return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
196}