provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/catwalk/pkg/catwalk"
  8
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11	"github.com/charmbracelet/crush/internal/message"
 12)
 13
 14type EventType string
 15
 16const maxRetries = 3
 17
 18const (
 19	EventContentStart   EventType = "content_start"
 20	EventToolUseStart   EventType = "tool_use_start"
 21	EventToolUseDelta   EventType = "tool_use_delta"
 22	EventToolUseStop    EventType = "tool_use_stop"
 23	EventContentDelta   EventType = "content_delta"
 24	EventThinkingDelta  EventType = "thinking_delta"
 25	EventSignatureDelta EventType = "signature_delta"
 26	EventContentStop    EventType = "content_stop"
 27	EventComplete       EventType = "complete"
 28	EventError          EventType = "error"
 29	EventWarning        EventType = "warning"
 30)
 31
 32type TokenUsage struct {
 33	InputTokens         int64
 34	OutputTokens        int64
 35	CacheCreationTokens int64
 36	CacheReadTokens     int64
 37}
 38
 39type ProviderResponse struct {
 40	Content      string
 41	ToolCalls    []message.ToolCall
 42	Usage        TokenUsage
 43	FinishReason message.FinishReason
 44}
 45
 46type ProviderEvent struct {
 47	Type EventType
 48
 49	Content   string
 50	Thinking  string
 51	Signature string
 52	Response  *ProviderResponse
 53	ToolCall  *message.ToolCall
 54	Error     error
 55}
 56type Provider interface {
 57	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 58
 59	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 60
 61	Model() catwalk.Model
 62}
 63
 64type providerClientOptions struct {
 65	cfg *config.Config
 66
 67	resolver config.VariableResolver
 68
 69	baseURL            string
 70	config             config.ProviderConfig
 71	apiKey             string
 72	modelType          config.SelectedModelType
 73	model              func(config.SelectedModelType) catwalk.Model
 74	disableCache       bool
 75	systemMessage      string
 76	systemPromptPrefix string
 77	maxTokens          int64
 78	extraHeaders       map[string]string
 79	extraBody          map[string]any
 80	extraParams        map[string]string
 81}
 82
 83type ProviderClientOption func(*providerClientOptions)
 84
 85type ProviderClient interface {
 86	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 87	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 88
 89	Model() catwalk.Model
 90}
 91
 92type baseProvider[C ProviderClient] struct {
 93	options providerClientOptions
 94	client  C
 95}
 96
 97func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 98	for _, msg := range messages {
 99		// The message has no content
100		if len(msg.Parts) == 0 {
101			continue
102		}
103		cleaned = append(cleaned, msg)
104	}
105	return cleaned
106}
107
108func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
109	messages = p.cleanMessages(messages)
110	return p.client.send(ctx, messages, tools)
111}
112
113func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
114	messages = p.cleanMessages(messages)
115	return p.client.stream(ctx, messages, tools)
116}
117
118func (p *baseProvider[C]) Model() catwalk.Model {
119	return p.client.Model()
120}
121
122func WithModel(model config.SelectedModelType) ProviderClientOption {
123	return func(options *providerClientOptions) {
124		options.modelType = model
125	}
126}
127
128func WithDisableCache(disableCache bool) ProviderClientOption {
129	return func(options *providerClientOptions) {
130		options.disableCache = disableCache
131	}
132}
133
134func WithSystemMessage(systemMessage string) ProviderClientOption {
135	return func(options *providerClientOptions) {
136		options.systemMessage = systemMessage
137	}
138}
139
140func WithMaxTokens(maxTokens int64) ProviderClientOption {
141	return func(options *providerClientOptions) {
142		options.maxTokens = maxTokens
143	}
144}
145
146func WithResolver(resolver config.VariableResolver) ProviderClientOption {
147	return func(options *providerClientOptions) {
148		options.resolver = resolver
149	}
150}
151
152func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
153	clientOptions := providerClientOptions{
154		cfg:                cfg,
155		baseURL:            pcfg.BaseURL,
156		config:             pcfg,
157		extraBody:          pcfg.ExtraBody,
158		extraParams:        pcfg.ExtraParams,
159		systemPromptPrefix: pcfg.SystemPromptPrefix,
160		model: func(tp config.SelectedModelType) catwalk.Model {
161			return *cfg.GetModelByType(tp)
162		},
163	}
164	for _, o := range opts {
165		o(&clientOptions)
166	}
167	if clientOptions.resolver == nil {
168		clientOptions.resolver = config.OsShellResolver
169	}
170
171	restore := config.PushPopCrushEnv()
172	defer restore()
173	resolvedAPIKey, err := clientOptions.resolver.ResolveValue(pcfg.APIKey)
174	if err != nil {
175		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err)
176	}
177
178	// Resolve extra headers
179	resolvedExtraHeaders := make(map[string]string)
180	for key, value := range pcfg.ExtraHeaders {
181		resolvedValue, err := clientOptions.resolver.ResolveValue(value)
182		if err != nil {
183			return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err)
184		}
185		resolvedExtraHeaders[key] = resolvedValue
186	}
187
188	clientOptions.apiKey = resolvedAPIKey
189	clientOptions.extraHeaders = resolvedExtraHeaders
190
191	switch pcfg.Type {
192	case catwalk.TypeAnthropic:
193		return &baseProvider[AnthropicClient]{
194			options: clientOptions,
195			client:  newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
196		}, nil
197	case catwalk.TypeOpenAI:
198		return &baseProvider[OpenAIClient]{
199			options: clientOptions,
200			client:  newOpenAIClient(clientOptions),
201		}, nil
202	case catwalk.TypeGemini:
203		return &baseProvider[GeminiClient]{
204			options: clientOptions,
205			client:  newGeminiClient(clientOptions),
206		}, nil
207	case catwalk.TypeBedrock:
208		return &baseProvider[BedrockClient]{
209			options: clientOptions,
210			client:  newBedrockClient(clientOptions),
211		}, nil
212	case catwalk.TypeAzure:
213		return &baseProvider[AzureClient]{
214			options: clientOptions,
215			client:  newAzureClient(clientOptions),
216		}, nil
217	case catwalk.TypeVertexAI:
218		return &baseProvider[VertexAIClient]{
219			options: clientOptions,
220			client:  newVertexAIClient(clientOptions),
221		}, nil
222	}
223	return nil, fmt.Errorf("provider not supported: %s", pcfg.Type)
224}