1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/fur/provider"
14 "github.com/charmbracelet/crush/internal/llm/tools"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/openai/openai-go"
17 "github.com/openai/openai-go/option"
18 "github.com/openai/openai-go/shared"
19)
20
21type openaiClient struct {
22 providerOptions providerClientOptions
23 client openai.Client
24}
25
26type OpenAIClient ProviderClient
27
28func newOpenAIClient(opts providerClientOptions) OpenAIClient {
29 return &openaiClient{
30 providerOptions: opts,
31 client: createOpenAIClient(opts),
32 }
33}
34
35func createOpenAIClient(opts providerClientOptions) openai.Client {
36 openaiClientOptions := []option.RequestOption{}
37 if opts.apiKey != "" {
38 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
39 }
40 if opts.baseURL != "" {
41 resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
42 if err == nil {
43 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
44 }
45 }
46
47 if opts.extraHeaders != nil {
48 for key, value := range opts.extraHeaders {
49 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
50 }
51 }
52
53 return openai.NewClient(openaiClientOptions...)
54}
55
56func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
57 // Add system message first
58 openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
59
60 for _, msg := range messages {
61 switch msg.Role {
62 case message.User:
63 var content []openai.ChatCompletionContentPartUnionParam
64 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
65 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
66 for _, binaryContent := range msg.BinaryContent() {
67 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
68 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
69
70 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
71 }
72
73 openaiMessages = append(openaiMessages, openai.UserMessage(content))
74
75 case message.Assistant:
76 assistantMsg := openai.ChatCompletionAssistantMessageParam{
77 Role: "assistant",
78 }
79
80 if msg.Content().String() != "" {
81 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
82 OfString: openai.String(msg.Content().String()),
83 }
84 }
85
86 if len(msg.ToolCalls()) > 0 {
87 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
88 for i, call := range msg.ToolCalls() {
89 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
90 ID: call.ID,
91 Type: "function",
92 Function: openai.ChatCompletionMessageToolCallFunctionParam{
93 Name: call.Name,
94 Arguments: call.Input,
95 },
96 }
97 }
98 }
99
100 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
101 OfAssistant: &assistantMsg,
102 })
103
104 case message.Tool:
105 for _, result := range msg.ToolResults() {
106 openaiMessages = append(openaiMessages,
107 openai.ToolMessage(result.Content, result.ToolCallID),
108 )
109 }
110 }
111 }
112
113 return
114}
115
116func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
117 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
118
119 for i, tool := range tools {
120 info := tool.Info()
121 openaiTools[i] = openai.ChatCompletionToolParam{
122 Function: openai.FunctionDefinitionParam{
123 Name: info.Name,
124 Description: openai.String(info.Description),
125 Parameters: openai.FunctionParameters{
126 "type": "object",
127 "properties": info.Parameters,
128 "required": info.Required,
129 },
130 },
131 }
132 }
133
134 return openaiTools
135}
136
137func (o *openaiClient) finishReason(reason string) message.FinishReason {
138 switch reason {
139 case "stop":
140 return message.FinishReasonEndTurn
141 case "length":
142 return message.FinishReasonMaxTokens
143 case "tool_calls":
144 return message.FinishReasonToolUse
145 default:
146 return message.FinishReasonUnknown
147 }
148}
149
150func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
151 model := o.providerOptions.model(o.providerOptions.modelType)
152 cfg := config.Get()
153
154 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
155 if o.providerOptions.modelType == config.SelectedModelTypeSmall {
156 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
157 }
158
159 reasoningEffort := modelConfig.ReasoningEffort
160
161 params := openai.ChatCompletionNewParams{
162 Model: openai.ChatModel(model.ID),
163 Messages: messages,
164 Tools: tools,
165 }
166
167 maxTokens := model.DefaultMaxTokens
168 if modelConfig.MaxTokens > 0 {
169 maxTokens = modelConfig.MaxTokens
170 }
171
172 // Override max tokens if set in provider options
173 if o.providerOptions.maxTokens > 0 {
174 maxTokens = o.providerOptions.maxTokens
175 }
176 if model.CanReason {
177 params.MaxCompletionTokens = openai.Int(maxTokens)
178 switch reasoningEffort {
179 case "low":
180 params.ReasoningEffort = shared.ReasoningEffortLow
181 case "medium":
182 params.ReasoningEffort = shared.ReasoningEffortMedium
183 case "high":
184 params.ReasoningEffort = shared.ReasoningEffortHigh
185 default:
186 params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
187 }
188 } else {
189 params.MaxTokens = openai.Int(maxTokens)
190 }
191
192 return params
193}
194
195func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
196 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
197 cfg := config.Get()
198 if cfg.Options.Debug {
199 jsonData, _ := json.Marshal(params)
200 slog.Debug("Prepared messages", "messages", string(jsonData))
201 }
202 attempts := 0
203 for {
204 attempts++
205 openaiResponse, err := o.client.Chat.Completions.New(
206 ctx,
207 params,
208 )
209 // If there is an error we are going to see if we can retry the call
210 if err != nil {
211 retry, after, retryErr := o.shouldRetry(attempts, err)
212 if retryErr != nil {
213 return nil, retryErr
214 }
215 if retry {
216 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
217 select {
218 case <-ctx.Done():
219 return nil, ctx.Err()
220 case <-time.After(time.Duration(after) * time.Millisecond):
221 continue
222 }
223 }
224 return nil, retryErr
225 }
226
227 if len(openaiResponse.Choices) == 0 {
228 return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
229 }
230
231 content := ""
232 if openaiResponse.Choices[0].Message.Content != "" {
233 content = openaiResponse.Choices[0].Message.Content
234 }
235
236 toolCalls := o.toolCalls(*openaiResponse)
237 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
238
239 if len(toolCalls) > 0 {
240 finishReason = message.FinishReasonToolUse
241 }
242
243 return &ProviderResponse{
244 Content: content,
245 ToolCalls: toolCalls,
246 Usage: o.usage(*openaiResponse),
247 FinishReason: finishReason,
248 }, nil
249 }
250}
251
252func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
253 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
254 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
255 IncludeUsage: openai.Bool(true),
256 }
257
258 cfg := config.Get()
259 if cfg.Options.Debug {
260 jsonData, _ := json.Marshal(params)
261 slog.Debug("Prepared messages", "messages", string(jsonData))
262 }
263
264 attempts := 0
265 eventChan := make(chan ProviderEvent)
266
267 go func() {
268 for {
269 attempts++
270 openaiStream := o.client.Chat.Completions.NewStreaming(
271 ctx,
272 params,
273 )
274
275 acc := openai.ChatCompletionAccumulator{}
276 currentContent := ""
277 toolCalls := make([]message.ToolCall, 0)
278
279 var currentToolCallID string
280 var currentToolCall openai.ChatCompletionMessageToolCall
281 var msgToolCalls []openai.ChatCompletionMessageToolCall
282 for openaiStream.Next() {
283 chunk := openaiStream.Current()
284 acc.AddChunk(chunk)
285 // This fixes multiple tool calls for some providers
286 for _, choice := range chunk.Choices {
287 if choice.Delta.Content != "" {
288 eventChan <- ProviderEvent{
289 Type: EventContentDelta,
290 Content: choice.Delta.Content,
291 }
292 currentContent += choice.Delta.Content
293 } else if len(choice.Delta.ToolCalls) > 0 {
294 toolCall := choice.Delta.ToolCalls[0]
295 // Detect tool use start
296 if currentToolCallID == "" {
297 if toolCall.ID != "" {
298 currentToolCallID = toolCall.ID
299 currentToolCall = openai.ChatCompletionMessageToolCall{
300 ID: toolCall.ID,
301 Type: "function",
302 Function: openai.ChatCompletionMessageToolCallFunction{
303 Name: toolCall.Function.Name,
304 Arguments: toolCall.Function.Arguments,
305 },
306 }
307 }
308 } else {
309 // Delta tool use
310 if toolCall.ID == "" {
311 currentToolCall.Function.Arguments += toolCall.Function.Arguments
312 } else {
313 // Detect new tool use
314 if toolCall.ID != currentToolCallID {
315 msgToolCalls = append(msgToolCalls, currentToolCall)
316 currentToolCallID = toolCall.ID
317 currentToolCall = openai.ChatCompletionMessageToolCall{
318 ID: toolCall.ID,
319 Type: "function",
320 Function: openai.ChatCompletionMessageToolCallFunction{
321 Name: toolCall.Function.Name,
322 Arguments: toolCall.Function.Arguments,
323 },
324 }
325 }
326 }
327 }
328 }
329 if choice.FinishReason == "tool_calls" {
330 msgToolCalls = append(msgToolCalls, currentToolCall)
331 if len(acc.Choices) > 0 {
332 acc.Choices[0].Message.ToolCalls = msgToolCalls
333 }
334 }
335 }
336 }
337
338 err := openaiStream.Err()
339 if err == nil || errors.Is(err, io.EOF) {
340 if cfg.Options.Debug {
341 jsonData, _ := json.Marshal(acc.ChatCompletion)
342 slog.Debug("Response", "messages", string(jsonData))
343 }
344
345 if len(acc.Choices) == 0 {
346 eventChan <- ProviderEvent{
347 Type: EventError,
348 Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
349 }
350 return
351 }
352
353 resultFinishReason := acc.Choices[0].FinishReason
354 if resultFinishReason == "" {
355 // If the finish reason is empty, we assume it was a successful completion
356 // INFO: this is happening for openrouter for some reason
357 resultFinishReason = "stop"
358 }
359 // Stream completed successfully
360 finishReason := o.finishReason(resultFinishReason)
361 if len(acc.Choices[0].Message.ToolCalls) > 0 {
362 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
363 }
364 if len(toolCalls) > 0 {
365 finishReason = message.FinishReasonToolUse
366 }
367
368 eventChan <- ProviderEvent{
369 Type: EventComplete,
370 Response: &ProviderResponse{
371 Content: currentContent,
372 ToolCalls: toolCalls,
373 Usage: o.usage(acc.ChatCompletion),
374 FinishReason: finishReason,
375 },
376 }
377 close(eventChan)
378 return
379 }
380
381 // If there is an error we are going to see if we can retry the call
382 retry, after, retryErr := o.shouldRetry(attempts, err)
383 if retryErr != nil {
384 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
385 close(eventChan)
386 return
387 }
388 if retry {
389 slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
390 select {
391 case <-ctx.Done():
392 // context cancelled
393 if ctx.Err() == nil {
394 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
395 }
396 close(eventChan)
397 return
398 case <-time.After(time.Duration(after) * time.Millisecond):
399 continue
400 }
401 }
402 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
403 close(eventChan)
404 return
405 }
406 }()
407
408 return eventChan
409}
410
411func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
412 var apiErr *openai.Error
413 if !errors.As(err, &apiErr) {
414 return false, 0, err
415 }
416
417 if attempts > maxRetries {
418 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
419 }
420
421 // Check for token expiration (401 Unauthorized)
422 if apiErr.StatusCode == 401 {
423 o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
424 if err != nil {
425 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
426 }
427 o.client = createOpenAIClient(o.providerOptions)
428 return true, 0, nil
429 }
430
431 if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
432 return false, 0, err
433 }
434
435 retryMs := 0
436 retryAfterValues := apiErr.Response.Header.Values("Retry-After")
437
438 backoffMs := 2000 * (1 << (attempts - 1))
439 jitterMs := int(float64(backoffMs) * 0.2)
440 retryMs = backoffMs + jitterMs
441 if len(retryAfterValues) > 0 {
442 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
443 retryMs = retryMs * 1000
444 }
445 }
446 return true, int64(retryMs), nil
447}
448
449func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
450 var toolCalls []message.ToolCall
451
452 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
453 for _, call := range completion.Choices[0].Message.ToolCalls {
454 toolCall := message.ToolCall{
455 ID: call.ID,
456 Name: call.Function.Name,
457 Input: call.Function.Arguments,
458 Type: "function",
459 Finished: true,
460 }
461 toolCalls = append(toolCalls, toolCall)
462 }
463 }
464
465 return toolCalls
466}
467
468func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
469 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
470 inputTokens := completion.Usage.PromptTokens - cachedTokens
471
472 return TokenUsage{
473 InputTokens: inputTokens,
474 OutputTokens: completion.Usage.CompletionTokens,
475 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
476 CacheReadTokens: cachedTokens,
477 }
478}
479
480func (o *openaiClient) Model() provider.Model {
481 return o.providerOptions.model(o.providerOptions.modelType)
482}