1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/catwalk/pkg/catwalk"
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventSignatureDelta EventType = "signature_delta"
25 EventContentStop EventType = "content_stop"
26 EventComplete EventType = "complete"
27 EventError EventType = "error"
28 EventWarning EventType = "warning"
29)
30
31type TokenUsage struct {
32 InputTokens int64
33 OutputTokens int64
34 CacheCreationTokens int64
35 CacheReadTokens int64
36}
37
38type ProviderResponse struct {
39 Content string
40 ToolCalls []message.ToolCall
41 Usage TokenUsage
42 FinishReason message.FinishReason
43}
44
45type ProviderEvent struct {
46 Type EventType
47
48 Content string
49 Thinking string
50 Signature string
51 Response *ProviderResponse
52 ToolCall *message.ToolCall
53 Error error
54}
55type Provider interface {
56 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
57
58 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
59
60 Model() catwalk.Model
61}
62
63type providerClientOptions struct {
64 baseURL string
65 config config.ProviderConfig
66 apiKey string
67 modelType config.SelectedModelType
68 model func(config.SelectedModelType) catwalk.Model
69 disableCache bool
70 systemMessage string
71 systemPromptPrefix string
72 maxTokens int64
73 extraHeaders map[string]string
74 extraBody map[string]any
75 extraParams map[string]string
76}
77
78type ProviderClientOption func(*providerClientOptions)
79
80type ProviderClient interface {
81 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
82 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
83
84 Model() catwalk.Model
85}
86
87type baseProvider[C ProviderClient] struct {
88 options providerClientOptions
89 client C
90}
91
92func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
93 for _, msg := range messages {
94 // The message has no content
95 if len(msg.Parts) == 0 {
96 continue
97 }
98 cleaned = append(cleaned, msg)
99 }
100 return
101}
102
103func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
104 messages = p.cleanMessages(messages)
105 return p.client.send(ctx, messages, tools)
106}
107
108func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
109 messages = p.cleanMessages(messages)
110 return p.client.stream(ctx, messages, tools)
111}
112
113func (p *baseProvider[C]) Model() catwalk.Model {
114 return p.client.Model()
115}
116
117func WithModel(model config.SelectedModelType) ProviderClientOption {
118 return func(options *providerClientOptions) {
119 options.modelType = model
120 }
121}
122
123func WithDisableCache(disableCache bool) ProviderClientOption {
124 return func(options *providerClientOptions) {
125 options.disableCache = disableCache
126 }
127}
128
129func WithSystemMessage(systemMessage string) ProviderClientOption {
130 return func(options *providerClientOptions) {
131 options.systemMessage = systemMessage
132 }
133}
134
135func WithMaxTokens(maxTokens int64) ProviderClientOption {
136 return func(options *providerClientOptions) {
137 options.maxTokens = maxTokens
138 }
139}
140
141func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
142 resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
143 if err != nil {
144 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
145 }
146
147 // Resolve extra headers
148 resolvedExtraHeaders := make(map[string]string)
149 for key, value := range cfg.ExtraHeaders {
150 resolvedValue, err := config.Get().Resolve(value)
151 if err != nil {
152 return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
153 }
154 resolvedExtraHeaders[key] = resolvedValue
155 }
156
157 clientOptions := providerClientOptions{
158 baseURL: cfg.BaseURL,
159 config: cfg,
160 apiKey: resolvedAPIKey,
161 extraHeaders: resolvedExtraHeaders,
162 extraBody: cfg.ExtraBody,
163 systemPromptPrefix: cfg.SystemPromptPrefix,
164 model: func(tp config.SelectedModelType) catwalk.Model {
165 return *config.Get().GetModelByType(tp)
166 },
167 }
168 for _, o := range opts {
169 o(&clientOptions)
170 }
171 switch cfg.Type {
172 case catwalk.TypeAnthropic:
173 return &baseProvider[AnthropicClient]{
174 options: clientOptions,
175 client: newAnthropicClient(clientOptions, false),
176 }, nil
177 case catwalk.TypeOpenAI:
178 return &baseProvider[OpenAIClient]{
179 options: clientOptions,
180 client: newOpenAIClient(clientOptions),
181 }, nil
182 case catwalk.TypeGemini:
183 return &baseProvider[GeminiClient]{
184 options: clientOptions,
185 client: newGeminiClient(clientOptions),
186 }, nil
187 case catwalk.TypeBedrock:
188 return &baseProvider[BedrockClient]{
189 options: clientOptions,
190 client: newBedrockClient(clientOptions),
191 }, nil
192 case catwalk.TypeAzure:
193 return &baseProvider[AzureClient]{
194 options: clientOptions,
195 client: newAzureClient(clientOptions),
196 }, nil
197 case catwalk.TypeVertexAI:
198 return &baseProvider[VertexAIClient]{
199 options: clientOptions,
200 client: newVertexAIClient(clientOptions),
201 }, nil
202 case catwalk.TypeXAI:
203 clientOptions.baseURL = "https://api.x.ai/v1"
204 return &baseProvider[OpenAIClient]{
205 options: clientOptions,
206 client: newOpenAIClient(clientOptions),
207 }, nil
208 }
209 return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
210}