1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/llm/tools"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/openai/openai-go"
17 "github.com/openai/openai-go/option"
18 "github.com/openai/openai-go/shared"
19)
20
21type openaiClient struct {
22 providerOptions providerClientOptions
23 client openai.Client
24}
25
26type OpenAIClient ProviderClient
27
28func newOpenAIClient(opts providerClientOptions) OpenAIClient {
29 return &openaiClient{
30 providerOptions: opts,
31 client: createOpenAIClient(opts),
32 }
33}
34
35func createOpenAIClient(opts providerClientOptions) openai.Client {
36 openaiClientOptions := []option.RequestOption{}
37 if opts.apiKey != "" {
38 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
39 }
40 if opts.baseURL != "" {
41 resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
42 if err == nil {
43 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
44 }
45 }
46
47 for key, value := range opts.extraHeaders {
48 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
49 }
50
51 for extraKey, extraValue := range opts.extraBody {
52 openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
53 }
54
55 return openai.NewClient(openaiClientOptions...)
56}
57
58func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
59 // Add system message first
60 systemMessage := o.providerOptions.systemMessage
61 if o.providerOptions.systemPromptPrefix != "" {
62 systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
63 }
64 openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage))
65
66 for _, msg := range messages {
67 switch msg.Role {
68 case message.User:
69 var content []openai.ChatCompletionContentPartUnionParam
70 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
71 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
72 for _, binaryContent := range msg.BinaryContent() {
73 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
74 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
75
76 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
77 }
78
79 openaiMessages = append(openaiMessages, openai.UserMessage(content))
80
81 case message.Assistant:
82 assistantMsg := openai.ChatCompletionAssistantMessageParam{
83 Role: "assistant",
84 }
85
86 hasContent := false
87 if msg.Content().String() != "" {
88 hasContent = true
89 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
90 OfString: openai.String(msg.Content().String()),
91 }
92 }
93
94 if len(msg.ToolCalls()) > 0 {
95 hasContent = true
96 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
97 for i, call := range msg.ToolCalls() {
98 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
99 ID: call.ID,
100 Type: "function",
101 Function: openai.ChatCompletionMessageToolCallFunctionParam{
102 Name: call.Name,
103 Arguments: call.Input,
104 },
105 }
106 }
107 }
108 if !hasContent {
109 slog.Warn("There is a message without content, investigate, this should not happen")
110 continue
111 }
112
113 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
114 OfAssistant: &assistantMsg,
115 })
116
117 case message.Tool:
118 for _, result := range msg.ToolResults() {
119 openaiMessages = append(openaiMessages,
120 openai.ToolMessage(result.Content, result.ToolCallID),
121 )
122 }
123 }
124 }
125
126 return
127}
128
129func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
130 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
131
132 for i, tool := range tools {
133 info := tool.Info()
134 openaiTools[i] = openai.ChatCompletionToolParam{
135 Function: openai.FunctionDefinitionParam{
136 Name: info.Name,
137 Description: openai.String(info.Description),
138 Parameters: openai.FunctionParameters{
139 "type": "object",
140 "properties": info.Parameters,
141 "required": info.Required,
142 },
143 },
144 }
145 }
146
147 return openaiTools
148}
149
150func (o *openaiClient) finishReason(reason string) message.FinishReason {
151 switch reason {
152 case "stop":
153 return message.FinishReasonEndTurn
154 case "length":
155 return message.FinishReasonMaxTokens
156 case "tool_calls":
157 return message.FinishReasonToolUse
158 default:
159 return message.FinishReasonUnknown
160 }
161}
162
163func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
164 model := o.providerOptions.model(o.providerOptions.modelType)
165 cfg := config.Get()
166
167 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
168 if o.providerOptions.modelType == config.SelectedModelTypeSmall {
169 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
170 }
171
172 reasoningEffort := modelConfig.ReasoningEffort
173
174 params := openai.ChatCompletionNewParams{
175 Model: openai.ChatModel(model.ID),
176 Messages: messages,
177 Tools: tools,
178 }
179
180 maxTokens := model.DefaultMaxTokens
181 if modelConfig.MaxTokens > 0 {
182 maxTokens = modelConfig.MaxTokens
183 }
184
185 // Override max tokens if set in provider options
186 if o.providerOptions.maxTokens > 0 {
187 maxTokens = o.providerOptions.maxTokens
188 }
189 if model.CanReason {
190 params.MaxCompletionTokens = openai.Int(maxTokens)
191 switch reasoningEffort {
192 case "low":
193 params.ReasoningEffort = shared.ReasoningEffortLow
194 case "medium":
195 params.ReasoningEffort = shared.ReasoningEffortMedium
196 case "high":
197 params.ReasoningEffort = shared.ReasoningEffortHigh
198 default:
199 params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
200 }
201 } else {
202 params.MaxTokens = openai.Int(maxTokens)
203 }
204
205 return params
206}
207
208func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
209 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
210 cfg := config.Get()
211 if cfg.Options.Debug {
212 jsonData, _ := json.Marshal(params)
213 slog.Debug("Prepared messages", "messages", string(jsonData))
214 }
215 attempts := 0
216 for {
217 attempts++
218 openaiResponse, err := o.client.Chat.Completions.New(
219 ctx,
220 params,
221 )
222 // If there is an error we are going to see if we can retry the call
223 if err != nil {
224 retry, after, retryErr := o.shouldRetry(attempts, err)
225 if retryErr != nil {
226 return nil, retryErr
227 }
228 if retry {
229 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
230 select {
231 case <-ctx.Done():
232 return nil, ctx.Err()
233 case <-time.After(time.Duration(after) * time.Millisecond):
234 continue
235 }
236 }
237 return nil, retryErr
238 }
239
240 if len(openaiResponse.Choices) == 0 {
241 return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
242 }
243
244 content := ""
245 if openaiResponse.Choices[0].Message.Content != "" {
246 content = openaiResponse.Choices[0].Message.Content
247 }
248
249 toolCalls := o.toolCalls(*openaiResponse)
250 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
251
252 if len(toolCalls) > 0 {
253 finishReason = message.FinishReasonToolUse
254 }
255
256 return &ProviderResponse{
257 Content: content,
258 ToolCalls: toolCalls,
259 Usage: o.usage(*openaiResponse),
260 FinishReason: finishReason,
261 }, nil
262 }
263}
264
265func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
266 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
267 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
268 IncludeUsage: openai.Bool(true),
269 }
270
271 cfg := config.Get()
272 if cfg.Options.Debug {
273 jsonData, _ := json.Marshal(params)
274 slog.Debug("Prepared messages", "messages", string(jsonData))
275 }
276
277 attempts := 0
278 eventChan := make(chan ProviderEvent)
279
280 go func() {
281 for {
282 attempts++
283 openaiStream := o.client.Chat.Completions.NewStreaming(
284 ctx,
285 params,
286 )
287
288 acc := openai.ChatCompletionAccumulator{}
289 currentContent := ""
290 toolCalls := make([]message.ToolCall, 0)
291
292 var currentToolCallID string
293 var currentToolCall openai.ChatCompletionMessageToolCall
294 var msgToolCalls []openai.ChatCompletionMessageToolCall
295 for openaiStream.Next() {
296 chunk := openaiStream.Current()
297 acc.AddChunk(chunk)
298 // This fixes multiple tool calls for some providers
299 for _, choice := range chunk.Choices {
300 if choice.Delta.Content != "" {
301 eventChan <- ProviderEvent{
302 Type: EventContentDelta,
303 Content: choice.Delta.Content,
304 }
305 currentContent += choice.Delta.Content
306 } else if len(choice.Delta.ToolCalls) > 0 {
307 toolCall := choice.Delta.ToolCalls[0]
308 // Detect tool use start
309 if currentToolCallID == "" {
310 if toolCall.ID != "" {
311 currentToolCallID = toolCall.ID
312 currentToolCall = openai.ChatCompletionMessageToolCall{
313 ID: toolCall.ID,
314 Type: "function",
315 Function: openai.ChatCompletionMessageToolCallFunction{
316 Name: toolCall.Function.Name,
317 Arguments: toolCall.Function.Arguments,
318 },
319 }
320 }
321 } else {
322 // Delta tool use
323 if toolCall.ID == "" {
324 currentToolCall.Function.Arguments += toolCall.Function.Arguments
325 } else {
326 // Detect new tool use
327 if toolCall.ID != currentToolCallID {
328 msgToolCalls = append(msgToolCalls, currentToolCall)
329 currentToolCallID = toolCall.ID
330 currentToolCall = openai.ChatCompletionMessageToolCall{
331 ID: toolCall.ID,
332 Type: "function",
333 Function: openai.ChatCompletionMessageToolCallFunction{
334 Name: toolCall.Function.Name,
335 Arguments: toolCall.Function.Arguments,
336 },
337 }
338 }
339 }
340 }
341 }
342 if choice.FinishReason == "tool_calls" {
343 msgToolCalls = append(msgToolCalls, currentToolCall)
344 if len(acc.Choices) > 0 {
345 acc.Choices[0].Message.ToolCalls = msgToolCalls
346 }
347 }
348 }
349 }
350
351 err := openaiStream.Err()
352 if err == nil || errors.Is(err, io.EOF) {
353 if cfg.Options.Debug {
354 jsonData, _ := json.Marshal(acc.ChatCompletion)
355 slog.Debug("Response", "messages", string(jsonData))
356 }
357
358 if len(acc.Choices) == 0 {
359 eventChan <- ProviderEvent{
360 Type: EventError,
361 Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
362 }
363 return
364 }
365
366 resultFinishReason := acc.Choices[0].FinishReason
367 if resultFinishReason == "" {
368 // If the finish reason is empty, we assume it was a successful completion
369 // INFO: this is happening for openrouter for some reason
370 resultFinishReason = "stop"
371 }
372 // Stream completed successfully
373 finishReason := o.finishReason(resultFinishReason)
374 if len(acc.Choices[0].Message.ToolCalls) > 0 {
375 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
376 }
377 if len(toolCalls) > 0 {
378 finishReason = message.FinishReasonToolUse
379 }
380
381 eventChan <- ProviderEvent{
382 Type: EventComplete,
383 Response: &ProviderResponse{
384 Content: currentContent,
385 ToolCalls: toolCalls,
386 Usage: o.usage(acc.ChatCompletion),
387 FinishReason: finishReason,
388 },
389 }
390 close(eventChan)
391 return
392 }
393
394 // If there is an error we are going to see if we can retry the call
395 retry, after, retryErr := o.shouldRetry(attempts, err)
396 if retryErr != nil {
397 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
398 close(eventChan)
399 return
400 }
401 if retry {
402 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
403 select {
404 case <-ctx.Done():
405 // context cancelled
406 if ctx.Err() == nil {
407 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
408 }
409 close(eventChan)
410 return
411 case <-time.After(time.Duration(after) * time.Millisecond):
412 continue
413 }
414 }
415 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
416 close(eventChan)
417 return
418 }
419 }()
420
421 return eventChan
422}
423
424func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
425 var apiErr *openai.Error
426 if !errors.As(err, &apiErr) {
427 return false, 0, err
428 }
429
430 if attempts > maxRetries {
431 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
432 }
433
434 // Check for token expiration (401 Unauthorized)
435 if apiErr.StatusCode == 401 {
436 o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
437 if err != nil {
438 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
439 }
440 o.client = createOpenAIClient(o.providerOptions)
441 return true, 0, nil
442 }
443
444 if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
445 return false, 0, err
446 }
447
448 retryMs := 0
449 retryAfterValues := apiErr.Response.Header.Values("Retry-After")
450
451 backoffMs := 2000 * (1 << (attempts - 1))
452 jitterMs := int(float64(backoffMs) * 0.2)
453 retryMs = backoffMs + jitterMs
454 if len(retryAfterValues) > 0 {
455 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
456 retryMs = retryMs * 1000
457 }
458 }
459 return true, int64(retryMs), nil
460}
461
462func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
463 var toolCalls []message.ToolCall
464
465 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
466 for _, call := range completion.Choices[0].Message.ToolCalls {
467 toolCall := message.ToolCall{
468 ID: call.ID,
469 Name: call.Function.Name,
470 Input: call.Function.Arguments,
471 Type: "function",
472 Finished: true,
473 }
474 toolCalls = append(toolCalls, toolCall)
475 }
476 }
477
478 return toolCalls
479}
480
481func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
482 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
483 inputTokens := completion.Usage.PromptTokens - cachedTokens
484
485 return TokenUsage{
486 InputTokens: inputTokens,
487 OutputTokens: completion.Usage.CompletionTokens,
488 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
489 CacheReadTokens: cachedTokens,
490 }
491}
492
493func (o *openaiClient) Model() catwalk.Model {
494 return o.providerOptions.model(o.providerOptions.modelType)
495}