1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/catwalk/pkg/catwalk"
8 "github.com/charmbracelet/crush/internal/config"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventSignatureDelta EventType = "signature_delta"
25 EventContentStop EventType = "content_stop"
26 EventComplete EventType = "complete"
27 EventError EventType = "error"
28 EventWarning EventType = "warning"
29)
30
31type TokenUsage struct {
32 InputTokens int64
33 OutputTokens int64
34 CacheCreationTokens int64
35 CacheReadTokens int64
36}
37
38type ProviderResponse struct {
39 Content string
40 ToolCalls []message.ToolCall
41 Usage TokenUsage
42 FinishReason message.FinishReason
43}
44
45type ProviderEvent struct {
46 Type EventType
47
48 Content string
49 Thinking string
50 Signature string
51 Response *ProviderResponse
52 ToolCall *message.ToolCall
53 Error error
54}
55type Provider interface {
56 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
57
58 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
59
60 Model() catwalk.Model
61}
62
63type providerClientOptions struct {
64 baseURL string
65 config config.ProviderConfig
66 apiKey string
67 modelType config.SelectedModelType
68 model func(config.SelectedModelType) catwalk.Model
69 disableCache bool
70 systemMessage string
71 systemPromptPrefix string
72 maxTokens int64
73 extraHeaders map[string]string
74 extraBody map[string]any
75 extraParams map[string]string
76}
77
78type ProviderClientOption func(*providerClientOptions)
79
80type ProviderClient interface {
81 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
82 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
83
84 Model() catwalk.Model
85}
86
87type baseProvider[C ProviderClient] struct {
88 options providerClientOptions
89 client C
90}
91
92func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
93 for _, msg := range messages {
94 // The message has no content
95 if len(msg.Parts) == 0 {
96 continue
97 }
98 cleaned = append(cleaned, msg)
99 }
100 return
101}
102
103func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
104 messages = p.cleanMessages(messages)
105 return p.client.send(ctx, messages, tools)
106}
107
108func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
109 messages = p.cleanMessages(messages)
110 return p.client.stream(ctx, messages, tools)
111}
112
113func (p *baseProvider[C]) Model() catwalk.Model {
114 return p.client.Model()
115}
116
117func WithModel(model config.SelectedModelType) ProviderClientOption {
118 return func(options *providerClientOptions) {
119 options.modelType = model
120 }
121}
122
123func WithDisableCache(disableCache bool) ProviderClientOption {
124 return func(options *providerClientOptions) {
125 options.disableCache = disableCache
126 }
127}
128
129func WithSystemMessage(systemMessage string) ProviderClientOption {
130 return func(options *providerClientOptions) {
131 options.systemMessage = systemMessage
132 }
133}
134
135func WithMaxTokens(maxTokens int64) ProviderClientOption {
136 return func(options *providerClientOptions) {
137 options.maxTokens = maxTokens
138 }
139}
140
141func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
142 resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
143 if err != nil {
144 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
145 }
146
147 // Resolve extra headers
148 resolvedExtraHeaders := make(map[string]string)
149 for key, value := range cfg.ExtraHeaders {
150 resolvedValue, err := config.Get().Resolve(value)
151 if err != nil {
152 return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
153 }
154 resolvedExtraHeaders[key] = resolvedValue
155 }
156
157 clientOptions := providerClientOptions{
158 baseURL: cfg.BaseURL,
159 config: cfg,
160 apiKey: resolvedAPIKey,
161 extraHeaders: resolvedExtraHeaders,
162 extraBody: cfg.ExtraBody,
163 extraParams: cfg.ExtraParams,
164 systemPromptPrefix: cfg.SystemPromptPrefix,
165 model: func(tp config.SelectedModelType) catwalk.Model {
166 return *config.Get().GetModelByType(tp)
167 },
168 }
169 for _, o := range opts {
170 o(&clientOptions)
171 }
172 switch cfg.Type {
173 case catwalk.TypeAnthropic:
174 return &baseProvider[AnthropicClient]{
175 options: clientOptions,
176 client: newAnthropicClient(clientOptions, false),
177 }, nil
178 case catwalk.TypeOpenAI:
179 return &baseProvider[OpenAIClient]{
180 options: clientOptions,
181 client: newOpenAIClient(clientOptions),
182 }, nil
183 case catwalk.TypeGemini:
184 return &baseProvider[GeminiClient]{
185 options: clientOptions,
186 client: newGeminiClient(clientOptions),
187 }, nil
188 case catwalk.TypeBedrock:
189 return &baseProvider[BedrockClient]{
190 options: clientOptions,
191 client: newBedrockClient(clientOptions),
192 }, nil
193 case catwalk.TypeAzure:
194 return &baseProvider[AzureClient]{
195 options: clientOptions,
196 client: newAzureClient(clientOptions),
197 }, nil
198 case catwalk.TypeVertexAI:
199 return &baseProvider[VertexAIClient]{
200 options: clientOptions,
201 client: newVertexAIClient(clientOptions),
202 }, nil
203 }
204 return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
205}