1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/catwalk/pkg/catwalk"
8
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/charmbracelet/crush/internal/message"
12)
13
14type EventType string
15
16const maxRetries = 8
17
18const (
19 EventContentStart EventType = "content_start"
20 EventToolUseStart EventType = "tool_use_start"
21 EventToolUseDelta EventType = "tool_use_delta"
22 EventToolUseStop EventType = "tool_use_stop"
23 EventContentDelta EventType = "content_delta"
24 EventThinkingDelta EventType = "thinking_delta"
25 EventSignatureDelta EventType = "signature_delta"
26 EventContentStop EventType = "content_stop"
27 EventComplete EventType = "complete"
28 EventError EventType = "error"
29 EventWarning EventType = "warning"
30)
31
32type TokenUsage struct {
33 InputTokens int64
34 OutputTokens int64
35 CacheCreationTokens int64
36 CacheReadTokens int64
37}
38
39type ProviderResponse struct {
40 Content string
41 ToolCalls []message.ToolCall
42 Usage TokenUsage
43 FinishReason message.FinishReason
44}
45
46type ProviderEvent struct {
47 Type EventType
48
49 Content string
50 Thinking string
51 Signature string
52 Response *ProviderResponse
53 ToolCall *message.ToolCall
54 Error error
55}
56type Provider interface {
57 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
58
59 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
60
61 Model() catwalk.Model
62}
63
64type providerClientOptions struct {
65 baseURL string
66 config config.ProviderConfig
67 apiKey string
68 modelType config.SelectedModelType
69 model func(config.SelectedModelType) catwalk.Model
70 disableCache bool
71 systemMessage string
72 systemPromptPrefix string
73 maxTokens int64
74 extraHeaders map[string]string
75 extraBody map[string]any
76 extraParams map[string]string
77}
78
79type ProviderClientOption func(*providerClientOptions)
80
81type ProviderClient interface {
82 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
83 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
84
85 Model() catwalk.Model
86}
87
88type baseProvider[C ProviderClient] struct {
89 options providerClientOptions
90 client C
91}
92
93func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
94 for _, msg := range messages {
95 // The message has no content
96 if len(msg.Parts) == 0 {
97 continue
98 }
99 cleaned = append(cleaned, msg)
100 }
101 return cleaned
102}
103
104func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
105 messages = p.cleanMessages(messages)
106 return p.client.send(ctx, messages, tools)
107}
108
109func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
110 messages = p.cleanMessages(messages)
111 return p.client.stream(ctx, messages, tools)
112}
113
114func (p *baseProvider[C]) Model() catwalk.Model {
115 return p.client.Model()
116}
117
118func WithModel(model config.SelectedModelType) ProviderClientOption {
119 return func(options *providerClientOptions) {
120 options.modelType = model
121 }
122}
123
124func WithDisableCache(disableCache bool) ProviderClientOption {
125 return func(options *providerClientOptions) {
126 options.disableCache = disableCache
127 }
128}
129
130func WithSystemMessage(systemMessage string) ProviderClientOption {
131 return func(options *providerClientOptions) {
132 options.systemMessage = systemMessage
133 }
134}
135
136func WithMaxTokens(maxTokens int64) ProviderClientOption {
137 return func(options *providerClientOptions) {
138 options.maxTokens = maxTokens
139 }
140}
141
142func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
143 restore := config.PushPopCrushEnv()
144 defer restore()
145 resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
146 if err != nil {
147 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
148 }
149
150 // Resolve extra headers
151 resolvedExtraHeaders := make(map[string]string)
152 for key, value := range cfg.ExtraHeaders {
153 resolvedValue, err := config.Get().Resolve(value)
154 if err != nil {
155 return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
156 }
157 resolvedExtraHeaders[key] = resolvedValue
158 }
159
160 clientOptions := providerClientOptions{
161 baseURL: cfg.BaseURL,
162 config: cfg,
163 apiKey: resolvedAPIKey,
164 extraHeaders: resolvedExtraHeaders,
165 extraBody: cfg.ExtraBody,
166 extraParams: cfg.ExtraParams,
167 systemPromptPrefix: cfg.SystemPromptPrefix,
168 model: func(tp config.SelectedModelType) catwalk.Model {
169 return *config.Get().GetModelByType(tp)
170 },
171 }
172 for _, o := range opts {
173 o(&clientOptions)
174 }
175 switch cfg.Type {
176 case catwalk.TypeAnthropic:
177 return &baseProvider[AnthropicClient]{
178 options: clientOptions,
179 client: newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
180 }, nil
181 case catwalk.TypeOpenAI:
182 return &baseProvider[OpenAIClient]{
183 options: clientOptions,
184 client: newOpenAIClient(clientOptions),
185 }, nil
186 case catwalk.TypeGemini:
187 return &baseProvider[GeminiClient]{
188 options: clientOptions,
189 client: newGeminiClient(clientOptions),
190 }, nil
191 case catwalk.TypeBedrock:
192 return &baseProvider[BedrockClient]{
193 options: clientOptions,
194 client: newBedrockClient(clientOptions),
195 }, nil
196 case catwalk.TypeAzure:
197 return &baseProvider[AzureClient]{
198 options: clientOptions,
199 client: newAzureClient(clientOptions),
200 }, nil
201 case catwalk.TypeVertexAI:
202 return &baseProvider[VertexAIClient]{
203 options: clientOptions,
204 client: newVertexAIClient(clientOptions),
205 }, nil
206 }
207 return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
208}