provider.go

  1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/catwalk/pkg/catwalk"
  8	"github.com/charmbracelet/crush/internal/config"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11)
 12
 13type EventType string
 14
 15const maxRetries = 8
 16
 17const (
 18	EventContentStart   EventType = "content_start"
 19	EventToolUseStart   EventType = "tool_use_start"
 20	EventToolUseDelta   EventType = "tool_use_delta"
 21	EventToolUseStop    EventType = "tool_use_stop"
 22	EventContentDelta   EventType = "content_delta"
 23	EventThinkingDelta  EventType = "thinking_delta"
 24	EventSignatureDelta EventType = "signature_delta"
 25	EventContentStop    EventType = "content_stop"
 26	EventComplete       EventType = "complete"
 27	EventError          EventType = "error"
 28	EventWarning        EventType = "warning"
 29	EventRetry          EventType = "retry"
 30	EventRetrying       EventType = "retrying"
 31)
 32
 33type TokenUsage struct {
 34	InputTokens         int64
 35	OutputTokens        int64
 36	CacheCreationTokens int64
 37	CacheReadTokens     int64
 38}
 39
 40type ProviderResponse struct {
 41	Content      string
 42	ToolCalls    []message.ToolCall
 43	Usage        TokenUsage
 44	FinishReason message.FinishReason
 45}
 46
 47type ProviderEvent struct {
 48	Type EventType
 49
 50	Content   string
 51	Thinking  string
 52	Signature string
 53	Response  *ProviderResponse
 54	ToolCall  *message.ToolCall
 55	Retry     int64
 56	Error     error
 57}
 58type Provider interface {
 59	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 60
 61	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 62
 63	Model() catwalk.Model
 64}
 65
 66type providerClientOptions struct {
 67	baseURL            string
 68	config             config.ProviderConfig
 69	apiKey             string
 70	modelType          config.SelectedModelType
 71	model              func(config.SelectedModelType) catwalk.Model
 72	disableCache       bool
 73	systemMessage      string
 74	systemPromptPrefix string
 75	maxTokens          int64
 76	extraHeaders       map[string]string
 77	extraBody          map[string]any
 78	extraParams        map[string]string
 79}
 80
 81type ProviderClientOption func(*providerClientOptions)
 82
 83type ProviderClient interface {
 84	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 85	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 86
 87	Model() catwalk.Model
 88}
 89
 90type baseProvider[C ProviderClient] struct {
 91	options providerClientOptions
 92	client  C
 93}
 94
 95func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 96	for _, msg := range messages {
 97		// The message has no content
 98		if len(msg.Parts) == 0 {
 99			continue
100		}
101		cleaned = append(cleaned, msg)
102	}
103	return
104}
105
106func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
107	messages = p.cleanMessages(messages)
108	return p.client.send(ctx, messages, tools)
109}
110
111func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
112	messages = p.cleanMessages(messages)
113	return p.client.stream(ctx, messages, tools)
114}
115
116func (p *baseProvider[C]) Model() catwalk.Model {
117	return p.client.Model()
118}
119
120func WithModel(model config.SelectedModelType) ProviderClientOption {
121	return func(options *providerClientOptions) {
122		options.modelType = model
123	}
124}
125
126func WithDisableCache(disableCache bool) ProviderClientOption {
127	return func(options *providerClientOptions) {
128		options.disableCache = disableCache
129	}
130}
131
132func WithSystemMessage(systemMessage string) ProviderClientOption {
133	return func(options *providerClientOptions) {
134		options.systemMessage = systemMessage
135	}
136}
137
138func WithMaxTokens(maxTokens int64) ProviderClientOption {
139	return func(options *providerClientOptions) {
140		options.maxTokens = maxTokens
141	}
142}
143
144func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
145	resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
146	if err != nil {
147		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
148	}
149
150	// Resolve extra headers
151	resolvedExtraHeaders := make(map[string]string)
152	for key, value := range cfg.ExtraHeaders {
153		resolvedValue, err := config.Get().Resolve(value)
154		if err != nil {
155			return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
156		}
157		resolvedExtraHeaders[key] = resolvedValue
158	}
159
160	clientOptions := providerClientOptions{
161		baseURL:            cfg.BaseURL,
162		config:             cfg,
163		apiKey:             resolvedAPIKey,
164		extraHeaders:       resolvedExtraHeaders,
165		extraBody:          cfg.ExtraBody,
166		extraParams:        cfg.ExtraParams,
167		systemPromptPrefix: cfg.SystemPromptPrefix,
168		model: func(tp config.SelectedModelType) catwalk.Model {
169			return *config.Get().GetModelByType(tp)
170		},
171	}
172	for _, o := range opts {
173		o(&clientOptions)
174	}
175	switch cfg.Type {
176	case catwalk.TypeAnthropic:
177		return &baseProvider[AnthropicClient]{
178			options: clientOptions,
179			client:  newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
180		}, nil
181	case catwalk.TypeOpenAI:
182		return &baseProvider[OpenAIClient]{
183			options: clientOptions,
184			client:  newOpenAIClient(clientOptions),
185		}, nil
186	case catwalk.TypeGemini:
187		return &baseProvider[GeminiClient]{
188			options: clientOptions,
189			client:  newGeminiClient(clientOptions),
190		}, nil
191	case catwalk.TypeBedrock:
192		return &baseProvider[BedrockClient]{
193			options: clientOptions,
194			client:  newBedrockClient(clientOptions),
195		}, nil
196	case catwalk.TypeAzure:
197		return &baseProvider[AzureClient]{
198			options: clientOptions,
199			client:  newAzureClient(clientOptions),
200		}, nil
201	case catwalk.TypeVertexAI:
202		return &baseProvider[VertexAIClient]{
203			options: clientOptions,
204			client:  newVertexAIClient(clientOptions),
205		}, nil
206	}
207	return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
208}