provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/catwalk/pkg/catwalk"
  8
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11	"github.com/charmbracelet/crush/internal/message"
 12)
 13
 14type EventType string
 15
 16const maxRetries = 8
 17
 18const (
 19	EventContentStart   EventType = "content_start"
 20	EventToolUseStart   EventType = "tool_use_start"
 21	EventToolUseDelta   EventType = "tool_use_delta"
 22	EventToolUseStop    EventType = "tool_use_stop"
 23	EventContentDelta   EventType = "content_delta"
 24	EventThinkingDelta  EventType = "thinking_delta"
 25	EventSignatureDelta EventType = "signature_delta"
 26	EventContentStop    EventType = "content_stop"
 27	EventComplete       EventType = "complete"
 28	EventError          EventType = "error"
 29	EventWarning        EventType = "warning"
 30)
 31
 32type TokenUsage struct {
 33	InputTokens         int64
 34	OutputTokens        int64
 35	CacheCreationTokens int64
 36	CacheReadTokens     int64
 37}
 38
 39type ProviderResponse struct {
 40	Content      string
 41	ToolCalls    []message.ToolCall
 42	Usage        TokenUsage
 43	FinishReason message.FinishReason
 44}
 45
 46type ProviderEvent struct {
 47	Type EventType
 48
 49	Content   string
 50	Thinking  string
 51	Signature string
 52	Response  *ProviderResponse
 53	ToolCall  *message.ToolCall
 54	Error     error
 55}
 56type Provider interface {
 57	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 58
 59	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 60
 61	Model() catwalk.Model
 62}
 63
 64type providerClientOptions struct {
 65	cfg *config.Config
 66
 67	baseURL            string
 68	config             config.ProviderConfig
 69	apiKey             string
 70	modelType          config.SelectedModelType
 71	model              func(config.SelectedModelType) catwalk.Model
 72	disableCache       bool
 73	systemMessage      string
 74	systemPromptPrefix string
 75	maxTokens          int64
 76	extraHeaders       map[string]string
 77	extraBody          map[string]any
 78	extraParams        map[string]string
 79}
 80
 81type ProviderClientOption func(*providerClientOptions)
 82
 83type ProviderClient interface {
 84	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 85	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 86
 87	Model() catwalk.Model
 88}
 89
 90type baseProvider[C ProviderClient] struct {
 91	options providerClientOptions
 92	client  C
 93}
 94
 95func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 96	for _, msg := range messages {
 97		// The message has no content
 98		if len(msg.Parts) == 0 {
 99			continue
100		}
101		cleaned = append(cleaned, msg)
102	}
103	return cleaned
104}
105
106func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
107	messages = p.cleanMessages(messages)
108	return p.client.send(ctx, messages, tools)
109}
110
111func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
112	messages = p.cleanMessages(messages)
113	return p.client.stream(ctx, messages, tools)
114}
115
116func (p *baseProvider[C]) Model() catwalk.Model {
117	return p.client.Model()
118}
119
120func WithModel(model config.SelectedModelType) ProviderClientOption {
121	return func(options *providerClientOptions) {
122		options.modelType = model
123	}
124}
125
126func WithDisableCache(disableCache bool) ProviderClientOption {
127	return func(options *providerClientOptions) {
128		options.disableCache = disableCache
129	}
130}
131
132func WithSystemMessage(systemMessage string) ProviderClientOption {
133	return func(options *providerClientOptions) {
134		options.systemMessage = systemMessage
135	}
136}
137
138func WithMaxTokens(maxTokens int64) ProviderClientOption {
139	return func(options *providerClientOptions) {
140		options.maxTokens = maxTokens
141	}
142}
143
144func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
145	restore := config.PushPopCrushEnv()
146	defer restore()
147	resolvedAPIKey, err := cfg.Resolve(pcfg.APIKey)
148	if err != nil {
149		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err)
150	}
151
152	// Resolve extra headers
153	resolvedExtraHeaders := make(map[string]string)
154	for key, value := range pcfg.ExtraHeaders {
155		resolvedValue, err := cfg.Resolve(value)
156		if err != nil {
157			return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err)
158		}
159		resolvedExtraHeaders[key] = resolvedValue
160	}
161
162	clientOptions := providerClientOptions{
163		cfg:                cfg,
164		baseURL:            pcfg.BaseURL,
165		config:             pcfg,
166		apiKey:             resolvedAPIKey,
167		extraHeaders:       resolvedExtraHeaders,
168		extraBody:          pcfg.ExtraBody,
169		extraParams:        pcfg.ExtraParams,
170		systemPromptPrefix: pcfg.SystemPromptPrefix,
171		model: func(tp config.SelectedModelType) catwalk.Model {
172			return *cfg.GetModelByType(tp)
173		},
174	}
175	for _, o := range opts {
176		o(&clientOptions)
177	}
178	switch pcfg.Type {
179	case catwalk.TypeAnthropic:
180		return &baseProvider[AnthropicClient]{
181			options: clientOptions,
182			client:  newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
183		}, nil
184	case catwalk.TypeOpenAI:
185		return &baseProvider[OpenAIClient]{
186			options: clientOptions,
187			client:  newOpenAIClient(clientOptions),
188		}, nil
189	case catwalk.TypeGemini:
190		return &baseProvider[GeminiClient]{
191			options: clientOptions,
192			client:  newGeminiClient(clientOptions),
193		}, nil
194	case catwalk.TypeBedrock:
195		return &baseProvider[BedrockClient]{
196			options: clientOptions,
197			client:  newBedrockClient(clientOptions),
198		}, nil
199	case catwalk.TypeAzure:
200		return &baseProvider[AzureClient]{
201			options: clientOptions,
202			client:  newAzureClient(clientOptions),
203		}, nil
204	case catwalk.TypeVertexAI:
205		return &baseProvider[VertexAIClient]{
206			options: clientOptions,
207			client:  newVertexAIClient(clientOptions),
208		}, nil
209	}
210	return nil, fmt.Errorf("provider not supported: %s", pcfg.Type)
211}