1package provider
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "time"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/charmbracelet/crush/internal/message"
12 "github.com/charmbracelet/crush/internal/resolver"
13)
14
15type EventType string
16
17const maxRetries = 8
18
19const (
20 EventContentStart EventType = "content_start"
21 EventToolUseStart EventType = "tool_use_start"
22 EventToolUseDelta EventType = "tool_use_delta"
23 EventToolUseStop EventType = "tool_use_stop"
24 EventContentDelta EventType = "content_delta"
25 EventThinkingDelta EventType = "thinking_delta"
26 EventSignatureDelta EventType = "signature_delta"
27 EventContentStop EventType = "content_stop"
28 EventComplete EventType = "complete"
29 EventError EventType = "error"
30 EventWarning EventType = "warning"
31)
32
33type TokenUsage struct {
34 InputTokens int64
35 OutputTokens int64
36 CacheCreationTokens int64
37 CacheReadTokens int64
38}
39
40type ProviderResponse struct {
41 Content string
42 ToolCalls []message.ToolCall
43 Usage TokenUsage
44 FinishReason message.FinishReason
45}
46
47type ProviderEvent struct {
48 Type EventType
49
50 Content string
51 Thinking string
52 Signature string
53 Response *ProviderResponse
54 ToolCall *message.ToolCall
55 Error error
56}
57
58type Config struct {
59 // The provider's id.
60 ID string `json:"id,omitempty"`
61 // The provider's name, used for display purposes.
62 Name string `json:"name,omitempty"`
63 // The provider's API endpoint.
64 BaseURL string `json:"base_url,omitempty"`
65 // The provider type, e.g. "openai", "anthropic", etc. if empty it defaults to openai.
66 Type catwalk.Type `json:"type,omitempty"`
67 // The provider's API key.
68 APIKey string `json:"api_key,omitempty"`
69 // Marks the provider as disabled.
70 Disable bool `json:"disable,omitempty"`
71
72 // Custom system prompt prefix.
73 SystemPromptPrefix string `json:"system_prompt_prefix,omitempty"`
74
75 // Extra headers to send with each request to the provider.
76 ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
77 // Extra body
78 ExtraBody map[string]any `json:"extra_body,omitempty"`
79
80 // Used to pass extra parameters to the provider.
81 ExtraParams map[string]string `json:"-"`
82
83 // The provider models
84 Models []catwalk.Model `json:"models,omitempty"`
85}
86
87type Provider interface {
88 Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
89
90 Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
91
92 Model(modelID string) *catwalk.Model
93
94 SetDebug(debug bool)
95}
96
97type baseProvider struct {
98 baseURL string
99 debug bool
100 config Config
101 apiKey string
102 disableCache bool
103 systemMessage string
104 systemPromptPrefix string
105 maxTokens int64
106 think bool
107 reasoningEffort string
108 resolver resolver.Resolver
109 extraHeaders map[string]string
110 extraBody map[string]any
111 extraParams map[string]string
112}
113
114type Option func(*baseProvider)
115
116func WithDisableCache(disableCache bool) Option {
117 return func(options *baseProvider) {
118 options.disableCache = disableCache
119 }
120}
121
122func WithSystemMessage(systemMessage string) Option {
123 return func(options *baseProvider) {
124 options.systemMessage = systemMessage
125 }
126}
127
128func WithMaxTokens(maxTokens int64) Option {
129 return func(options *baseProvider) {
130 options.maxTokens = maxTokens
131 }
132}
133
134func WithThinking(think bool) Option {
135 return func(options *baseProvider) {
136 options.think = think
137 }
138}
139
140func WithReasoningEffort(reasoningEffort string) Option {
141 return func(options *baseProvider) {
142 options.reasoningEffort = reasoningEffort
143 }
144}
145
146func WithDebug(debug bool) Option {
147 return func(options *baseProvider) {
148 options.debug = debug
149 }
150}
151
152func WithResolver(resolver resolver.Resolver) Option {
153 return func(options *baseProvider) {
154 options.resolver = resolver
155 }
156}
157
158func newBaseProvider(cfg Config, opts ...Option) (*baseProvider, error) {
159 provider := &baseProvider{
160 baseURL: cfg.BaseURL,
161 config: cfg,
162 apiKey: cfg.APIKey,
163 extraHeaders: cfg.ExtraHeaders,
164 extraBody: cfg.ExtraBody,
165 systemPromptPrefix: cfg.SystemPromptPrefix,
166 resolver: resolver.New(),
167 }
168 for _, o := range opts {
169 o(provider)
170 }
171
172 resolvedAPIKey, err := provider.resolver.ResolveValue(cfg.APIKey)
173 if err != nil {
174 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
175 }
176
177 resolvedBaseURL, err := provider.resolver.ResolveValue(cfg.BaseURL)
178 if err != nil {
179 resolvedBaseURL = ""
180 }
181 // Resolve extra headers
182 resolvedExtraHeaders := make(map[string]string)
183 for key, value := range cfg.ExtraHeaders {
184 resolvedValue, err := provider.resolver.ResolveValue(value)
185 if err != nil {
186 return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
187 }
188 resolvedExtraHeaders[key] = resolvedValue
189 }
190
191 provider.apiKey = resolvedAPIKey
192 provider.baseURL = resolvedBaseURL
193 provider.extraHeaders = resolvedExtraHeaders
194 return provider, nil
195}
196
197func NewProvider(cfg Config, opts ...Option) (Provider, error) {
198 base, err := newBaseProvider(cfg, opts...)
199 if err != nil {
200 return nil, err
201 }
202 switch cfg.Type {
203 case catwalk.TypeAnthropic:
204 return NewAnthropicProvider(base, false), nil
205 case catwalk.TypeOpenAI:
206 return NewOpenAIProvider(base), nil
207 case catwalk.TypeGemini:
208 return NewGeminiProvider(base), nil
209 case catwalk.TypeBedrock:
210 return NewBedrockProvider(base), nil
211 case catwalk.TypeAzure:
212 return NewAzureProvider(base), nil
213 case catwalk.TypeVertexAI:
214 return NewVertexAIProvider(base), nil
215 }
216 return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
217}
218
219func (p *baseProvider) cleanMessages(messages []message.Message) (cleaned []message.Message) {
220 for _, msg := range messages {
221 // The message has no content
222 if len(msg.Parts) == 0 {
223 continue
224 }
225 cleaned = append(cleaned, msg)
226 }
227 return
228}
229
230func (o *baseProvider) Model(model string) *catwalk.Model {
231 for _, m := range o.config.Models {
232 if m.ID == model {
233 return &m
234 }
235 }
236 return nil
237}
238
239func (o *baseProvider) SetDebug(debug bool) {
240 o.debug = debug
241}
242
243func (c *Config) TestConnection(resolver resolver.Resolver) error {
244 testURL := ""
245 headers := make(map[string]string)
246 apiKey, _ := resolver.ResolveValue(c.APIKey)
247 switch c.Type {
248 case catwalk.TypeOpenAI:
249 baseURL, _ := resolver.ResolveValue(c.BaseURL)
250 if baseURL == "" {
251 baseURL = "https://api.openai.com/v1"
252 }
253 testURL = baseURL + "/models"
254 headers["Authorization"] = "Bearer " + apiKey
255 case catwalk.TypeAnthropic:
256 baseURL, _ := resolver.ResolveValue(c.BaseURL)
257 if baseURL == "" {
258 baseURL = "https://api.anthropic.com/v1"
259 }
260 testURL = baseURL + "/models"
261 headers["x-api-key"] = apiKey
262 headers["anthropic-version"] = "2023-06-01"
263 }
264 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
265 defer cancel()
266 client := &http.Client{}
267 req, err := http.NewRequestWithContext(ctx, "GET", testURL, nil)
268 if err != nil {
269 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
270 }
271 for k, v := range headers {
272 req.Header.Set(k, v)
273 }
274 for k, v := range c.ExtraHeaders {
275 req.Header.Set(k, v)
276 }
277 b, err := client.Do(req)
278 if err != nil {
279 return fmt.Errorf("failed to create request for provider %s: %w", c.ID, err)
280 }
281 if b.StatusCode != http.StatusOK {
282 return fmt.Errorf("failed to connect to provider %s: %s", c.ID, b.Status)
283 }
284 _ = b.Body.Close()
285 return nil
286}