1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/fur/provider"
14 "github.com/charmbracelet/crush/internal/llm/tools"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/openai/openai-go"
17 "github.com/openai/openai-go/option"
18 "github.com/openai/openai-go/shared"
19)
20
21type openaiClient struct {
22 providerOptions providerClientOptions
23 client openai.Client
24}
25
26type OpenAIClient ProviderClient
27
28func newOpenAIClient(opts providerClientOptions) OpenAIClient {
29 return &openaiClient{
30 providerOptions: opts,
31 client: createOpenAIClient(opts),
32 }
33}
34
35func createOpenAIClient(opts providerClientOptions) openai.Client {
36 openaiClientOptions := []option.RequestOption{}
37 if opts.apiKey != "" {
38 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
39 }
40 if opts.baseURL != "" {
41 resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
42 if err == nil {
43 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
44 }
45 }
46
47 if opts.extraHeaders != nil {
48 for key, value := range opts.extraHeaders {
49 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
50 }
51 }
52
53 return openai.NewClient(openaiClientOptions...)
54}
55
56func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
57 // Add system message first
58 openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
59
60 for _, msg := range messages {
61 switch msg.Role {
62 case message.User:
63 var content []openai.ChatCompletionContentPartUnionParam
64 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
65 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
66 for _, binaryContent := range msg.BinaryContent() {
67 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
68 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
69
70 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
71 }
72
73 openaiMessages = append(openaiMessages, openai.UserMessage(content))
74
75 case message.Assistant:
76 assistantMsg := openai.ChatCompletionAssistantMessageParam{
77 Role: "assistant",
78 }
79
80 hasContent := false
81 if msg.Content().String() != "" {
82 hasContent = true
83 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
84 OfString: openai.String(msg.Content().String()),
85 }
86 }
87
88 if len(msg.ToolCalls()) > 0 {
89 hasContent = true
90 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
91 for i, call := range msg.ToolCalls() {
92 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
93 ID: call.ID,
94 Type: "function",
95 Function: openai.ChatCompletionMessageToolCallFunctionParam{
96 Name: call.Name,
97 Arguments: call.Input,
98 },
99 }
100 }
101 }
102 if !hasContent {
103 slog.Warn("There is a message without content, investigate, this should not happen")
104 continue
105 }
106
107 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
108 OfAssistant: &assistantMsg,
109 })
110
111 case message.Tool:
112 for _, result := range msg.ToolResults() {
113 openaiMessages = append(openaiMessages,
114 openai.ToolMessage(result.Content, result.ToolCallID),
115 )
116 }
117 }
118 }
119
120 return
121}
122
123func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
124 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
125
126 for i, tool := range tools {
127 info := tool.Info()
128 openaiTools[i] = openai.ChatCompletionToolParam{
129 Function: openai.FunctionDefinitionParam{
130 Name: info.Name,
131 Description: openai.String(info.Description),
132 Parameters: openai.FunctionParameters{
133 "type": "object",
134 "properties": info.Parameters,
135 "required": info.Required,
136 },
137 },
138 }
139 }
140
141 return openaiTools
142}
143
144func (o *openaiClient) finishReason(reason string) message.FinishReason {
145 switch reason {
146 case "stop":
147 return message.FinishReasonEndTurn
148 case "length":
149 return message.FinishReasonMaxTokens
150 case "tool_calls":
151 return message.FinishReasonToolUse
152 default:
153 return message.FinishReasonUnknown
154 }
155}
156
157func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
158 model := o.providerOptions.model(o.providerOptions.modelType)
159 cfg := config.Get()
160
161 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
162 if o.providerOptions.modelType == config.SelectedModelTypeSmall {
163 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
164 }
165
166 reasoningEffort := modelConfig.ReasoningEffort
167
168 params := openai.ChatCompletionNewParams{
169 Model: openai.ChatModel(model.ID),
170 Messages: messages,
171 Tools: tools,
172 }
173
174 maxTokens := model.DefaultMaxTokens
175 if modelConfig.MaxTokens > 0 {
176 maxTokens = modelConfig.MaxTokens
177 }
178
179 // Override max tokens if set in provider options
180 if o.providerOptions.maxTokens > 0 {
181 maxTokens = o.providerOptions.maxTokens
182 }
183 if model.CanReason {
184 params.MaxCompletionTokens = openai.Int(maxTokens)
185 switch reasoningEffort {
186 case "low":
187 params.ReasoningEffort = shared.ReasoningEffortLow
188 case "medium":
189 params.ReasoningEffort = shared.ReasoningEffortMedium
190 case "high":
191 params.ReasoningEffort = shared.ReasoningEffortHigh
192 default:
193 params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
194 }
195 } else {
196 params.MaxTokens = openai.Int(maxTokens)
197 }
198
199 return params
200}
201
202func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
203 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
204 cfg := config.Get()
205 if cfg.Options.Debug {
206 jsonData, _ := json.Marshal(params)
207 slog.Debug("Prepared messages", "messages", string(jsonData))
208 }
209 attempts := 0
210 for {
211 attempts++
212 openaiResponse, err := o.client.Chat.Completions.New(
213 ctx,
214 params,
215 )
216 // If there is an error we are going to see if we can retry the call
217 if err != nil {
218 retry, after, retryErr := o.shouldRetry(attempts, err)
219 if retryErr != nil {
220 return nil, retryErr
221 }
222 if retry {
223 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
224 select {
225 case <-ctx.Done():
226 return nil, ctx.Err()
227 case <-time.After(time.Duration(after) * time.Millisecond):
228 continue
229 }
230 }
231 return nil, retryErr
232 }
233
234 if len(openaiResponse.Choices) == 0 {
235 return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
236 }
237
238 content := ""
239 if openaiResponse.Choices[0].Message.Content != "" {
240 content = openaiResponse.Choices[0].Message.Content
241 }
242
243 toolCalls := o.toolCalls(*openaiResponse)
244 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
245
246 if len(toolCalls) > 0 {
247 finishReason = message.FinishReasonToolUse
248 }
249
250 return &ProviderResponse{
251 Content: content,
252 ToolCalls: toolCalls,
253 Usage: o.usage(*openaiResponse),
254 FinishReason: finishReason,
255 }, nil
256 }
257}
258
259func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
260 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
261 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
262 IncludeUsage: openai.Bool(true),
263 }
264
265 cfg := config.Get()
266 if cfg.Options.Debug {
267 jsonData, _ := json.Marshal(params)
268 slog.Debug("Prepared messages", "messages", string(jsonData))
269 }
270
271 attempts := 0
272 eventChan := make(chan ProviderEvent)
273
274 go func() {
275 for {
276 attempts++
277 openaiStream := o.client.Chat.Completions.NewStreaming(
278 ctx,
279 params,
280 )
281
282 acc := openai.ChatCompletionAccumulator{}
283 currentContent := ""
284 toolCalls := make([]message.ToolCall, 0)
285
286 var currentToolCallID string
287 var currentToolCall openai.ChatCompletionMessageToolCall
288 var msgToolCalls []openai.ChatCompletionMessageToolCall
289 for openaiStream.Next() {
290 chunk := openaiStream.Current()
291 acc.AddChunk(chunk)
292 // This fixes multiple tool calls for some providers
293 for _, choice := range chunk.Choices {
294 if choice.Delta.Content != "" {
295 eventChan <- ProviderEvent{
296 Type: EventContentDelta,
297 Content: choice.Delta.Content,
298 }
299 currentContent += choice.Delta.Content
300 } else if len(choice.Delta.ToolCalls) > 0 {
301 toolCall := choice.Delta.ToolCalls[0]
302 // Detect tool use start
303 if currentToolCallID == "" {
304 if toolCall.ID != "" {
305 currentToolCallID = toolCall.ID
306 currentToolCall = openai.ChatCompletionMessageToolCall{
307 ID: toolCall.ID,
308 Type: "function",
309 Function: openai.ChatCompletionMessageToolCallFunction{
310 Name: toolCall.Function.Name,
311 Arguments: toolCall.Function.Arguments,
312 },
313 }
314 }
315 } else {
316 // Delta tool use
317 if toolCall.ID == "" {
318 currentToolCall.Function.Arguments += toolCall.Function.Arguments
319 } else {
320 // Detect new tool use
321 if toolCall.ID != currentToolCallID {
322 msgToolCalls = append(msgToolCalls, currentToolCall)
323 currentToolCallID = toolCall.ID
324 currentToolCall = openai.ChatCompletionMessageToolCall{
325 ID: toolCall.ID,
326 Type: "function",
327 Function: openai.ChatCompletionMessageToolCallFunction{
328 Name: toolCall.Function.Name,
329 Arguments: toolCall.Function.Arguments,
330 },
331 }
332 }
333 }
334 }
335 }
336 if choice.FinishReason == "tool_calls" {
337 msgToolCalls = append(msgToolCalls, currentToolCall)
338 if len(acc.Choices) > 0 {
339 acc.Choices[0].Message.ToolCalls = msgToolCalls
340 }
341 }
342 }
343 }
344
345 err := openaiStream.Err()
346 if err == nil || errors.Is(err, io.EOF) {
347 if cfg.Options.Debug {
348 jsonData, _ := json.Marshal(acc.ChatCompletion)
349 slog.Debug("Response", "messages", string(jsonData))
350 }
351
352 if len(acc.Choices) == 0 {
353 eventChan <- ProviderEvent{
354 Type: EventError,
355 Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
356 }
357 return
358 }
359
360 resultFinishReason := acc.Choices[0].FinishReason
361 if resultFinishReason == "" {
362 // If the finish reason is empty, we assume it was a successful completion
363 // INFO: this is happening for openrouter for some reason
364 resultFinishReason = "stop"
365 }
366 // Stream completed successfully
367 finishReason := o.finishReason(resultFinishReason)
368 if len(acc.Choices[0].Message.ToolCalls) > 0 {
369 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
370 }
371 if len(toolCalls) > 0 {
372 finishReason = message.FinishReasonToolUse
373 }
374
375 eventChan <- ProviderEvent{
376 Type: EventComplete,
377 Response: &ProviderResponse{
378 Content: currentContent,
379 ToolCalls: toolCalls,
380 Usage: o.usage(acc.ChatCompletion),
381 FinishReason: finishReason,
382 },
383 }
384 close(eventChan)
385 return
386 }
387
388 // If there is an error we are going to see if we can retry the call
389 retry, after, retryErr := o.shouldRetry(attempts, err)
390 if retryErr != nil {
391 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
392 close(eventChan)
393 return
394 }
395 if retry {
396 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
397 select {
398 case <-ctx.Done():
399 // context cancelled
400 if ctx.Err() == nil {
401 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
402 }
403 close(eventChan)
404 return
405 case <-time.After(time.Duration(after) * time.Millisecond):
406 continue
407 }
408 }
409 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
410 close(eventChan)
411 return
412 }
413 }()
414
415 return eventChan
416}
417
418func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
419 var apiErr *openai.Error
420 if !errors.As(err, &apiErr) {
421 return false, 0, err
422 }
423
424 if attempts > maxRetries {
425 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
426 }
427
428 // Check for token expiration (401 Unauthorized)
429 if apiErr.StatusCode == 401 {
430 o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
431 if err != nil {
432 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
433 }
434 o.client = createOpenAIClient(o.providerOptions)
435 return true, 0, nil
436 }
437
438 if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
439 return false, 0, err
440 }
441
442 retryMs := 0
443 retryAfterValues := apiErr.Response.Header.Values("Retry-After")
444
445 backoffMs := 2000 * (1 << (attempts - 1))
446 jitterMs := int(float64(backoffMs) * 0.2)
447 retryMs = backoffMs + jitterMs
448 if len(retryAfterValues) > 0 {
449 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
450 retryMs = retryMs * 1000
451 }
452 }
453 return true, int64(retryMs), nil
454}
455
456func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
457 var toolCalls []message.ToolCall
458
459 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
460 for _, call := range completion.Choices[0].Message.ToolCalls {
461 toolCall := message.ToolCall{
462 ID: call.ID,
463 Name: call.Function.Name,
464 Input: call.Function.Arguments,
465 Type: "function",
466 Finished: true,
467 }
468 toolCalls = append(toolCalls, toolCall)
469 }
470 }
471
472 return toolCalls
473}
474
475func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
476 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
477 inputTokens := completion.Usage.PromptTokens - cachedTokens
478
479 return TokenUsage{
480 InputTokens: inputTokens,
481 OutputTokens: completion.Usage.CompletionTokens,
482 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
483 CacheReadTokens: cachedTokens,
484 }
485}
486
487func (o *openaiClient) Model() provider.Model {
488 return o.providerOptions.model(o.providerOptions.modelType)
489}