1package provider
2
3import (
4 "context"
5 "fmt"
6 "maps"
7 "os"
8
9 "github.com/charmbracelet/crush/internal/llm/models"
10 "github.com/charmbracelet/crush/internal/llm/tools"
11 "github.com/charmbracelet/crush/internal/message"
12)
13
14type EventType string
15
16const maxRetries = 8
17
18const (
19 EventContentStart EventType = "content_start"
20 EventToolUseStart EventType = "tool_use_start"
21 EventToolUseDelta EventType = "tool_use_delta"
22 EventToolUseStop EventType = "tool_use_stop"
23 EventContentDelta EventType = "content_delta"
24 EventThinkingDelta EventType = "thinking_delta"
25 EventContentStop EventType = "content_stop"
26 EventComplete EventType = "complete"
27 EventError EventType = "error"
28 EventWarning EventType = "warning"
29)
30
31type TokenUsage struct {
32 InputTokens int64
33 OutputTokens int64
34 CacheCreationTokens int64
35 CacheReadTokens int64
36}
37
38type ProviderResponse struct {
39 Content string
40 ToolCalls []message.ToolCall
41 Usage TokenUsage
42 FinishReason message.FinishReason
43}
44
45type ProviderEvent struct {
46 Type EventType
47
48 Content string
49 Thinking string
50 Response *ProviderResponse
51 ToolCall *message.ToolCall
52 Error error
53}
54type Provider interface {
55 SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
56
57 StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
58
59 Model() models.Model
60}
61
62type providerClientOptions struct {
63 baseURL string
64 apiKey string
65 model models.Model
66 disableCache bool
67 maxTokens int64
68 systemMessage string
69 extraHeaders map[string]string
70}
71
72type ProviderClientOption func(*providerClientOptions)
73
74type ProviderClient interface {
75 send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error)
76 stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent
77}
78
79type baseProvider[C ProviderClient] struct {
80 options providerClientOptions
81 client C
82}
83
84func NewProvider(providerName models.InferenceProvider, opts ...ProviderClientOption) (Provider, error) {
85 clientOptions := providerClientOptions{}
86 for _, o := range opts {
87 o(&clientOptions)
88 }
89 switch providerName {
90 case models.ProviderAnthropic:
91 return &baseProvider[AnthropicClient]{
92 options: clientOptions,
93 client: newAnthropicClient(clientOptions, false),
94 }, nil
95 case models.ProviderOpenAI:
96 return &baseProvider[OpenAIClient]{
97 options: clientOptions,
98 client: newOpenAIClient(clientOptions),
99 }, nil
100 case models.ProviderGemini:
101 return &baseProvider[GeminiClient]{
102 options: clientOptions,
103 client: newGeminiClient(clientOptions),
104 }, nil
105 case models.ProviderBedrock:
106 return &baseProvider[BedrockClient]{
107 options: clientOptions,
108 client: newBedrockClient(clientOptions),
109 }, nil
110 case models.ProviderGROQ:
111 clientOptions.baseURL = "https://api.groq.com/openai/v1"
112 return &baseProvider[OpenAIClient]{
113 options: clientOptions,
114 client: newOpenAIClient(clientOptions),
115 }, nil
116 case models.ProviderAzure:
117 return &baseProvider[AzureClient]{
118 options: clientOptions,
119 client: newAzureClient(clientOptions),
120 }, nil
121 case models.ProviderVertexAI:
122 return &baseProvider[VertexAIClient]{
123 options: clientOptions,
124 client: newVertexAIClient(clientOptions),
125 }, nil
126 case models.ProviderOpenRouter:
127 clientOptions.baseURL = "https://openrouter.ai/api/v1"
128 clientOptions.extraHeaders = map[string]string{
129 "HTTP-Referer": "crush.charm.land",
130 "X-Title": "Crush",
131 }
132 return &baseProvider[OpenAIClient]{
133 options: clientOptions,
134 client: newOpenAIClient(clientOptions),
135 }, nil
136 case models.ProviderXAI:
137 clientOptions.baseURL = "https://api.x.ai/v1"
138 return &baseProvider[OpenAIClient]{
139 options: clientOptions,
140 client: newOpenAIClient(clientOptions),
141 }, nil
142 case models.ProviderLocal:
143 clientOptions.baseURL = os.Getenv("LOCAL_ENDPOINT")
144 return &baseProvider[OpenAIClient]{
145 options: clientOptions,
146 client: newOpenAIClient(clientOptions),
147 }, nil
148 case models.ProviderMock:
149 // TODO: implement mock client for test
150 panic("not implemented")
151 }
152 return nil, fmt.Errorf("provider not supported: %s", providerName)
153}
154
155func (p *baseProvider[C]) cleanMessages(messages []message.Message) (cleaned []message.Message) {
156 for _, msg := range messages {
157 // The message has no content
158 if len(msg.Parts) == 0 {
159 continue
160 }
161 cleaned = append(cleaned, msg)
162 }
163 return
164}
165
166func (p *baseProvider[C]) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
167 messages = p.cleanMessages(messages)
168 return p.client.send(ctx, messages, tools)
169}
170
171func (p *baseProvider[C]) Model() models.Model {
172 return p.options.model
173}
174
175func (p *baseProvider[C]) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
176 messages = p.cleanMessages(messages)
177 return p.client.stream(ctx, messages, tools)
178}
179
180func WithBaseURL(baseURL string) ProviderClientOption {
181 return func(options *providerClientOptions) {
182 options.baseURL = baseURL
183 }
184}
185
186func WithAPIKey(apiKey string) ProviderClientOption {
187 return func(options *providerClientOptions) {
188 options.apiKey = apiKey
189 }
190}
191
192func WithModel(model models.Model) ProviderClientOption {
193 return func(options *providerClientOptions) {
194 options.model = model
195 }
196}
197
198func WithDisableCache(disableCache bool) ProviderClientOption {
199 return func(options *providerClientOptions) {
200 options.disableCache = disableCache
201 }
202}
203
204func WithExtraHeaders(extraHeaders map[string]string) ProviderClientOption {
205 return func(options *providerClientOptions) {
206 if options.extraHeaders == nil {
207 options.extraHeaders = make(map[string]string)
208 }
209 maps.Copy(options.extraHeaders, extraHeaders)
210 }
211}
212
213func WithMaxTokens(maxTokens int64) ProviderClientOption {
214 return func(options *providerClientOptions) {
215 options.maxTokens = maxTokens
216 }
217}
218
219func WithSystemMessage(systemMessage string) ProviderClientOption {
220 return func(options *providerClientOptions) {
221 options.systemMessage = systemMessage
222 }
223}