provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/catwalk/pkg/catwalk"
  8
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11	"github.com/charmbracelet/crush/internal/message"
 12)
 13
 14type EventType string
 15
 16const maxRetries = 8
 17
 18const (
 19	EventContentStart   EventType = "content_start"
 20	EventToolUseStart   EventType = "tool_use_start"
 21	EventToolUseDelta   EventType = "tool_use_delta"
 22	EventToolUseStop    EventType = "tool_use_stop"
 23	EventContentDelta   EventType = "content_delta"
 24	EventThinkingDelta  EventType = "thinking_delta"
 25	EventSignatureDelta EventType = "signature_delta"
 26	EventContentStop    EventType = "content_stop"
 27	EventComplete       EventType = "complete"
 28	EventError          EventType = "error"
 29	EventWarning        EventType = "warning"
 30)
 31
 32type TokenUsage struct {
 33	InputTokens         int64
 34	OutputTokens        int64
 35	CacheCreationTokens int64
 36	CacheReadTokens     int64
 37}
 38
 39type ProviderResponse struct {
 40	Content      string
 41	ToolCalls    []message.ToolCall
 42	Usage        TokenUsage
 43	FinishReason message.FinishReason
 44}
 45
 46type ProviderEvent struct {
 47	Type EventType
 48
 49	Content   string
 50	Thinking  string
 51	Signature string
 52	Response  *ProviderResponse
 53	ToolCall  *message.ToolCall
 54	Error     error
 55}
 56type Provider interface {
 57	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 58
 59	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 60
 61	Model() catwalk.Model
 62}
 63
 64type providerClientOptions struct {
 65	baseURL            string
 66	config             config.ProviderConfig
 67	apiKey             string
 68	modelType          config.SelectedModelType
 69	model              func(config.SelectedModelType) catwalk.Model
 70	disableCache       bool
 71	systemMessage      string
 72	systemPromptPrefix string
 73	maxTokens          int64
 74	extraHeaders       map[string]string
 75	extraBody          map[string]any
 76	extraParams        map[string]string
 77}
 78
 79type ProviderClientOption func(*providerClientOptions)
 80
 81type ProviderClient interface {
 82	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 83	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 84
 85	Model() catwalk.Model
 86}
 87
 88type baseProvider[C ProviderClient] struct {
 89	options providerClientOptions
 90	client  C
 91}
 92
 93func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 94	for _, msg := range messages {
 95		// The message has no content
 96		if len(msg.Parts) == 0 {
 97			continue
 98		}
 99		cleaned = append(cleaned, msg)
100	}
101	return cleaned
102}
103
104func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
105	messages = p.cleanMessages(messages)
106	return p.client.send(ctx, messages, tools)
107}
108
109func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
110	messages = p.cleanMessages(messages)
111	return p.client.stream(ctx, messages, tools)
112}
113
114func (p *baseProvider[C]) Model() catwalk.Model {
115	return p.client.Model()
116}
117
118func WithModel(model config.SelectedModelType) ProviderClientOption {
119	return func(options *providerClientOptions) {
120		options.modelType = model
121	}
122}
123
124func WithDisableCache(disableCache bool) ProviderClientOption {
125	return func(options *providerClientOptions) {
126		options.disableCache = disableCache
127	}
128}
129
130func WithSystemMessage(systemMessage string) ProviderClientOption {
131	return func(options *providerClientOptions) {
132		options.systemMessage = systemMessage
133	}
134}
135
136func WithMaxTokens(maxTokens int64) ProviderClientOption {
137	return func(options *providerClientOptions) {
138		options.maxTokens = maxTokens
139	}
140}
141
142func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
143	restore := config.PushPopCrushEnv()
144	defer restore()
145	resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
146	if err != nil {
147		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
148	}
149
150	// Resolve extra headers
151	resolvedExtraHeaders := make(map[string]string)
152	for key, value := range cfg.ExtraHeaders {
153		resolvedValue, err := config.Get().Resolve(value)
154		if err != nil {
155			return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
156		}
157		resolvedExtraHeaders[key] = resolvedValue
158	}
159
160	clientOptions := providerClientOptions{
161		baseURL:            cfg.BaseURL,
162		config:             cfg,
163		apiKey:             resolvedAPIKey,
164		extraHeaders:       resolvedExtraHeaders,
165		extraBody:          cfg.ExtraBody,
166		extraParams:        cfg.ExtraParams,
167		systemPromptPrefix: cfg.SystemPromptPrefix,
168		model: func(tp config.SelectedModelType) catwalk.Model {
169			return *config.Get().GetModelByType(tp)
170		},
171	}
172	for _, o := range opts {
173		o(&clientOptions)
174	}
175	switch cfg.Type {
176	case catwalk.TypeAnthropic:
177		return &baseProvider[AnthropicClient]{
178			options: clientOptions,
179			client:  newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
180		}, nil
181	case catwalk.TypeOpenAI:
182		return &baseProvider[OpenAIClient]{
183			options: clientOptions,
184			client:  newOpenAIClient(clientOptions),
185		}, nil
186	case catwalk.TypeGemini:
187		return &baseProvider[GeminiClient]{
188			options: clientOptions,
189			client:  newGeminiClient(clientOptions),
190		}, nil
191	case catwalk.TypeBedrock:
192		return &baseProvider[BedrockClient]{
193			options: clientOptions,
194			client:  newBedrockClient(clientOptions),
195		}, nil
196	case catwalk.TypeAzure:
197		return &baseProvider[AzureClient]{
198			options: clientOptions,
199			client:  newAzureClient(clientOptions),
200		}, nil
201	case catwalk.TypeVertexAI:
202		return &baseProvider[VertexAIClient]{
203			options: clientOptions,
204			client:  newVertexAIClient(clientOptions),
205		}, nil
206	}
207	return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
208}