1package provider
2
3import (
4 "context"
5 "fmt"
6
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/charmbracelet/crush/internal/fur/provider"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventContentStop EventType = "content_stop"
25 EventComplete EventType = "complete"
26 EventError EventType = "error"
27 EventWarning EventType = "warning"
28)
29
30type TokenUsage struct {
31 InputTokens int64
32 OutputTokens int64
33 CacheCreationTokens int64
34 CacheReadTokens int64
35}
36
37type ProviderResponse struct {
38 Content string
39 ToolCalls []message.ToolCall
40 Usage TokenUsage
41 FinishReason message.FinishReason
42}
43
44type ProviderEvent struct {
45 Type EventType
46
47 Content string
48 Thinking string
49 Response *ProviderResponse
50 ToolCall *message.ToolCall
51 Error error
52}
53type Provider interface {
54 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
55
56 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
57
58 Model() config.Model
59}
60
61type providerClientOptions struct {
62 baseURL string
63 config config.ProviderConfig
64 apiKey string
65 modelType config.ModelType
66 model func(config.ModelType) config.Model
67 disableCache bool
68 systemMessage string
69 maxTokens int64
70 extraHeaders map[string]string
71 extraParams map[string]string
72}
73
74type ProviderClientOption func(*providerClientOptions)
75
76type ProviderClient interface {
77 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
78 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
79
80 Model() config.Model
81}
82
83type baseProvider[C ProviderClient] struct {
84 options providerClientOptions
85 client C
86}
87
88func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
89 for _, msg := range messages {
90 // The message has no content
91 if len(msg.Parts) == 0 {
92 continue
93 }
94 cleaned = append(cleaned, msg)
95 }
96 return
97}
98
99func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
100 messages = p.cleanMessages(messages)
101 return p.client.send(ctx, messages, tools)
102}
103
104func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
105 messages = p.cleanMessages(messages)
106 return p.client.stream(ctx, messages, tools)
107}
108
109func (p *baseProvider[C]) Model() config.Model {
110 return p.client.Model()
111}
112
113func WithModel(model config.ModelType) ProviderClientOption {
114 return func(options *providerClientOptions) {
115 options.modelType = model
116 }
117}
118
119func WithDisableCache(disableCache bool) ProviderClientOption {
120 return func(options *providerClientOptions) {
121 options.disableCache = disableCache
122 }
123}
124
125func WithSystemMessage(systemMessage string) ProviderClientOption {
126 return func(options *providerClientOptions) {
127 options.systemMessage = systemMessage
128 }
129}
130
131func WithMaxTokens(maxTokens int64) ProviderClientOption {
132 return func(options *providerClientOptions) {
133 options.maxTokens = maxTokens
134 }
135}
136
137func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
138 resolvedAPIKey, err := config.ResolveAPIKey(cfg.APIKey)
139 if err != nil {
140 return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
141 }
142
143 clientOptions := providerClientOptions{
144 baseURL: cfg.BaseURL,
145 config: cfg,
146 apiKey: resolvedAPIKey,
147 extraHeaders: cfg.ExtraHeaders,
148 model: func(tp config.ModelType) config.Model {
149 return config.GetModel(tp)
150 },
151 }
152 for _, o := range opts {
153 o(&clientOptions)
154 }
155 switch cfg.ProviderType {
156 case provider.TypeAnthropic:
157 return &baseProvider[AnthropicClient]{
158 options: clientOptions,
159 client: newAnthropicClient(clientOptions, false),
160 }, nil
161 case provider.TypeOpenAI:
162 return &baseProvider[OpenAIClient]{
163 options: clientOptions,
164 client: newOpenAIClient(clientOptions),
165 }, nil
166 case provider.TypeGemini:
167 return &baseProvider[GeminiClient]{
168 options: clientOptions,
169 client: newGeminiClient(clientOptions),
170 }, nil
171 case provider.TypeBedrock:
172 return &baseProvider[BedrockClient]{
173 options: clientOptions,
174 client: newBedrockClient(clientOptions),
175 }, nil
176 case provider.TypeAzure:
177 return &baseProvider[AzureClient]{
178 options: clientOptions,
179 client: newAzureClient(clientOptions),
180 }, nil
181 case provider.TypeVertexAI:
182 return &baseProvider[VertexAIClient]{
183 options: clientOptions,
184 client: newVertexAIClient(clientOptions),
185 }, nil
186 case provider.TypeXAI:
187 clientOptions.baseURL = "https://api.x.ai/v1"
188 return &baseProvider[OpenAIClient]{
189 options: clientOptions,
190 client: newOpenAIClient(clientOptions),
191 }, nil
192 case provider.TypeLlama:
193 return &baseProvider[LlamaClient]{
194 options: clientOptions,
195 client: newLlamaClient(clientOptions),
196 }, nil
197 }
198 return nil, fmt.Errorf("provider not supported: %s", cfg.ProviderType)
199}