1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11)
 12
 13type EventType string
 14
 15const maxRetries = 8
 16
 17const (
 18	EventContentStart  EventType = "content_start"
 19	EventToolUseStart  EventType = "tool_use_start"
 20	EventToolUseDelta  EventType = "tool_use_delta"
 21	EventToolUseStop   EventType = "tool_use_stop"
 22	EventContentDelta  EventType = "content_delta"
 23	EventThinkingDelta EventType = "thinking_delta"
 24	EventContentStop   EventType = "content_stop"
 25	EventComplete      EventType = "complete"
 26	EventError         EventType = "error"
 27	EventWarning       EventType = "warning"
 28)
 29
 30type TokenUsage struct {
 31	InputTokens         int64
 32	OutputTokens        int64
 33	CacheCreationTokens int64
 34	CacheReadTokens     int64
 35}
 36
 37type ProviderResponse struct {
 38	Content      string
 39	ToolCalls    []message.ToolCall
 40	Usage        TokenUsage
 41	FinishReason message.FinishReason
 42}
 43
 44type ProviderEvent struct {
 45	Type EventType
 46
 47	Content  string
 48	Thinking string
 49	Response *ProviderResponse
 50	ToolCall *message.ToolCall
 51	Error    error
 52}
 53type Provider interface {
 54	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 55
 56	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 57
 58	Model() provider.Model
 59}
 60
 61type providerClientOptions struct {
 62	baseURL       string
 63	config        config.ProviderConfig
 64	apiKey        string
 65	modelType     config.SelectedModelType
 66	model         func(config.SelectedModelType) provider.Model
 67	disableCache  bool
 68	systemMessage string
 69	maxTokens     int64
 70	extraHeaders  map[string]string
 71	extraParams   map[string]string
 72}
 73
 74type ProviderClientOption func(*providerClientOptions)
 75
 76type ProviderClient interface {
 77	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 78	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 79
 80	Model() provider.Model
 81}
 82
 83type baseProvider[C ProviderClient] struct {
 84	options providerClientOptions
 85	client  C
 86}
 87
 88func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 89	for _, msg := range messages {
 90		// The message has no content
 91		if len(msg.Parts) == 0 {
 92			continue
 93		}
 94		cleaned = append(cleaned, msg)
 95	}
 96	return
 97}
 98
 99func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
100	messages = p.cleanMessages(messages)
101	return p.client.send(ctx, messages, tools)
102}
103
104func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
105	messages = p.cleanMessages(messages)
106	return p.client.stream(ctx, messages, tools)
107}
108
109func (p *baseProvider[C]) Model() provider.Model {
110	return p.client.Model()
111}
112
113func WithModel(model config.SelectedModelType) ProviderClientOption {
114	return func(options *providerClientOptions) {
115		options.modelType = model
116	}
117}
118
119func WithDisableCache(disableCache bool) ProviderClientOption {
120	return func(options *providerClientOptions) {
121		options.disableCache = disableCache
122	}
123}
124
125func WithSystemMessage(systemMessage string) ProviderClientOption {
126	return func(options *providerClientOptions) {
127		options.systemMessage = systemMessage
128	}
129}
130
131func WithMaxTokens(maxTokens int64) ProviderClientOption {
132	return func(options *providerClientOptions) {
133		options.maxTokens = maxTokens
134	}
135}
136
137func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
138	resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
139	if err != nil {
140		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
141	}
142
143	clientOptions := providerClientOptions{
144		baseURL:      cfg.BaseURL,
145		config:       cfg,
146		apiKey:       resolvedAPIKey,
147		extraHeaders: cfg.ExtraHeaders,
148		model: func(tp config.SelectedModelType) provider.Model {
149			return *config.Get().GetModelByType(tp)
150		},
151	}
152	for _, o := range opts {
153		o(&clientOptions)
154	}
155	switch cfg.Type {
156	case provider.TypeAnthropic:
157		return &baseProvider[AnthropicClient]{
158			options: clientOptions,
159			client:  newAnthropicClient(clientOptions, false),
160		}, nil
161	case provider.TypeOpenAI:
162		return &baseProvider[OpenAIClient]{
163			options: clientOptions,
164			client:  newOpenAIClient(clientOptions),
165		}, nil
166	case provider.TypeGemini:
167		return &baseProvider[GeminiClient]{
168			options: clientOptions,
169			client:  newGeminiClient(clientOptions),
170		}, nil
171	case provider.TypeBedrock:
172		return &baseProvider[BedrockClient]{
173			options: clientOptions,
174			client:  newBedrockClient(clientOptions),
175		}, nil
176	case provider.TypeAzure:
177		return &baseProvider[AzureClient]{
178			options: clientOptions,
179			client:  newAzureClient(clientOptions),
180		}, nil
181	case provider.TypeVertexAI:
182		return &baseProvider[VertexAIClient]{
183			options: clientOptions,
184			client:  newVertexAIClient(clientOptions),
185		}, nil
186	case provider.TypeXAI:
187		clientOptions.baseURL = "https://api.x.ai/v1"
188		return &baseProvider[OpenAIClient]{
189			options: clientOptions,
190			client:  newOpenAIClient(clientOptions),
191		}, nil
192	}
193	return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
194}