provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11)
 12
 13type EventType string
 14
 15const maxRetries = 8
 16
 17const (
 18	EventContentStart  EventType = "content_start"
 19	EventToolUseStart  EventType = "tool_use_start"
 20	EventToolUseDelta  EventType = "tool_use_delta"
 21	EventToolUseStop   EventType = "tool_use_stop"
 22	EventContentDelta  EventType = "content_delta"
 23	EventThinkingDelta EventType = "thinking_delta"
 24	EventContentStop   EventType = "content_stop"
 25	EventComplete      EventType = "complete"
 26	EventError         EventType = "error"
 27	EventWarning       EventType = "warning"
 28)
 29
 30type TokenUsage struct {
 31	InputTokens         int64
 32	OutputTokens        int64
 33	CacheCreationTokens int64
 34	CacheReadTokens     int64
 35}
 36
 37type ProviderResponse struct {
 38	Content      string
 39	ToolCalls    []message.ToolCall
 40	Usage        TokenUsage
 41	FinishReason message.FinishReason
 42}
 43
 44type ProviderEvent struct {
 45	Type EventType
 46
 47	Content  string
 48	Thinking string
 49	Response *ProviderResponse
 50	ToolCall *message.ToolCall
 51	Error    error
 52}
 53type Provider interface {
 54	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 55
 56	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 57
 58	Model() config.Model
 59}
 60
 61type providerClientOptions struct {
 62	baseURL       string
 63	apiKey        string
 64	modelType     config.ModelType
 65	model         func(config.ModelType) config.Model
 66	disableCache  bool
 67	systemMessage string
 68	extraHeaders  map[string]string
 69	extraParams   map[string]string
 70}
 71
 72type ProviderClientOption func(*providerClientOptions)
 73
 74type ProviderClient interface {
 75	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 76	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 77
 78	Model() config.Model
 79}
 80
 81type baseProvider[C ProviderClient] struct {
 82	options providerClientOptions
 83	client  C
 84}
 85
 86func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 87	for _, msg := range messages {
 88		// The message has no content
 89		if len(msg.Parts) == 0 {
 90			continue
 91		}
 92		cleaned = append(cleaned, msg)
 93	}
 94	return
 95}
 96
 97func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 98	messages = p.cleanMessages(messages)
 99	return p.client.send(ctx, messages, tools)
100}
101
102func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
103	messages = p.cleanMessages(messages)
104	return p.client.stream(ctx, messages, tools)
105}
106
107func (p *baseProvider[C]) Model() config.Model {
108	return p.client.Model()
109}
110
111func WithModel(model config.ModelType) ProviderClientOption {
112	return func(options *providerClientOptions) {
113		options.modelType = model
114	}
115}
116
117func WithDisableCache(disableCache bool) ProviderClientOption {
118	return func(options *providerClientOptions) {
119		options.disableCache = disableCache
120	}
121}
122
123func WithSystemMessage(systemMessage string) ProviderClientOption {
124	return func(options *providerClientOptions) {
125		options.systemMessage = systemMessage
126	}
127}
128
129func NewProviderV2(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
130	clientOptions := providerClientOptions{
131		baseURL:      cfg.BaseURL,
132		apiKey:       cfg.APIKey,
133		extraHeaders: cfg.ExtraHeaders,
134		model: func(tp config.ModelType) config.Model {
135			return config.GetModel(tp)
136		},
137	}
138	for _, o := range opts {
139		o(&clientOptions)
140	}
141	switch cfg.ProviderType {
142	case provider.TypeAnthropic:
143		return &baseProvider[AnthropicClient]{
144			options: clientOptions,
145			client:  newAnthropicClient(clientOptions, false),
146		}, nil
147	case provider.TypeOpenAI:
148		return &baseProvider[OpenAIClient]{
149			options: clientOptions,
150			client:  newOpenAIClient(clientOptions),
151		}, nil
152	case provider.TypeGemini:
153		return &baseProvider[GeminiClient]{
154			options: clientOptions,
155			client:  newGeminiClient(clientOptions),
156		}, nil
157	case provider.TypeBedrock:
158		return &baseProvider[BedrockClient]{
159			options: clientOptions,
160			client:  newBedrockClient(clientOptions),
161		}, nil
162	case provider.TypeAzure:
163		return &baseProvider[AzureClient]{
164			options: clientOptions,
165			client:  newAzureClient(clientOptions),
166		}, nil
167	case provider.TypeVertexAI:
168		return &baseProvider[VertexAIClient]{
169			options: clientOptions,
170			client:  newVertexAIClient(clientOptions),
171		}, nil
172	case provider.TypeXAI:
173		clientOptions.baseURL = "https://api.x.ai/v1"
174		return &baseProvider[OpenAIClient]{
175			options: clientOptions,
176			client:  newOpenAIClient(clientOptions),
177		}, nil
178	}
179	return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
180}