1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	configv2 "github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/fur/provider"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11)
 12
 13type EventType string
 14
 15const maxRetries = 8
 16
 17const (
 18	EventContentStart  EventType = "content_start"
 19	EventToolUseStart  EventType = "tool_use_start"
 20	EventToolUseDelta  EventType = "tool_use_delta"
 21	EventToolUseStop   EventType = "tool_use_stop"
 22	EventContentDelta  EventType = "content_delta"
 23	EventThinkingDelta EventType = "thinking_delta"
 24	EventContentStop   EventType = "content_stop"
 25	EventComplete      EventType = "complete"
 26	EventError         EventType = "error"
 27	EventWarning       EventType = "warning"
 28)
 29
 30type TokenUsage struct {
 31	InputTokens         int64
 32	OutputTokens        int64
 33	CacheCreationTokens int64
 34	CacheReadTokens     int64
 35}
 36
 37type ProviderResponse struct {
 38	Content      string
 39	ToolCalls    []message.ToolCall
 40	Usage        TokenUsage
 41	FinishReason message.FinishReason
 42}
 43
 44type ProviderEvent struct {
 45	Type EventType
 46
 47	Content  string
 48	Thinking string
 49	Response *ProviderResponse
 50	ToolCall *message.ToolCall
 51	Error    error
 52}
 53type Provider interface {
 54	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 55
 56	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 57
 58	Model() configv2.Model
 59}
 60
 61type providerClientOptions struct {
 62	baseURL       string
 63	apiKey        string
 64	model         configv2.Model
 65	disableCache  bool
 66	maxTokens     int64
 67	systemMessage string
 68	extraHeaders  map[string]string
 69	extraParams   map[string]string
 70}
 71
 72type ProviderClientOption func(*providerClientOptions)
 73
 74type ProviderClient interface {
 75	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 76	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 77}
 78
 79type baseProvider[C ProviderClient] struct {
 80	options providerClientOptions
 81	client  C
 82}
 83
 84func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 85	for _, msg := range messages {
 86		// The message has no content
 87		if len(msg.Parts) == 0 {
 88			continue
 89		}
 90		cleaned = append(cleaned, msg)
 91	}
 92	return
 93}
 94
 95func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
 96	messages = p.cleanMessages(messages)
 97	return p.client.send(ctx, messages, tools)
 98}
 99
100func (p *baseProvider[C]) Model() configv2.Model {
101	return p.options.model
102}
103
104func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
105	messages = p.cleanMessages(messages)
106	return p.client.stream(ctx, messages, tools)
107}
108
109func WithModel(model configv2.Model) ProviderClientOption {
110	return func(options *providerClientOptions) {
111		options.model = model
112	}
113}
114
115func WithDisableCache(disableCache bool) ProviderClientOption {
116	return func(options *providerClientOptions) {
117		options.disableCache = disableCache
118	}
119}
120
121func WithMaxTokens(maxTokens int64) ProviderClientOption {
122	return func(options *providerClientOptions) {
123		options.maxTokens = maxTokens
124	}
125}
126
127func WithSystemMessage(systemMessage string) ProviderClientOption {
128	return func(options *providerClientOptions) {
129		options.systemMessage = systemMessage
130	}
131}
132
133func NewProviderV2(cfg configv2.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
134	clientOptions := providerClientOptions{
135		baseURL:      cfg.BaseURL,
136		apiKey:       cfg.APIKey,
137		extraHeaders: cfg.ExtraHeaders,
138	}
139	for _, o := range opts {
140		o(&clientOptions)
141	}
142	switch cfg.ProviderType {
143	case provider.TypeAnthropic:
144		return &baseProvider[AnthropicClient]{
145			options: clientOptions,
146			client:  newAnthropicClient(clientOptions, false),
147		}, nil
148	case provider.TypeOpenAI:
149		return &baseProvider[OpenAIClient]{
150			options: clientOptions,
151			client:  newOpenAIClient(clientOptions),
152		}, nil
153	case provider.TypeGemini:
154		return &baseProvider[GeminiClient]{
155			options: clientOptions,
156			client:  newGeminiClient(clientOptions),
157		}, nil
158	case provider.TypeBedrock:
159		return &baseProvider[BedrockClient]{
160			options: clientOptions,
161			client:  newBedrockClient(clientOptions),
162		}, nil
163	case provider.TypeAzure:
164		return &baseProvider[AzureClient]{
165			options: clientOptions,
166			client:  newAzureClient(clientOptions),
167		}, nil
168	case provider.TypeVertexAI:
169		return &baseProvider[VertexAIClient]{
170			options: clientOptions,
171			client:  newVertexAIClient(clientOptions),
172		}, nil
173	case provider.TypeXAI:
174		clientOptions.baseURL = "https://api.x.ai/v1"
175		return &baseProvider[OpenAIClient]{
176			options: clientOptions,
177			client:  newOpenAIClient(clientOptions),
178		}, nil
179	}
180	return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
181}