1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/catwalk/pkg/catwalk"
8
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/charmbracelet/crush/internal/message"
12)
13
14type EventType string
15
16const maxRetries = 8
17
18const (
19 EventContentStart EventType = "content_start"
20 EventToolUseStart EventType = "tool_use_start"
21 EventToolUseDelta EventType = "tool_use_delta"
22 EventToolUseStop EventType = "tool_use_stop"
23 EventContentDelta EventType = "content_delta"
24 EventThinkingDelta EventType = "thinking_delta"
25 EventSignatureDelta EventType = "signature_delta"
26 EventContentStop EventType = "content_stop"
27 EventComplete EventType = "complete"
28 EventError EventType = "error"
29 EventWarning EventType = "warning"
30)
31
32type TokenUsage struct {
33 InputTokens int64
34 OutputTokens int64
35 CacheCreationTokens int64
36 CacheReadTokens int64
37}
38
39type ProviderResponse struct {
40 Content string
41 ToolCalls []message.ToolCall
42 Usage TokenUsage
43 FinishReason message.FinishReason
44}
45
46type ProviderEvent struct {
47 Type EventType
48
49 Content string
50 Thinking string
51 Signature string
52 Response *ProviderResponse
53 ToolCall *message.ToolCall
54 Error error
55}
56type Provider interface {
57 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
58
59 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
60
61 Model() catwalk.Model
62}
63
64type providerClientOptions struct {
65 cfg *config.Config
66
67 baseURL string
68 config config.ProviderConfig
69 apiKey string
70 modelType config.SelectedModelType
71 model func(config.SelectedModelType) catwalk.Model
72 disableCache bool
73 systemMessage string
74 systemPromptPrefix string
75 maxTokens int64
76 extraHeaders map[string]string
77 extraBody map[string]any
78 extraParams map[string]string
79}
80
81type ProviderClientOption func(*providerClientOptions)
82
83type ProviderClient interface {
84 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
85 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
86
87 Model() catwalk.Model
88}
89
90type baseProvider[C ProviderClient] struct {
91 options providerClientOptions
92 client C
93}
94
95func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
96 for _, msg := range messages {
97 // The message has no content
98 if len(msg.Parts) == 0 {
99 continue
100 }
101 cleaned = append(cleaned, msg)
102 }
103 return cleaned
104}
105
106func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
107 messages = p.cleanMessages(messages)
108 return p.client.send(ctx, messages, tools)
109}
110
111func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
112 messages = p.cleanMessages(messages)
113 return p.client.stream(ctx, messages, tools)
114}
115
116func (p *baseProvider[C]) Model() catwalk.Model {
117 return p.client.Model()
118}
119
120func WithModel(model config.SelectedModelType) ProviderClientOption {
121 return func(options *providerClientOptions) {
122 options.modelType = model
123 }
124}
125
126func WithDisableCache(disableCache bool) ProviderClientOption {
127 return func(options *providerClientOptions) {
128 options.disableCache = disableCache
129 }
130}
131
132func WithSystemMessage(systemMessage string) ProviderClientOption {
133 return func(options *providerClientOptions) {
134 options.systemMessage = systemMessage
135 }
136}
137
138func WithMaxTokens(maxTokens int64) ProviderClientOption {
139 return func(options *providerClientOptions) {
140 options.maxTokens = maxTokens
141 }
142}
143
144func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
145 restore := config.PushPopCrushEnv()
146 defer restore()
147 resolvedAPIKey, err := cfg.Resolve(pcfg.APIKey)
148 if err != nil {
149 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err)
150 }
151
152 // Resolve extra headers
153 resolvedExtraHeaders := make(map[string]string)
154 for key, value := range pcfg.ExtraHeaders {
155 resolvedValue, err := cfg.Resolve(value)
156 if err != nil {
157 return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err)
158 }
159 resolvedExtraHeaders[key] = resolvedValue
160 }
161
162 clientOptions := providerClientOptions{
163 cfg: cfg,
164 baseURL: pcfg.BaseURL,
165 config: pcfg,
166 apiKey: resolvedAPIKey,
167 extraHeaders: resolvedExtraHeaders,
168 extraBody: pcfg.ExtraBody,
169 extraParams: pcfg.ExtraParams,
170 systemPromptPrefix: pcfg.SystemPromptPrefix,
171 model: func(tp config.SelectedModelType) catwalk.Model {
172 return *cfg.GetModelByType(tp)
173 },
174 }
175 for _, o := range opts {
176 o(&clientOptions)
177 }
178 switch pcfg.Type {
179 case catwalk.TypeAnthropic:
180 return &baseProvider[AnthropicClient]{
181 options: clientOptions,
182 client: newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
183 }, nil
184 case catwalk.TypeOpenAI:
185 return &baseProvider[OpenAIClient]{
186 options: clientOptions,
187 client: newOpenAIClient(clientOptions),
188 }, nil
189 case catwalk.TypeGemini:
190 return &baseProvider[GeminiClient]{
191 options: clientOptions,
192 client: newGeminiClient(clientOptions),
193 }, nil
194 case catwalk.TypeBedrock:
195 return &baseProvider[BedrockClient]{
196 options: clientOptions,
197 client: newBedrockClient(clientOptions),
198 }, nil
199 case catwalk.TypeAzure:
200 return &baseProvider[AzureClient]{
201 options: clientOptions,
202 client: newAzureClient(clientOptions),
203 }, nil
204 case catwalk.TypeVertexAI:
205 return &baseProvider[VertexAIClient]{
206 options: clientOptions,
207 client: newVertexAIClient(clientOptions),
208 }, nil
209 }
210 return nil, fmt.Errorf("provider not supported: %s", pcfg.Type)
211}