1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/catwalk/pkg/catwalk"
8
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/charmbracelet/crush/internal/message"
12)
13
14type EventType string
15
16const maxRetries = 3
17
18const (
19 EventContentStart EventType = "content_start"
20 EventToolUseStart EventType = "tool_use_start"
21 EventToolUseDelta EventType = "tool_use_delta"
22 EventToolUseStop EventType = "tool_use_stop"
23 EventContentDelta EventType = "content_delta"
24 EventThinkingDelta EventType = "thinking_delta"
25 EventSignatureDelta EventType = "signature_delta"
26 EventContentStop EventType = "content_stop"
27 EventComplete EventType = "complete"
28 EventError EventType = "error"
29 EventWarning EventType = "warning"
30)
31
32type TokenUsage struct {
33 InputTokens int64
34 OutputTokens int64
35 CacheCreationTokens int64
36 CacheReadTokens int64
37}
38
39type ProviderResponse struct {
40 Content string
41 ToolCalls []message.ToolCall
42 Usage TokenUsage
43 FinishReason message.FinishReason
44}
45
46type ProviderEvent struct {
47 Type EventType
48
49 Content string
50 Thinking string
51 Signature string
52 Response *ProviderResponse
53 ToolCall *message.ToolCall
54 Error error
55}
56type Provider interface {
57 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
58
59 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
60
61 Model() catwalk.Model
62}
63
64type providerClientOptions struct {
65 cfg *config.Config
66
67 resolver config.VariableResolver
68
69 baseURL string
70 config config.ProviderConfig
71 apiKey string
72 modelType config.SelectedModelType
73 model func(config.SelectedModelType) catwalk.Model
74 disableCache bool
75 systemMessage string
76 systemPromptPrefix string
77 maxTokens int64
78 extraHeaders map[string]string
79 extraBody map[string]any
80 extraParams map[string]string
81}
82
83type ProviderClientOption func(*providerClientOptions)
84
85type ProviderClient interface {
86 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
87 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
88
89 Model() catwalk.Model
90}
91
92type baseProvider[C ProviderClient] struct {
93 options providerClientOptions
94 client C
95}
96
97func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
98 for _, msg := range messages {
99 // The message has no content
100 if len(msg.Parts) == 0 {
101 continue
102 }
103 cleaned = append(cleaned, msg)
104 }
105 return cleaned
106}
107
108func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
109 messages = p.cleanMessages(messages)
110 return p.client.send(ctx, messages, tools)
111}
112
113func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
114 messages = p.cleanMessages(messages)
115 return p.client.stream(ctx, messages, tools)
116}
117
118func (p *baseProvider[C]) Model() catwalk.Model {
119 return p.client.Model()
120}
121
122func WithModel(model config.SelectedModelType) ProviderClientOption {
123 return func(options *providerClientOptions) {
124 options.modelType = model
125 }
126}
127
128func WithDisableCache(disableCache bool) ProviderClientOption {
129 return func(options *providerClientOptions) {
130 options.disableCache = disableCache
131 }
132}
133
134func WithSystemMessage(systemMessage string) ProviderClientOption {
135 return func(options *providerClientOptions) {
136 options.systemMessage = systemMessage
137 }
138}
139
140func WithMaxTokens(maxTokens int64) ProviderClientOption {
141 return func(options *providerClientOptions) {
142 options.maxTokens = maxTokens
143 }
144}
145
146func WithResolver(resolver config.VariableResolver) ProviderClientOption {
147 return func(options *providerClientOptions) {
148 options.resolver = resolver
149 }
150}
151
152func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
153 clientOptions := providerClientOptions{
154 cfg: cfg,
155 baseURL: pcfg.BaseURL,
156 config: pcfg,
157 extraBody: pcfg.ExtraBody,
158 extraParams: pcfg.ExtraParams,
159 systemPromptPrefix: pcfg.SystemPromptPrefix,
160 model: func(tp config.SelectedModelType) catwalk.Model {
161 return *cfg.GetModelByType(tp)
162 },
163 }
164 for _, o := range opts {
165 o(&clientOptions)
166 }
167 if clientOptions.resolver == nil {
168 clientOptions.resolver = config.OsShellResolver
169 }
170
171 restore := config.PushPopCrushEnv()
172 defer restore()
173 resolvedAPIKey, err := clientOptions.resolver.ResolveValue(pcfg.APIKey)
174 if err != nil {
175 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err)
176 }
177
178 // Resolve extra headers
179 resolvedExtraHeaders := make(map[string]string)
180 for key, value := range pcfg.ExtraHeaders {
181 resolvedValue, err := clientOptions.resolver.ResolveValue(value)
182 if err != nil {
183 return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err)
184 }
185 resolvedExtraHeaders[key] = resolvedValue
186 }
187
188 clientOptions.apiKey = resolvedAPIKey
189 clientOptions.extraHeaders = resolvedExtraHeaders
190
191 switch pcfg.Type {
192 case catwalk.TypeAnthropic:
193 return &baseProvider[AnthropicClient]{
194 options: clientOptions,
195 client: newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
196 }, nil
197 case catwalk.TypeOpenAI:
198 return &baseProvider[OpenAIClient]{
199 options: clientOptions,
200 client: newOpenAIClient(clientOptions),
201 }, nil
202 case catwalk.TypeGemini:
203 return &baseProvider[GeminiClient]{
204 options: clientOptions,
205 client: newGeminiClient(clientOptions),
206 }, nil
207 case catwalk.TypeBedrock:
208 return &baseProvider[BedrockClient]{
209 options: clientOptions,
210 client: newBedrockClient(clientOptions),
211 }, nil
212 case catwalk.TypeAzure:
213 return &baseProvider[AzureClient]{
214 options: clientOptions,
215 client: newAzureClient(clientOptions),
216 }, nil
217 case catwalk.TypeVertexAI:
218 return &baseProvider[VertexAIClient]{
219 options: clientOptions,
220 client: newVertexAIClient(clientOptions),
221 }, nil
222 }
223 return nil, fmt.Errorf("provider not supported: %s", pcfg.Type)
224}