1package provider
2
3import (
4 "context"
5 "fmt"
6 "os"
7
8 "github.com/charmbracelet/crush/internal/llm/models"
9 "github.com/charmbracelet/crush/internal/llm/tools"
10 "github.com/charmbracelet/crush/internal/message"
11)
12
13type EventType string
14
15const maxRetries = 8
16
17const (
18 EventContentStart EventType = "content_start"
19 EventToolUseStart EventType = "tool_use_start"
20 EventToolUseDelta EventType = "tool_use_delta"
21 EventToolUseStop EventType = "tool_use_stop"
22 EventContentDelta EventType = "content_delta"
23 EventThinkingDelta EventType = "thinking_delta"
24 EventContentStop EventType = "content_stop"
25 EventComplete EventType = "complete"
26 EventError EventType = "error"
27 EventWarning EventType = "warning"
28)
29
30type TokenUsage struct {
31 InputTokens int64
32 OutputTokens int64
33 CacheCreationTokens int64
34 CacheReadTokens int64
35}
36
37type ProviderResponse struct {
38 Content string
39 ToolCalls []message.ToolCall
40 Usage TokenUsage
41 FinishReason message.FinishReason
42}
43
44type ProviderEvent struct {
45 Type EventType
46
47 Content string
48 Thinking string
49 Response *ProviderResponse
50 ToolCall *message.ToolCall
51 Error error
52}
53type Provider interface {
54 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
55
56 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
57
58 Model() models.Model
59}
60
61type providerClientOptions struct {
62 apiKey string
63 model models.Model
64 maxTokens int64
65 systemMessage string
66
67 anthropicOptions []AnthropicOption
68 openaiOptions []OpenAIOption
69 geminiOptions []GeminiOption
70 bedrockOptions []BedrockOption
71}
72
73type ProviderClientOption func(*providerClientOptions)
74
75type ProviderClient interface {
76 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
77 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
78}
79
80type baseProvider[C ProviderClient] struct {
81 options providerClientOptions
82 client C
83}
84
85func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) {
86 clientOptions := providerClientOptions{}
87 for _, o := range opts {
88 o(&clientOptions)
89 }
90 switch providerName {
91 case models.ProviderAnthropic:
92 return &baseProvider[AnthropicClient]{
93 options: clientOptions,
94 client: newAnthropicClient(clientOptions),
95 }, nil
96 case models.ProviderOpenAI:
97 return &baseProvider[OpenAIClient]{
98 options: clientOptions,
99 client: newOpenAIClient(clientOptions),
100 }, nil
101 case models.ProviderGemini:
102 return &baseProvider[GeminiClient]{
103 options: clientOptions,
104 client: newGeminiClient(clientOptions),
105 }, nil
106 case models.ProviderBedrock:
107 return &baseProvider[BedrockClient]{
108 options: clientOptions,
109 client: newBedrockClient(clientOptions),
110 }, nil
111 case models.ProviderGROQ:
112 clientOptions.openaiOptions = append(clientOptions.openaiOptions,
113 WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
114 )
115 return &baseProvider[OpenAIClient]{
116 options: clientOptions,
117 client: newOpenAIClient(clientOptions),
118 }, nil
119 case models.ProviderAzure:
120 return &baseProvider[AzureClient]{
121 options: clientOptions,
122 client: newAzureClient(clientOptions),
123 }, nil
124 case models.ProviderVertexAI:
125 return &baseProvider[VertexAIClient]{
126 options: clientOptions,
127 client: newVertexAIClient(clientOptions),
128 }, nil
129 case models.ProviderOpenRouter:
130 clientOptions.openaiOptions = append(clientOptions.openaiOptions,
131 WithOpenAIBaseURL("https://openrouter.ai/api/v1"),
132 WithOpenAIExtraHeaders(map[string]string{
133 "HTTP-Referer": "crush.charm.land",
134 "X-Title": "Crush",
135 }),
136 )
137 return &baseProvider[OpenAIClient]{
138 options: clientOptions,
139 client: newOpenAIClient(clientOptions),
140 }, nil
141 case models.ProviderXAI:
142 clientOptions.openaiOptions = append(clientOptions.openaiOptions,
143 WithOpenAIBaseURL("https://api.x.ai/v1"),
144 )
145 return &baseProvider[OpenAIClient]{
146 options: clientOptions,
147 client: newOpenAIClient(clientOptions),
148 }, nil
149 case models.ProviderLocal:
150 clientOptions.openaiOptions = append(clientOptions.openaiOptions,
151 WithOpenAIBaseURL(os.Getenv("LOCAL_ENDPOINT")),
152 )
153 return &baseProvider[OpenAIClient]{
154 options: clientOptions,
155 client: newOpenAIClient(clientOptions),
156 }, nil
157 case models.ProviderMock:
158 // TODO: implement mock client for test
159 panic("not implemented")
160 }
161 return nil, fmt.Errorf("provider not supported: %s", providerName)
162}
163
164func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
165 for _, msg := range messages {
166 // The message has no content
167 if len(msg.Parts) == 0 {
168 continue
169 }
170 cleaned = append(cleaned, msg)
171 }
172 return
173}
174
175func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
176 messages = p.cleanMessages(messages)
177 return p.client.send(ctx, messages, tools)
178}
179
180func (p *baseProvider[C]) Model() models.Model {
181 return p.options.model
182}
183
184func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
185 messages = p.cleanMessages(messages)
186 return p.client.stream(ctx, messages, tools)
187}
188
189func WithAPIKey(apiKey string) ProviderClientOption {
190 return func(options *providerClientOptions) {
191 options.apiKey = apiKey
192 }
193}
194
195func WithModel(model models.Model) ProviderClientOption {
196 return func(options *providerClientOptions) {
197 options.model = model
198 }
199}
200
201func WithMaxTokens(maxTokens int64) ProviderClientOption {
202 return func(options *providerClientOptions) {
203 options.maxTokens = maxTokens
204 }
205}
206
207func WithSystemMessage(systemMessage string) ProviderClientOption {
208 return func(options *providerClientOptions) {
209 options.systemMessage = systemMessage
210 }
211}
212
213func WithAnthropicOptions(anthropicOptions ...AnthropicOption) ProviderClientOption {
214 return func(options *providerClientOptions) {
215 options.anthropicOptions = anthropicOptions
216 }
217}
218
219func WithOpenAIOptions(openaiOptions ...OpenAIOption) ProviderClientOption {
220 return func(options *providerClientOptions) {
221 options.openaiOptions = openaiOptions
222 }
223}
224
225func WithGeminiOptions(geminiOptions ...GeminiOption) ProviderClientOption {
226 return func(options *providerClientOptions) {
227 options.geminiOptions = geminiOptions
228 }
229}
230
231func WithBedrockOptions(bedrockOptions ...BedrockOption) ProviderClientOption {
232 return func(options *providerClientOptions) {
233 options.bedrockOptions = bedrockOptions
234 }
235}