1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/fur/provider"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventSignatureDelta EventType = "signature_delta"
25 EventContentStop EventType = "content_stop"
26 EventComplete EventType = "complete"
27 EventError EventType = "error"
28 EventWarning EventType = "warning"
29)
30
31type TokenUsage struct {
32 InputTokens int64
33 OutputTokens int64
34 CacheCreationTokens int64
35 CacheReadTokens int64
36}
37
38type ProviderResponse struct {
39 Content string
40 ToolCalls []message.ToolCall
41 Usage TokenUsage
42 FinishReason message.FinishReason
43}
44
45type ProviderEvent struct {
46 Type EventType
47
48 Content string
49 Thinking string
50 Signature string
51 Response *ProviderResponse
52 ToolCall *message.ToolCall
53 Error error
54}
55type Provider interface {
56 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
57
58 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
59
60 Model() provider.Model
61}
62
63type providerClientOptions struct {
64 baseURL string
65 config config.ProviderConfig
66 apiKey string
67 modelType config.SelectedModelType
68 model func(config.SelectedModelType) provider.Model
69 disableCache bool
70 systemMessage string
71 maxTokens int64
72 extraHeaders map[string]string
73 extraParams map[string]string
74}
75
76type ProviderClientOption func(*providerClientOptions)
77
78type ProviderClient interface {
79 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
80 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
81
82 Model() provider.Model
83}
84
85type baseProvider[C ProviderClient] struct {
86 options providerClientOptions
87 client C
88}
89
90func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
91 for _, msg := range messages {
92 // The message has no content
93 if len(msg.Parts) == 0 {
94 continue
95 }
96 cleaned = append(cleaned, msg)
97 }
98 return
99}
100
101func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
102 messages = p.cleanMessages(messages)
103 return p.client.send(ctx, messages, tools)
104}
105
106func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
107 messages = p.cleanMessages(messages)
108 return p.client.stream(ctx, messages, tools)
109}
110
111func (p *baseProvider[C]) Model() provider.Model {
112 return p.client.Model()
113}
114
115func WithModel(model config.SelectedModelType) ProviderClientOption {
116 return func(options *providerClientOptions) {
117 options.modelType = model
118 }
119}
120
121func WithDisableCache(disableCache bool) ProviderClientOption {
122 return func(options *providerClientOptions) {
123 options.disableCache = disableCache
124 }
125}
126
127func WithSystemMessage(systemMessage string) ProviderClientOption {
128 return func(options *providerClientOptions) {
129 options.systemMessage = systemMessage
130 }
131}
132
133func WithMaxTokens(maxTokens int64) ProviderClientOption {
134 return func(options *providerClientOptions) {
135 options.maxTokens = maxTokens
136 }
137}
138
139func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
140 resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
141 if err != nil {
142 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
143 }
144
145 clientOptions := providerClientOptions{
146 baseURL: cfg.BaseURL,
147 config: cfg,
148 apiKey: resolvedAPIKey,
149 extraHeaders: cfg.ExtraHeaders,
150 model: func(tp config.SelectedModelType) provider.Model {
151 return *config.Get().GetModelByType(tp)
152 },
153 }
154 for _, o := range opts {
155 o(&clientOptions)
156 }
157 switch cfg.Type {
158 case provider.TypeAnthropic:
159 return &baseProvider[AnthropicClient]{
160 options: clientOptions,
161 client: newAnthropicClient(clientOptions, false),
162 }, nil
163 case provider.TypeOpenAI:
164 return &baseProvider[OpenAIClient]{
165 options: clientOptions,
166 client: newOpenAIClient(clientOptions),
167 }, nil
168 case provider.TypeGemini:
169 return &baseProvider[GeminiClient]{
170 options: clientOptions,
171 client: newGeminiClient(clientOptions),
172 }, nil
173 case provider.TypeBedrock:
174 return &baseProvider[BedrockClient]{
175 options: clientOptions,
176 client: newBedrockClient(clientOptions),
177 }, nil
178 case provider.TypeAzure:
179 return &baseProvider[AzureClient]{
180 options: clientOptions,
181 client: newAzureClient(clientOptions),
182 }, nil
183 case provider.TypeVertexAI:
184 return &baseProvider[VertexAIClient]{
185 options: clientOptions,
186 client: newVertexAIClient(clientOptions),
187 }, nil
188 case provider.TypeXAI:
189 clientOptions.baseURL = "https://api.x.ai/v1"
190 return &baseProvider[OpenAIClient]{
191 options: clientOptions,
192 client: newOpenAIClient(clientOptions),
193 }, nil
194 }
195 return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
196}