1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/fur/provider"
14 "github.com/charmbracelet/crush/internal/llm/tools"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/openai/openai-go"
17 "github.com/openai/openai-go/option"
18 "github.com/openai/openai-go/shared"
19)
20
21type openaiClient struct {
22 providerOptions providerClientOptions
23 client openai.Client
24}
25
26type OpenAIClient ProviderClient
27
28func newOpenAIClient(opts providerClientOptions) OpenAIClient {
29 return &openaiClient{
30 providerOptions: opts,
31 client: createOpenAIClient(opts),
32 }
33}
34
35func createOpenAIClient(opts providerClientOptions) openai.Client {
36 openaiClientOptions := []option.RequestOption{}
37 if opts.apiKey != "" {
38 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
39 }
40 if opts.baseURL != "" {
41 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL))
42 }
43
44 if opts.extraHeaders != nil {
45 for key, value := range opts.extraHeaders {
46 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
47 }
48 }
49
50 return openai.NewClient(openaiClientOptions...)
51}
52
53func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
54 // Add system message first
55 openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
56
57 for _, msg := range messages {
58 switch msg.Role {
59 case message.User:
60 var content []openai.ChatCompletionContentPartUnionParam
61 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
62 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
63 for _, binaryContent := range msg.BinaryContent() {
64 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
65 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
66
67 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
68 }
69
70 openaiMessages = append(openaiMessages, openai.UserMessage(content))
71
72 case message.Assistant:
73 assistantMsg := openai.ChatCompletionAssistantMessageParam{
74 Role: "assistant",
75 }
76
77 if msg.Content().String() != "" {
78 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
79 OfString: openai.String(msg.Content().String()),
80 }
81 }
82
83 if len(msg.ToolCalls()) > 0 {
84 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
85 for i, call := range msg.ToolCalls() {
86 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
87 ID: call.ID,
88 Type: "function",
89 Function: openai.ChatCompletionMessageToolCallFunctionParam{
90 Name: call.Name,
91 Arguments: call.Input,
92 },
93 }
94 }
95 }
96
97 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
98 OfAssistant: &assistantMsg,
99 })
100
101 case message.Tool:
102 for _, result := range msg.ToolResults() {
103 openaiMessages = append(openaiMessages,
104 openai.ToolMessage(result.Content, result.ToolCallID),
105 )
106 }
107 }
108 }
109
110 return
111}
112
113func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
114 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
115
116 for i, tool := range tools {
117 info := tool.Info()
118 openaiTools[i] = openai.ChatCompletionToolParam{
119 Function: openai.FunctionDefinitionParam{
120 Name: info.Name,
121 Description: openai.String(info.Description),
122 Parameters: openai.FunctionParameters{
123 "type": "object",
124 "properties": info.Parameters,
125 "required": info.Required,
126 },
127 },
128 }
129 }
130
131 return openaiTools
132}
133
134func (o *openaiClient) finishReason(reason string) message.FinishReason {
135 switch reason {
136 case "stop":
137 return message.FinishReasonEndTurn
138 case "length":
139 return message.FinishReasonMaxTokens
140 case "tool_calls":
141 return message.FinishReasonToolUse
142 default:
143 return message.FinishReasonUnknown
144 }
145}
146
147func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
148 model := o.providerOptions.model(o.providerOptions.modelType)
149 cfg := config.Get()
150
151 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
152 if o.providerOptions.modelType == config.SelectedModelTypeSmall {
153 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
154 }
155
156 reasoningEffort := modelConfig.ReasoningEffort
157
158 params := openai.ChatCompletionNewParams{
159 Model: openai.ChatModel(model.ID),
160 Messages: messages,
161 Tools: tools,
162 }
163
164 maxTokens := model.DefaultMaxTokens
165 if modelConfig.MaxTokens > 0 {
166 maxTokens = modelConfig.MaxTokens
167 }
168
169 // Override max tokens if set in provider options
170 if o.providerOptions.maxTokens > 0 {
171 maxTokens = o.providerOptions.maxTokens
172 }
173 if model.CanReason {
174 params.MaxCompletionTokens = openai.Int(maxTokens)
175 switch reasoningEffort {
176 case "low":
177 params.ReasoningEffort = shared.ReasoningEffortLow
178 case "medium":
179 params.ReasoningEffort = shared.ReasoningEffortMedium
180 case "high":
181 params.ReasoningEffort = shared.ReasoningEffortHigh
182 default:
183 params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
184 }
185 } else {
186 params.MaxTokens = openai.Int(maxTokens)
187 }
188
189 return params
190}
191
192func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
193 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
194 cfg := config.Get()
195 if cfg.Options.Debug {
196 jsonData, _ := json.Marshal(params)
197 slog.Debug("Prepared messages", "messages", string(jsonData))
198 }
199 attempts := 0
200 for {
201 attempts++
202 openaiResponse, err := o.client.Chat.Completions.New(
203 ctx,
204 params,
205 )
206 // If there is an error we are going to see if we can retry the call
207 if err != nil {
208 retry, after, retryErr := o.shouldRetry(attempts, err)
209 if retryErr != nil {
210 return nil, retryErr
211 }
212 if retry {
213 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
214 select {
215 case <-ctx.Done():
216 return nil, ctx.Err()
217 case <-time.After(time.Duration(after) * time.Millisecond):
218 continue
219 }
220 }
221 return nil, retryErr
222 }
223
224 content := ""
225 if openaiResponse.Choices[0].Message.Content != "" {
226 content = openaiResponse.Choices[0].Message.Content
227 }
228
229 toolCalls := o.toolCalls(*openaiResponse)
230 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
231
232 if len(toolCalls) > 0 {
233 finishReason = message.FinishReasonToolUse
234 }
235
236 return &ProviderResponse{
237 Content: content,
238 ToolCalls: toolCalls,
239 Usage: o.usage(*openaiResponse),
240 FinishReason: finishReason,
241 }, nil
242 }
243}
244
245func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
246 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
247 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
248 IncludeUsage: openai.Bool(true),
249 }
250
251 cfg := config.Get()
252 if cfg.Options.Debug {
253 jsonData, _ := json.Marshal(params)
254 slog.Debug("Prepared messages", "messages", string(jsonData))
255 }
256
257 attempts := 0
258 eventChan := make(chan ProviderEvent)
259
260 go func() {
261 for {
262 attempts++
263 openaiStream := o.client.Chat.Completions.NewStreaming(
264 ctx,
265 params,
266 )
267
268 acc := openai.ChatCompletionAccumulator{}
269 currentContent := ""
270 toolCalls := make([]message.ToolCall, 0)
271
272 for openaiStream.Next() {
273 chunk := openaiStream.Current()
274 acc.AddChunk(chunk)
275
276 for _, choice := range chunk.Choices {
277 if choice.Delta.Content != "" {
278 eventChan <- ProviderEvent{
279 Type: EventContentDelta,
280 Content: choice.Delta.Content,
281 }
282 currentContent += choice.Delta.Content
283 }
284 }
285 }
286
287 err := openaiStream.Err()
288 if err == nil || errors.Is(err, io.EOF) {
289 if cfg.Options.Debug {
290 jsonData, _ := json.Marshal(acc.ChatCompletion)
291 slog.Debug("Response", "messages", string(jsonData))
292 }
293 resultFinishReason := acc.ChatCompletion.Choices[0].FinishReason
294 if resultFinishReason == "" {
295 // If the finish reason is empty, we assume it was a successful completion
296 // INFO: this is happening for openrouter for some reason
297 resultFinishReason = "stop"
298 }
299 // Stream completed successfully
300 finishReason := o.finishReason(resultFinishReason)
301 if len(acc.Choices[0].Message.ToolCalls) > 0 {
302 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
303 }
304 if len(toolCalls) > 0 {
305 finishReason = message.FinishReasonToolUse
306 }
307
308 eventChan <- ProviderEvent{
309 Type: EventComplete,
310 Response: &ProviderResponse{
311 Content: currentContent,
312 ToolCalls: toolCalls,
313 Usage: o.usage(acc.ChatCompletion),
314 FinishReason: finishReason,
315 },
316 }
317 close(eventChan)
318 return
319 }
320
321 // If there is an error we are going to see if we can retry the call
322 retry, after, retryErr := o.shouldRetry(attempts, err)
323 if retryErr != nil {
324 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
325 close(eventChan)
326 return
327 }
328 if retry {
329 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
330 select {
331 case <-ctx.Done():
332 // context cancelled
333 if ctx.Err() == nil {
334 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
335 }
336 close(eventChan)
337 return
338 case <-time.After(time.Duration(after) * time.Millisecond):
339 continue
340 }
341 }
342 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
343 close(eventChan)
344 return
345 }
346 }()
347
348 return eventChan
349}
350
351func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
352 var apiErr *openai.Error
353 if !errors.As(err, &apiErr) {
354 return false, 0, err
355 }
356
357 if attempts > maxRetries {
358 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
359 }
360
361 // Check for token expiration (401 Unauthorized)
362 if apiErr.StatusCode == 401 {
363 o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
364 if err != nil {
365 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
366 }
367 o.client = createOpenAIClient(o.providerOptions)
368 return true, 0, nil
369 }
370
371 if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
372 return false, 0, err
373 }
374
375 retryMs := 0
376 retryAfterValues := apiErr.Response.Header.Values("Retry-After")
377
378 backoffMs := 2000 * (1 << (attempts - 1))
379 jitterMs := int(float64(backoffMs) * 0.2)
380 retryMs = backoffMs + jitterMs
381 if len(retryAfterValues) > 0 {
382 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
383 retryMs = retryMs * 1000
384 }
385 }
386 return true, int64(retryMs), nil
387}
388
389func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
390 var toolCalls []message.ToolCall
391
392 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
393 for _, call := range completion.Choices[0].Message.ToolCalls {
394 toolCall := message.ToolCall{
395 ID: call.ID,
396 Name: call.Function.Name,
397 Input: call.Function.Arguments,
398 Type: "function",
399 Finished: true,
400 }
401 toolCalls = append(toolCalls, toolCall)
402 }
403 }
404
405 return toolCalls
406}
407
408func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
409 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
410 inputTokens := completion.Usage.PromptTokens - cachedTokens
411
412 return TokenUsage{
413 InputTokens: inputTokens,
414 OutputTokens: completion.Usage.CompletionTokens,
415 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
416 CacheReadTokens: cachedTokens,
417 }
418}
419
420func (a *openaiClient) Model() provider.Model {
421 return a.providerOptions.model(a.providerOptions.modelType)
422}