1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/fur/provider"
14 "github.com/charmbracelet/crush/internal/llm/tools"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/openai/openai-go"
17 "github.com/openai/openai-go/option"
18 "github.com/openai/openai-go/shared"
19)
20
21type openaiClient struct {
22 providerOptions providerClientOptions
23 client openai.Client
24}
25
26type OpenAIClient ProviderClient
27
28func newOpenAIClient(opts providerClientOptions) OpenAIClient {
29 return &openaiClient{
30 providerOptions: opts,
31 client: createOpenAIClient(opts),
32 }
33}
34
35func createOpenAIClient(opts providerClientOptions) openai.Client {
36 openaiClientOptions := []option.RequestOption{}
37 if opts.apiKey != "" {
38 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
39 }
40 if opts.baseURL != "" {
41 resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
42 if err == nil {
43 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
44 }
45 }
46
47 for key, value := range opts.extraHeaders {
48 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
49 }
50
51 for extraKey, extraValue := range opts.extraBody {
52 openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
53 }
54
55 return openai.NewClient(openaiClientOptions...)
56}
57
58func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
59 // Add system message first
60 openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
61
62 for _, msg := range messages {
63 switch msg.Role {
64 case message.User:
65 var content []openai.ChatCompletionContentPartUnionParam
66 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
67 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
68 for _, binaryContent := range msg.BinaryContent() {
69 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
70 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
71
72 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
73 }
74
75 openaiMessages = append(openaiMessages, openai.UserMessage(content))
76
77 case message.Assistant:
78 assistantMsg := openai.ChatCompletionAssistantMessageParam{
79 Role: "assistant",
80 }
81
82 hasContent := false
83 if msg.Content().String() != "" {
84 hasContent = true
85 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
86 OfString: openai.String(msg.Content().String()),
87 }
88 }
89
90 if len(msg.ToolCalls()) > 0 {
91 hasContent = true
92 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
93 for i, call := range msg.ToolCalls() {
94 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
95 ID: call.ID,
96 Type: "function",
97 Function: openai.ChatCompletionMessageToolCallFunctionParam{
98 Name: call.Name,
99 Arguments: call.Input,
100 },
101 }
102 }
103 }
104 if !hasContent {
105 slog.Warn("There is a message without content, investigate, this should not happen")
106 continue
107 }
108
109 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
110 OfAssistant: &assistantMsg,
111 })
112
113 case message.Tool:
114 for _, result := range msg.ToolResults() {
115 openaiMessages = append(openaiMessages,
116 openai.ToolMessage(result.Content, result.ToolCallID),
117 )
118 }
119 }
120 }
121
122 return
123}
124
125func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
126 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
127
128 for i, tool := range tools {
129 info := tool.Info()
130 openaiTools[i] = openai.ChatCompletionToolParam{
131 Function: openai.FunctionDefinitionParam{
132 Name: info.Name,
133 Description: openai.String(info.Description),
134 Parameters: openai.FunctionParameters{
135 "type": "object",
136 "properties": info.Parameters,
137 "required": info.Required,
138 },
139 },
140 }
141 }
142
143 return openaiTools
144}
145
146func (o *openaiClient) finishReason(reason string) message.FinishReason {
147 switch reason {
148 case "stop":
149 return message.FinishReasonEndTurn
150 case "length":
151 return message.FinishReasonMaxTokens
152 case "tool_calls":
153 return message.FinishReasonToolUse
154 default:
155 return message.FinishReasonUnknown
156 }
157}
158
159func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
160 model := o.providerOptions.model(o.providerOptions.modelType)
161 cfg := config.Get()
162
163 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
164 if o.providerOptions.modelType == config.SelectedModelTypeSmall {
165 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
166 }
167
168 reasoningEffort := modelConfig.ReasoningEffort
169
170 params := openai.ChatCompletionNewParams{
171 Model: openai.ChatModel(model.ID),
172 Messages: messages,
173 Tools: tools,
174 }
175
176 maxTokens := model.DefaultMaxTokens
177 if modelConfig.MaxTokens > 0 {
178 maxTokens = modelConfig.MaxTokens
179 }
180
181 // Override max tokens if set in provider options
182 if o.providerOptions.maxTokens > 0 {
183 maxTokens = o.providerOptions.maxTokens
184 }
185 if model.CanReason {
186 params.MaxCompletionTokens = openai.Int(maxTokens)
187 switch reasoningEffort {
188 case "low":
189 params.ReasoningEffort = shared.ReasoningEffortLow
190 case "medium":
191 params.ReasoningEffort = shared.ReasoningEffortMedium
192 case "high":
193 params.ReasoningEffort = shared.ReasoningEffortHigh
194 default:
195 params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
196 }
197 } else {
198 params.MaxTokens = openai.Int(maxTokens)
199 }
200
201 return params
202}
203
204func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
205 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
206 cfg := config.Get()
207 if cfg.Options.Debug {
208 jsonData, _ := json.Marshal(params)
209 slog.Debug("Prepared messages", "messages", string(jsonData))
210 }
211 attempts := 0
212 for {
213 attempts++
214 openaiResponse, err := o.client.Chat.Completions.New(
215 ctx,
216 params,
217 )
218 // If there is an error we are going to see if we can retry the call
219 if err != nil {
220 retry, after, retryErr := o.shouldRetry(attempts, err)
221 if retryErr != nil {
222 return nil, retryErr
223 }
224 if retry {
225 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
226 select {
227 case <-ctx.Done():
228 return nil, ctx.Err()
229 case <-time.After(time.Duration(after) * time.Millisecond):
230 continue
231 }
232 }
233 return nil, retryErr
234 }
235
236 if len(openaiResponse.Choices) == 0 {
237 return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
238 }
239
240 content := ""
241 if openaiResponse.Choices[0].Message.Content != "" {
242 content = openaiResponse.Choices[0].Message.Content
243 }
244
245 toolCalls := o.toolCalls(*openaiResponse)
246 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
247
248 if len(toolCalls) > 0 {
249 finishReason = message.FinishReasonToolUse
250 }
251
252 return &ProviderResponse{
253 Content: content,
254 ToolCalls: toolCalls,
255 Usage: o.usage(*openaiResponse),
256 FinishReason: finishReason,
257 }, nil
258 }
259}
260
261func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
262 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
263 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
264 IncludeUsage: openai.Bool(true),
265 }
266
267 cfg := config.Get()
268 if cfg.Options.Debug {
269 jsonData, _ := json.Marshal(params)
270 slog.Debug("Prepared messages", "messages", string(jsonData))
271 }
272
273 attempts := 0
274 eventChan := make(chan ProviderEvent)
275
276 go func() {
277 for {
278 attempts++
279 openaiStream := o.client.Chat.Completions.NewStreaming(
280 ctx,
281 params,
282 )
283
284 acc := openai.ChatCompletionAccumulator{}
285 currentContent := ""
286 toolCalls := make([]message.ToolCall, 0)
287
288 var currentToolCallID string
289 var currentToolCall openai.ChatCompletionMessageToolCall
290 var msgToolCalls []openai.ChatCompletionMessageToolCall
291 for openaiStream.Next() {
292 chunk := openaiStream.Current()
293 acc.AddChunk(chunk)
294 // This fixes multiple tool calls for some providers
295 for _, choice := range chunk.Choices {
296 if choice.Delta.Content != "" {
297 eventChan <- ProviderEvent{
298 Type: EventContentDelta,
299 Content: choice.Delta.Content,
300 }
301 currentContent += choice.Delta.Content
302 } else if len(choice.Delta.ToolCalls) > 0 {
303 toolCall := choice.Delta.ToolCalls[0]
304 // Detect tool use start
305 if currentToolCallID == "" {
306 if toolCall.ID != "" {
307 currentToolCallID = toolCall.ID
308 currentToolCall = openai.ChatCompletionMessageToolCall{
309 ID: toolCall.ID,
310 Type: "function",
311 Function: openai.ChatCompletionMessageToolCallFunction{
312 Name: toolCall.Function.Name,
313 Arguments: toolCall.Function.Arguments,
314 },
315 }
316 }
317 } else {
318 // Delta tool use
319 if toolCall.ID == "" {
320 currentToolCall.Function.Arguments += toolCall.Function.Arguments
321 } else {
322 // Detect new tool use
323 if toolCall.ID != currentToolCallID {
324 msgToolCalls = append(msgToolCalls, currentToolCall)
325 currentToolCallID = toolCall.ID
326 currentToolCall = openai.ChatCompletionMessageToolCall{
327 ID: toolCall.ID,
328 Type: "function",
329 Function: openai.ChatCompletionMessageToolCallFunction{
330 Name: toolCall.Function.Name,
331 Arguments: toolCall.Function.Arguments,
332 },
333 }
334 }
335 }
336 }
337 }
338 if choice.FinishReason == "tool_calls" {
339 msgToolCalls = append(msgToolCalls, currentToolCall)
340 if len(acc.Choices) > 0 {
341 acc.Choices[0].Message.ToolCalls = msgToolCalls
342 }
343 }
344 }
345 }
346
347 err := openaiStream.Err()
348 if err == nil || errors.Is(err, io.EOF) {
349 if cfg.Options.Debug {
350 jsonData, _ := json.Marshal(acc.ChatCompletion)
351 slog.Debug("Response", "messages", string(jsonData))
352 }
353
354 if len(acc.Choices) == 0 {
355 eventChan <- ProviderEvent{
356 Type: EventError,
357 Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
358 }
359 return
360 }
361
362 resultFinishReason := acc.Choices[0].FinishReason
363 if resultFinishReason == "" {
364 // If the finish reason is empty, we assume it was a successful completion
365 // INFO: this is happening for openrouter for some reason
366 resultFinishReason = "stop"
367 }
368 // Stream completed successfully
369 finishReason := o.finishReason(resultFinishReason)
370 if len(acc.Choices[0].Message.ToolCalls) > 0 {
371 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
372 }
373 if len(toolCalls) > 0 {
374 finishReason = message.FinishReasonToolUse
375 }
376
377 eventChan <- ProviderEvent{
378 Type: EventComplete,
379 Response: &ProviderResponse{
380 Content: currentContent,
381 ToolCalls: toolCalls,
382 Usage: o.usage(acc.ChatCompletion),
383 FinishReason: finishReason,
384 },
385 }
386 close(eventChan)
387 return
388 }
389
390 // If there is an error we are going to see if we can retry the call
391 retry, after, retryErr := o.shouldRetry(attempts, err)
392 if retryErr != nil {
393 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
394 close(eventChan)
395 return
396 }
397 if retry {
398 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
399 select {
400 case <-ctx.Done():
401 // context cancelled
402 if ctx.Err() == nil {
403 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
404 }
405 close(eventChan)
406 return
407 case <-time.After(time.Duration(after) * time.Millisecond):
408 continue
409 }
410 }
411 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
412 close(eventChan)
413 return
414 }
415 }()
416
417 return eventChan
418}
419
420func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
421 var apiErr *openai.Error
422 if !errors.As(err, &apiErr) {
423 return false, 0, err
424 }
425
426 if attempts > maxRetries {
427 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
428 }
429
430 // Check for token expiration (401 Unauthorized)
431 if apiErr.StatusCode == 401 {
432 o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
433 if err != nil {
434 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
435 }
436 o.client = createOpenAIClient(o.providerOptions)
437 return true, 0, nil
438 }
439
440 if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
441 return false, 0, err
442 }
443
444 retryMs := 0
445 retryAfterValues := apiErr.Response.Header.Values("Retry-After")
446
447 backoffMs := 2000 * (1 << (attempts - 1))
448 jitterMs := int(float64(backoffMs) * 0.2)
449 retryMs = backoffMs + jitterMs
450 if len(retryAfterValues) > 0 {
451 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
452 retryMs = retryMs * 1000
453 }
454 }
455 return true, int64(retryMs), nil
456}
457
458func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
459 var toolCalls []message.ToolCall
460
461 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
462 for _, call := range completion.Choices[0].Message.ToolCalls {
463 toolCall := message.ToolCall{
464 ID: call.ID,
465 Name: call.Function.Name,
466 Input: call.Function.Arguments,
467 Type: "function",
468 Finished: true,
469 }
470 toolCalls = append(toolCalls, toolCall)
471 }
472 }
473
474 return toolCalls
475}
476
477func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
478 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
479 inputTokens := completion.Usage.PromptTokens - cachedTokens
480
481 return TokenUsage{
482 InputTokens: inputTokens,
483 OutputTokens: completion.Usage.CompletionTokens,
484 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
485 CacheReadTokens: cachedTokens,
486 }
487}
488
489func (o *openaiClient) Model() provider.Model {
490 return o.providerOptions.model(o.providerOptions.modelType)
491}