1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/catwalk/pkg/catwalk"
  8	"github.com/charmbracelet/crush/internal/config"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11)
 12
 13type EventType string
 14
 15const maxRetries = 8
 16
 17const (
 18	EventContentStart   EventType = "content_start"
 19	EventToolUseStart   EventType = "tool_use_start"
 20	EventToolUseDelta   EventType = "tool_use_delta"
 21	EventToolUseStop    EventType = "tool_use_stop"
 22	EventContentDelta   EventType = "content_delta"
 23	EventThinkingDelta  EventType = "thinking_delta"
 24	EventSignatureDelta EventType = "signature_delta"
 25	EventContentStop    EventType = "content_stop"
 26	EventComplete       EventType = "complete"
 27	EventError          EventType = "error"
 28	EventWarning        EventType = "warning"
 29)
 30
 31type TokenUsage struct {
 32	InputTokens         int64
 33	OutputTokens        int64
 34	CacheCreationTokens int64
 35	CacheReadTokens     int64
 36}
 37
 38type ProviderResponse struct {
 39	Content      string
 40	ToolCalls    []message.ToolCall
 41	Usage        TokenUsage
 42	FinishReason message.FinishReason
 43}
 44
 45type ProviderEvent struct {
 46	Type EventType
 47
 48	Content   string
 49	Thinking  string
 50	Signature string
 51	Response  *ProviderResponse
 52	ToolCall  *message.ToolCall
 53	Error     error
 54}
 55type Provider interface {
 56	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 57
 58	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 59
 60	Model() catwalk.Model
 61}
 62
 63type providerClientOptions struct {
 64	baseURL            string
 65	config             config.ProviderConfig
 66	apiKey             string
 67	modelType          config.SelectedModelType
 68	model              func(config.SelectedModelType) catwalk.Model
 69	disableCache       bool
 70	systemMessage      string
 71	systemPromptPrefix string
 72	maxTokens          int64
 73	extraHeaders       map[string]string
 74	extraBody          map[string]any
 75	extraParams        map[string]string
 76}
 77
 78type ProviderClientOption func(*providerClientOptions)
 79
 80type ProviderClient interface {
 81	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 82	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 83
 84	Model() catwalk.Model
 85}
 86
 87type baseProvider[C ProviderClient] struct {
 88	options providerClientOptions
 89	client  C
 90}
 91
 92func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 93	for _, msg := range messages {
 94		// The message has no content
 95		if len(msg.Parts) == 0 {
 96			continue
 97		}
 98		cleaned = append(cleaned, msg)
 99	}
100	return
101}
102
103func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
104	messages = p.cleanMessages(messages)
105	return p.client.send(ctx, messages, tools)
106}
107
108func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
109	messages = p.cleanMessages(messages)
110	return p.client.stream(ctx, messages, tools)
111}
112
113func (p *baseProvider[C]) Model() catwalk.Model {
114	return p.client.Model()
115}
116
117func WithModel(model config.SelectedModelType) ProviderClientOption {
118	return func(options *providerClientOptions) {
119		options.modelType = model
120	}
121}
122
123func WithDisableCache(disableCache bool) ProviderClientOption {
124	return func(options *providerClientOptions) {
125		options.disableCache = disableCache
126	}
127}
128
129func WithSystemMessage(systemMessage string) ProviderClientOption {
130	return func(options *providerClientOptions) {
131		options.systemMessage = systemMessage
132	}
133}
134
135func WithMaxTokens(maxTokens int64) ProviderClientOption {
136	return func(options *providerClientOptions) {
137		options.maxTokens = maxTokens
138	}
139}
140
141func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
142	resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
143	if err != nil {
144		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
145	}
146
147	// Resolve extra headers
148	resolvedExtraHeaders := make(map[string]string)
149	for key, value := range cfg.ExtraHeaders {
150		resolvedValue, err := config.Get().Resolve(value)
151		if err != nil {
152			return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
153		}
154		resolvedExtraHeaders[key] = resolvedValue
155	}
156
157	clientOptions := providerClientOptions{
158		baseURL:            cfg.BaseURL,
159		config:             cfg,
160		apiKey:             resolvedAPIKey,
161		extraHeaders:       resolvedExtraHeaders,
162		extraBody:          cfg.ExtraBody,
163		systemPromptPrefix: cfg.SystemPromptPrefix,
164		model: func(tp config.SelectedModelType) catwalk.Model {
165			return *config.Get().GetModelByType(tp)
166		},
167	}
168	for _, o := range opts {
169		o(&clientOptions)
170	}
171	switch cfg.Type {
172	case catwalk.TypeAnthropic:
173		return &baseProvider[AnthropicClient]{
174			options: clientOptions,
175			client:  newAnthropicClient(clientOptions, false),
176		}, nil
177	case catwalk.TypeOpenAI:
178		return &baseProvider[OpenAIClient]{
179			options: clientOptions,
180			client:  newOpenAIClient(clientOptions),
181		}, nil
182	case catwalk.TypeGemini:
183		return &baseProvider[GeminiClient]{
184			options: clientOptions,
185			client:  newGeminiClient(clientOptions),
186		}, nil
187	case catwalk.TypeBedrock:
188		return &baseProvider[BedrockClient]{
189			options: clientOptions,
190			client:  newBedrockClient(clientOptions),
191		}, nil
192	case catwalk.TypeAzure:
193		return &baseProvider[AzureClient]{
194			options: clientOptions,
195			client:  newAzureClient(clientOptions),
196		}, nil
197	case catwalk.TypeVertexAI:
198		return &baseProvider[VertexAIClient]{
199			options: clientOptions,
200			client:  newVertexAIClient(clientOptions),
201		}, nil
202	}
203	return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
204}