1package provider
  2
  3import (
  4	"context"
  5	"fmt"
  6
  7	"github.com/charmbracelet/catwalk/pkg/catwalk"
  8	"github.com/charmbracelet/crush/internal/config"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/message"
 11)
 12
 13type EventType string
 14
 15const maxRetries = 8
 16
 17const (
 18	EventContentStart   EventType = "content_start"
 19	EventToolUseStart   EventType = "tool_use_start"
 20	EventToolUseDelta   EventType = "tool_use_delta"
 21	EventToolUseStop    EventType = "tool_use_stop"
 22	EventContentDelta   EventType = "content_delta"
 23	EventThinkingDelta  EventType = "thinking_delta"
 24	EventSignatureDelta EventType = "signature_delta"
 25	EventContentStop    EventType = "content_stop"
 26	EventComplete       EventType = "complete"
 27	EventError          EventType = "error"
 28	EventWarning        EventType = "warning"
 29)
 30
 31type TokenUsage struct {
 32	InputTokens         int64
 33	OutputTokens        int64
 34	CacheCreationTokens int64
 35	CacheReadTokens     int64
 36}
 37
 38type ProviderResponse struct {
 39	Content      string
 40	ToolCalls    []message.ToolCall
 41	Usage        TokenUsage
 42	FinishReason message.FinishReason
 43}
 44
 45type ProviderEvent struct {
 46	Type EventType
 47
 48	Content   string
 49	Thinking  string
 50	Signature string
 51	Response  *ProviderResponse
 52	ToolCall  *message.ToolCall
 53	Error     error
 54}
 55type Provider interface {
 56	SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 57
 58	StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 59
 60	Model() catwalk.Model
 61}
 62
 63type providerClientOptions struct {
 64	baseURL       string
 65	config        config.ProviderConfig
 66	apiKey        string
 67	modelType     config.SelectedModelType
 68	model         func(config.SelectedModelType) catwalk.Model
 69	disableCache  bool
 70	systemMessage string
 71	maxTokens     int64
 72	extraHeaders  map[string]string
 73	extraBody     map[string]any
 74	extraParams   map[string]string
 75}
 76
 77type ProviderClientOption func(*providerClientOptions)
 78
 79type ProviderClient interface {
 80	send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
 81	stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
 82
 83	Model() catwalk.Model
 84}
 85
 86type baseProvider[C ProviderClient] struct {
 87	options providerClientOptions
 88	client  C
 89}
 90
 91func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
 92	for _, msg := range messages {
 93		// The message has no content
 94		if len(msg.Parts) == 0 {
 95			continue
 96		}
 97		cleaned = append(cleaned, msg)
 98	}
 99	return
100}
101
102func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
103	messages = p.cleanMessages(messages)
104	return p.client.send(ctx, messages, tools)
105}
106
107func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
108	messages = p.cleanMessages(messages)
109	return p.client.stream(ctx, messages, tools)
110}
111
112func (p *baseProvider[C]) Model() catwalk.Model {
113	return p.client.Model()
114}
115
116func WithModel(model config.SelectedModelType) ProviderClientOption {
117	return func(options *providerClientOptions) {
118		options.modelType = model
119	}
120}
121
122func WithDisableCache(disableCache bool) ProviderClientOption {
123	return func(options *providerClientOptions) {
124		options.disableCache = disableCache
125	}
126}
127
128func WithSystemMessage(systemMessage string) ProviderClientOption {
129	return func(options *providerClientOptions) {
130		options.systemMessage = systemMessage
131	}
132}
133
134func WithMaxTokens(maxTokens int64) ProviderClientOption {
135	return func(options *providerClientOptions) {
136		options.maxTokens = maxTokens
137	}
138}
139
140func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
141	resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
142	if err != nil {
143		return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
144	}
145
146	clientOptions := providerClientOptions{
147		baseURL:      cfg.BaseURL,
148		config:       cfg,
149		apiKey:       resolvedAPIKey,
150		extraHeaders: cfg.ExtraHeaders,
151		extraBody:    cfg.ExtraBody,
152		model: func(tp config.SelectedModelType) catwalk.Model {
153			return *config.Get().GetModelByType(tp)
154		},
155	}
156	for _, o := range opts {
157		o(&clientOptions)
158	}
159	switch cfg.Type {
160	case catwalk.TypeAnthropic:
161		return &baseProvider[AnthropicClient]{
162			options: clientOptions,
163			client:  newAnthropicClient(clientOptions, false),
164		}, nil
165	case catwalk.TypeOpenAI:
166		return &baseProvider[OpenAIClient]{
167			options: clientOptions,
168			client:  newOpenAIClient(clientOptions),
169		}, nil
170	case catwalk.TypeGemini:
171		return &baseProvider[GeminiClient]{
172			options: clientOptions,
173			client:  newGeminiClient(clientOptions),
174		}, nil
175	case catwalk.TypeBedrock:
176		return &baseProvider[BedrockClient]{
177			options: clientOptions,
178			client:  newBedrockClient(clientOptions),
179		}, nil
180	case catwalk.TypeAzure:
181		return &baseProvider[AzureClient]{
182			options: clientOptions,
183			client:  newAzureClient(clientOptions),
184		}, nil
185	case catwalk.TypeVertexAI:
186		return &baseProvider[VertexAIClient]{
187			options: clientOptions,
188			client:  newVertexAIClient(clientOptions),
189		}, nil
190	case catwalk.TypeXAI:
191		clientOptions.baseURL = "https://api.x.ai/v1"
192		return &baseProvider[OpenAIClient]{
193			options: clientOptions,
194			client:  newOpenAIClient(clientOptions),
195		}, nil
196	}
197	return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
198}