1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/catwalk/pkg/catwalk"
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventSignatureDelta EventType = "signature_delta"
25 EventContentStop EventType = "content_stop"
26 EventComplete EventType = "complete"
27 EventError EventType = "error"
28 EventWarning EventType = "warning"
29 EventRetry EventType = "retry"
30 EventRetrying EventType = "retrying"
31)
32
33type TokenUsage struct {
34 InputTokens int64
35 OutputTokens int64
36 CacheCreationTokens int64
37 CacheReadTokens int64
38}
39
40type ProviderResponse struct {
41 Content string
42 ToolCalls []message.ToolCall
43 Usage TokenUsage
44 FinishReason message.FinishReason
45}
46
47type ProviderEvent struct {
48 Type EventType
49
50 Content string
51 Thinking string
52 Signature string
53 Response *ProviderResponse
54 ToolCall *message.ToolCall
55 Retry int64
56 Error error
57}
58type Provider interface {
59 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
60
61 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
62
63 Model() catwalk.Model
64}
65
66type providerClientOptions struct {
67 baseURL string
68 config config.ProviderConfig
69 apiKey string
70 modelType config.SelectedModelType
71 model func(config.SelectedModelType) catwalk.Model
72 disableCache bool
73 systemMessage string
74 systemPromptPrefix string
75 maxTokens int64
76 extraHeaders map[string]string
77 extraBody map[string]any
78 extraParams map[string]string
79}
80
81type ProviderClientOption func(*providerClientOptions)
82
83type ProviderClient interface {
84 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
85 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
86
87 Model() catwalk.Model
88}
89
90type baseProvider[C ProviderClient] struct {
91 options providerClientOptions
92 client C
93}
94
95func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
96 for _, msg := range messages {
97 // The message has no content
98 if len(msg.Parts) == 0 {
99 continue
100 }
101 cleaned = append(cleaned, msg)
102 }
103 return
104}
105
106func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
107 messages = p.cleanMessages(messages)
108 return p.client.send(ctx, messages, tools)
109}
110
111func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
112 messages = p.cleanMessages(messages)
113 return p.client.stream(ctx, messages, tools)
114}
115
116func (p *baseProvider[C]) Model() catwalk.Model {
117 return p.client.Model()
118}
119
120func WithModel(model config.SelectedModelType) ProviderClientOption {
121 return func(options *providerClientOptions) {
122 options.modelType = model
123 }
124}
125
126func WithDisableCache(disableCache bool) ProviderClientOption {
127 return func(options *providerClientOptions) {
128 options.disableCache = disableCache
129 }
130}
131
132func WithSystemMessage(systemMessage string) ProviderClientOption {
133 return func(options *providerClientOptions) {
134 options.systemMessage = systemMessage
135 }
136}
137
138func WithMaxTokens(maxTokens int64) ProviderClientOption {
139 return func(options *providerClientOptions) {
140 options.maxTokens = maxTokens
141 }
142}
143
144func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
145 resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
146 if err != nil {
147 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
148 }
149
150 // Resolve extra headers
151 resolvedExtraHeaders := make(map[string]string)
152 for key, value := range cfg.ExtraHeaders {
153 resolvedValue, err := config.Get().Resolve(value)
154 if err != nil {
155 return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
156 }
157 resolvedExtraHeaders[key] = resolvedValue
158 }
159
160 clientOptions := providerClientOptions{
161 baseURL: cfg.BaseURL,
162 config: cfg,
163 apiKey: resolvedAPIKey,
164 extraHeaders: resolvedExtraHeaders,
165 extraBody: cfg.ExtraBody,
166 extraParams: cfg.ExtraParams,
167 systemPromptPrefix: cfg.SystemPromptPrefix,
168 model: func(tp config.SelectedModelType) catwalk.Model {
169 return *config.Get().GetModelByType(tp)
170 },
171 }
172 for _, o := range opts {
173 o(&clientOptions)
174 }
175 switch cfg.Type {
176 case catwalk.TypeAnthropic:
177 return &baseProvider[AnthropicClient]{
178 options: clientOptions,
179 client: newAnthropicClient(clientOptions, AnthropicClientTypeNormal),
180 }, nil
181 case catwalk.TypeOpenAI:
182 return &baseProvider[OpenAIClient]{
183 options: clientOptions,
184 client: newOpenAIClient(clientOptions),
185 }, nil
186 case catwalk.TypeGemini:
187 return &baseProvider[GeminiClient]{
188 options: clientOptions,
189 client: newGeminiClient(clientOptions),
190 }, nil
191 case catwalk.TypeBedrock:
192 return &baseProvider[BedrockClient]{
193 options: clientOptions,
194 client: newBedrockClient(clientOptions),
195 }, nil
196 case catwalk.TypeAzure:
197 return &baseProvider[AzureClient]{
198 options: clientOptions,
199 client: newAzureClient(clientOptions),
200 }, nil
201 case catwalk.TypeVertexAI:
202 return &baseProvider[VertexAIClient]{
203 options: clientOptions,
204 client: newVertexAIClient(clientOptions),
205 }, nil
206 }
207 return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
208}