1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11)
 12
 13type EventType string
 14
 15const maxRetries = 8
 16
 17const (
 18	EventContentStart  EventType = "content_start"
 19	EventToolUseStart  EventType = "tool_use_start"
 20	EventToolUseDelta  EventType = "tool_use_delta"
 21	EventToolUseStop   EventType = "tool_use_stop"
 22	EventContentDelta  EventType = "content_delta"
 23	EventThinkingDelta EventType = "thinking_delta"
 24	EventContentStop   EventType = "content_stop"
 25	EventComplete      EventType = "complete"
 26	EventError         EventType = "error"
 27	EventWarning       EventType = "warning"
 28)
 29
 30type TokenUsage struct {
 31	InputTokens         int64
 32	OutputTokens        int64
 33	CacheCreationTokens int64
 34	CacheReadTokens     int64
 35}
 36
 37type ProviderResponse struct {
 38	Content      string
 39	ToolCalls    []message.ToolCall
 40	Usage        TokenUsage
 41	FinishReason message.FinishReason
 42}
 43
 44type ProviderEvent struct {
 45	Type EventType
 46
 47	Content  string
 48	Thinking string
 49	Response *ProviderResponse
 50	ToolCall *message.ToolCall
 51	Error    error
 52}
 53type Provider interface {
 54	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 55
 56	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 57
 58	Model() config.Model
 59}
 60
 61type providerClientOptions struct {
 62	baseURL       string
 63	apiKey        string
 64	modelType     config.ModelType
 65	model         func(config.ModelType) config.Model
 66	disableCache  bool
 67	maxTokens     int64
 68	systemMessage string
 69	extraHeaders  map[string]string
 70	extraParams   map[string]string
 71}
 72
 73type ProviderClientOption func(*providerClientOptions)
 74
 75type ProviderClient interface {
 76	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 77	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 78
 79	Model() config.Model
 80}
 81
 82type baseProvider[C ProviderClient] struct {
 83	options providerClientOptions
 84	client  C
 85}
 86
 87func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 88	for _, msg := range messages {
 89		// The message has no content
 90		if len(msg.Parts) == 0 {
 91			continue
 92		}
 93		cleaned = append(cleaned, msg)
 94	}
 95	return
 96}
 97
 98func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 99	messages = p.cleanMessages(messages)
100	return p.client.send(ctx, messages, tools)
101}
102
103func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
104	messages = p.cleanMessages(messages)
105	return p.client.stream(ctx, messages, tools)
106}
107
108func (p *baseProvider[C]) Model() config.Model {
109	return p.client.Model()
110}
111
112func WithModel(model config.ModelType) ProviderClientOption {
113	return func(options *providerClientOptions) {
114		options.modelType = model
115	}
116}
117
118func WithDisableCache(disableCache bool) ProviderClientOption {
119	return func(options *providerClientOptions) {
120		options.disableCache = disableCache
121	}
122}
123
124func WithMaxTokens(maxTokens int64) ProviderClientOption {
125	return func(options *providerClientOptions) {
126		options.maxTokens = maxTokens
127	}
128}
129
130func WithSystemMessage(systemMessage string) ProviderClientOption {
131	return func(options *providerClientOptions) {
132		options.systemMessage = systemMessage
133	}
134}
135
136func NewProviderV2(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
137	clientOptions := providerClientOptions{
138		baseURL:      cfg.BaseURL,
139		apiKey:       cfg.APIKey,
140		extraHeaders: cfg.ExtraHeaders,
141		model: func(tp config.ModelType) config.Model {
142			return config.GetModel(tp)
143		},
144	}
145	for _, o := range opts {
146		o(&clientOptions)
147	}
148	switch cfg.ProviderType {
149	case provider.TypeAnthropic:
150		return &baseProvider[AnthropicClient]{
151			options: clientOptions,
152			client:  newAnthropicClient(clientOptions, false),
153		}, nil
154	case provider.TypeOpenAI:
155		return &baseProvider[OpenAIClient]{
156			options: clientOptions,
157			client:  newOpenAIClient(clientOptions),
158		}, nil
159	case provider.TypeGemini:
160		return &baseProvider[GeminiClient]{
161			options: clientOptions,
162			client:  newGeminiClient(clientOptions),
163		}, nil
164	case provider.TypeBedrock:
165		return &baseProvider[BedrockClient]{
166			options: clientOptions,
167			client:  newBedrockClient(clientOptions),
168		}, nil
169	case provider.TypeAzure:
170		return &baseProvider[AzureClient]{
171			options: clientOptions,
172			client:  newAzureClient(clientOptions),
173		}, nil
174	case provider.TypeVertexAI:
175		return &baseProvider[VertexAIClient]{
176			options: clientOptions,
177			client:  newVertexAIClient(clientOptions),
178		}, nil
179	case provider.TypeXAI:
180		clientOptions.baseURL = "https://api.x.ai/v1"
181		return &baseProvider[OpenAIClient]{
182			options: clientOptions,
183			client:  newOpenAIClient(clientOptions),
184		}, nil
185	}
186	return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
187}