1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/fur/provider"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventSignatureDelta EventType = "signature_delta"
25 EventContentStop EventType = "content_stop"
26 EventComplete EventType = "complete"
27 EventError EventType = "error"
28 EventWarning EventType = "warning"
29)
30
31type TokenUsage struct {
32 InputTokens int64
33 OutputTokens int64
34 CacheCreationTokens int64
35 CacheReadTokens int64
36}
37
38type ProviderResponse struct {
39 Content string
40 ToolCalls []message.ToolCall
41 Usage TokenUsage
42 FinishReason message.FinishReason
43}
44
45type ProviderEvent struct {
46 Type EventType
47
48 Content string
49 Thinking string
50 Signature string
51 Response *ProviderResponse
52 ToolCall *message.ToolCall
53 Error error
54}
55type Provider interface {
56 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
57
58 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
59
60 Model() provider.Model
61}
62
63type providerClientOptions struct {
64 baseURL string
65 config config.ProviderConfig
66 apiKey string
67 modelType config.SelectedModelType
68 model func(config.SelectedModelType) provider.Model
69 disableCache bool
70 systemMessage string
71 maxTokens int64
72 extraHeaders map[string]string
73 extraBody map[string]any
74 extraParams map[string]string
75}
76
77type ProviderClientOption func(*providerClientOptions)
78
79type ProviderClient interface {
80 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
81 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
82
83 Model() provider.Model
84}
85
86type baseProvider[C ProviderClient] struct {
87 options providerClientOptions
88 client C
89}
90
91func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
92 for _, msg := range messages {
93 // The message has no content
94 if len(msg.Parts) == 0 {
95 continue
96 }
97 cleaned = append(cleaned, msg)
98 }
99 return
100}
101
102func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
103 messages = p.cleanMessages(messages)
104 return p.client.send(ctx, messages, tools)
105}
106
107func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
108 messages = p.cleanMessages(messages)
109 return p.client.stream(ctx, messages, tools)
110}
111
112func (p *baseProvider[C]) Model() provider.Model {
113 return p.client.Model()
114}
115
116func WithModel(model config.SelectedModelType) ProviderClientOption {
117 return func(options *providerClientOptions) {
118 options.modelType = model
119 }
120}
121
122func WithDisableCache(disableCache bool) ProviderClientOption {
123 return func(options *providerClientOptions) {
124 options.disableCache = disableCache
125 }
126}
127
128func WithSystemMessage(systemMessage string) ProviderClientOption {
129 return func(options *providerClientOptions) {
130 options.systemMessage = systemMessage
131 }
132}
133
134func WithMaxTokens(maxTokens int64) ProviderClientOption {
135 return func(options *providerClientOptions) {
136 options.maxTokens = maxTokens
137 }
138}
139
140func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
141 resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
142 if err != nil {
143 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
144 }
145
146 clientOptions := providerClientOptions{
147 baseURL: cfg.BaseURL,
148 config: cfg,
149 apiKey: resolvedAPIKey,
150 extraHeaders: cfg.ExtraHeaders,
151 extraBody: cfg.ExtraBody,
152 model: func(tp config.SelectedModelType) provider.Model {
153 return *config.Get().GetModelByType(tp)
154 },
155 }
156 for _, o := range opts {
157 o(&clientOptions)
158 }
159 switch cfg.Type {
160 case provider.TypeAnthropic:
161 return &baseProvider[AnthropicClient]{
162 options: clientOptions,
163 client: newAnthropicClient(clientOptions, false),
164 }, nil
165 case provider.TypeOpenAI:
166 return &baseProvider[OpenAIClient]{
167 options: clientOptions,
168 client: newOpenAIClient(clientOptions),
169 }, nil
170 case provider.TypeGemini:
171 return &baseProvider[GeminiClient]{
172 options: clientOptions,
173 client: newGeminiClient(clientOptions),
174 }, nil
175 case provider.TypeBedrock:
176 return &baseProvider[BedrockClient]{
177 options: clientOptions,
178 client: newBedrockClient(clientOptions),
179 }, nil
180 case provider.TypeAzure:
181 return &baseProvider[AzureClient]{
182 options: clientOptions,
183 client: newAzureClient(clientOptions),
184 }, nil
185 case provider.TypeVertexAI:
186 return &baseProvider[VertexAIClient]{
187 options: clientOptions,
188 client: newVertexAIClient(clientOptions),
189 }, nil
190 case provider.TypeXAI:
191 clientOptions.baseURL = "https://api.x.ai/v1"
192 return &baseProvider[OpenAIClient]{
193 options: clientOptions,
194 client: newOpenAIClient(clientOptions),
195 }, nil
196 }
197 return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
198}