1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/llm/tools"
16 "github.com/charmbracelet/crush/internal/message"
17 "github.com/openai/openai-go"
18 "github.com/openai/openai-go/option"
19 "github.com/openai/openai-go/packages/param"
20 "github.com/openai/openai-go/shared"
21)
22
23type openaiClient struct {
24 providerOptions providerClientOptions
25 client openai.Client
26}
27
28type OpenAIClient ProviderClient
29
30func newOpenAIClient(opts providerClientOptions) OpenAIClient {
31 return &openaiClient{
32 providerOptions: opts,
33 client: createOpenAIClient(opts),
34 }
35}
36
37func createOpenAIClient(opts providerClientOptions) openai.Client {
38 openaiClientOptions := []option.RequestOption{}
39 if opts.apiKey != "" {
40 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
41 }
42 if opts.baseURL != "" {
43 resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
44 if err == nil {
45 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
46 }
47 }
48
49 for key, value := range opts.extraHeaders {
50 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
51 }
52
53 for extraKey, extraValue := range opts.extraBody {
54 openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
55 }
56
57 return openai.NewClient(openaiClientOptions...)
58}
59
60func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
61 isAnthropicModel := o.providerOptions.config.ID == string(catwalk.InferenceProviderOpenRouter) && strings.HasPrefix(o.Model().ID, "anthropic/")
62 // Add system message first
63 systemMessage := o.providerOptions.systemMessage
64 if o.providerOptions.systemPromptPrefix != "" {
65 systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
66 }
67
68 systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
69 if isAnthropicModel && !o.providerOptions.disableCache {
70 systemTextBlock.SetExtraFields(
71 map[string]any{
72 "cache_control": map[string]string{
73 "type": "ephemeral",
74 },
75 },
76 )
77 }
78 var content []openai.ChatCompletionContentPartTextParam
79 content = append(content, systemTextBlock)
80 system := openai.SystemMessage(content)
81 openaiMessages = append(openaiMessages, system)
82
83 for i, msg := range messages {
84 cache := false
85 if i > len(messages)-3 {
86 cache = true
87 }
88 switch msg.Role {
89 case message.User:
90 var content []openai.ChatCompletionContentPartUnionParam
91 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
92 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
93 for _, binaryContent := range msg.BinaryContent() {
94 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
95 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
96
97 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
98 }
99 if cache && !o.providerOptions.disableCache && isAnthropicModel {
100 textBlock.SetExtraFields(map[string]any{
101 "cache_control": map[string]string{
102 "type": "ephemeral",
103 },
104 })
105 }
106
107 openaiMessages = append(openaiMessages, openai.UserMessage(content))
108
109 case message.Assistant:
110 assistantMsg := openai.ChatCompletionAssistantMessageParam{
111 Role: "assistant",
112 }
113
114 hasContent := false
115 if msg.Content().String() != "" {
116 hasContent = true
117 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
118 if cache && !o.providerOptions.disableCache && isAnthropicModel {
119 textBlock.SetExtraFields(map[string]any{
120 "cache_control": map[string]string{
121 "type": "ephemeral",
122 },
123 })
124 }
125 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
126 OfArrayOfContentParts: []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
127 {
128 OfText: &textBlock,
129 },
130 },
131 }
132 }
133
134 if len(msg.ToolCalls()) > 0 {
135 hasContent = true
136 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
137 OfString: param.NewOpt(msg.Content().String()),
138 }
139 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
140 for i, call := range msg.ToolCalls() {
141 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
142 ID: call.ID,
143 Type: "function",
144 Function: openai.ChatCompletionMessageToolCallFunctionParam{
145 Name: call.Name,
146 Arguments: call.Input,
147 },
148 }
149 }
150 }
151 if !hasContent {
152 slog.Warn("There is a message without content, investigate, this should not happen")
153 continue
154 }
155
156 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
157 OfAssistant: &assistantMsg,
158 })
159
160 case message.Tool:
161 for _, result := range msg.ToolResults() {
162 openaiMessages = append(openaiMessages,
163 openai.ToolMessage(result.Content, result.ToolCallID),
164 )
165 }
166 }
167 }
168
169 return
170}
171
172func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
173 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
174
175 for i, tool := range tools {
176 info := tool.Info()
177 openaiTools[i] = openai.ChatCompletionToolParam{
178 Function: openai.FunctionDefinitionParam{
179 Name: info.Name,
180 Description: openai.String(info.Description),
181 Parameters: openai.FunctionParameters{
182 "type": "object",
183 "properties": info.Parameters,
184 "required": info.Required,
185 },
186 },
187 }
188 }
189
190 return openaiTools
191}
192
193func (o *openaiClient) finishReason(reason string) message.FinishReason {
194 switch reason {
195 case "stop":
196 return message.FinishReasonEndTurn
197 case "length":
198 return message.FinishReasonMaxTokens
199 case "tool_calls":
200 return message.FinishReasonToolUse
201 default:
202 return message.FinishReasonUnknown
203 }
204}
205
206func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
207 model := o.providerOptions.model(o.providerOptions.modelType)
208 cfg := config.Get()
209
210 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
211 if o.providerOptions.modelType == config.SelectedModelTypeSmall {
212 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
213 }
214
215 reasoningEffort := modelConfig.ReasoningEffort
216
217 params := openai.ChatCompletionNewParams{
218 Model: openai.ChatModel(model.ID),
219 Messages: messages,
220 Tools: tools,
221 }
222
223 maxTokens := model.DefaultMaxTokens
224 if modelConfig.MaxTokens > 0 {
225 maxTokens = modelConfig.MaxTokens
226 }
227
228 // Override max tokens if set in provider options
229 if o.providerOptions.maxTokens > 0 {
230 maxTokens = o.providerOptions.maxTokens
231 }
232 if model.CanReason {
233 params.MaxCompletionTokens = openai.Int(maxTokens)
234 switch reasoningEffort {
235 case "low":
236 params.ReasoningEffort = shared.ReasoningEffortLow
237 case "medium":
238 params.ReasoningEffort = shared.ReasoningEffortMedium
239 case "high":
240 params.ReasoningEffort = shared.ReasoningEffortHigh
241 default:
242 params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
243 }
244 } else {
245 params.MaxTokens = openai.Int(maxTokens)
246 }
247
248 return params
249}
250
251func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
252 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
253 cfg := config.Get()
254 if cfg.Options.Debug {
255 jsonData, _ := json.Marshal(params)
256 slog.Debug("Prepared messages", "messages", string(jsonData))
257 }
258 attempts := 0
259 for {
260 attempts++
261 openaiResponse, err := o.client.Chat.Completions.New(
262 ctx,
263 params,
264 )
265 // If there is an error we are going to see if we can retry the call
266 if err != nil {
267 retry, after, retryErr := o.shouldRetry(attempts, err)
268 if retryErr != nil {
269 return nil, retryErr
270 }
271 if retry {
272 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
273 select {
274 case <-ctx.Done():
275 return nil, ctx.Err()
276 case <-time.After(time.Duration(after) * time.Millisecond):
277 continue
278 }
279 }
280 return nil, retryErr
281 }
282
283 if len(openaiResponse.Choices) == 0 {
284 return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
285 }
286
287 content := ""
288 if openaiResponse.Choices[0].Message.Content != "" {
289 content = openaiResponse.Choices[0].Message.Content
290 }
291
292 toolCalls := o.toolCalls(*openaiResponse)
293 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
294
295 if len(toolCalls) > 0 {
296 finishReason = message.FinishReasonToolUse
297 }
298
299 return &ProviderResponse{
300 Content: content,
301 ToolCalls: toolCalls,
302 Usage: o.usage(*openaiResponse),
303 FinishReason: finishReason,
304 }, nil
305 }
306}
307
308func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
309 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
310 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
311 IncludeUsage: openai.Bool(true),
312 }
313
314 cfg := config.Get()
315 if cfg.Options.Debug {
316 jsonData, _ := json.Marshal(params)
317 slog.Debug("Prepared messages", "messages", string(jsonData))
318 }
319
320 attempts := 0
321 eventChan := make(chan ProviderEvent)
322
323 go func() {
324 for {
325 attempts++
326 // Kujtim: fixes an issue with anthropig models on openrouter
327 if len(params.Tools) == 0 {
328 params.Tools = nil
329 }
330 openaiStream := o.client.Chat.Completions.NewStreaming(
331 ctx,
332 params,
333 )
334
335 acc := openai.ChatCompletionAccumulator{}
336 currentContent := ""
337 toolCalls := make([]message.ToolCall, 0)
338
339 var currentToolCallID string
340 var currentToolCall openai.ChatCompletionMessageToolCall
341 var msgToolCalls []openai.ChatCompletionMessageToolCall
342 currentToolIndex := 0
343 for openaiStream.Next() {
344 chunk := openaiStream.Current()
345 // Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
346 if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
347 chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
348 currentToolIndex++
349 }
350 acc.AddChunk(chunk)
351 // This fixes multiple tool calls for some providers
352 for _, choice := range chunk.Choices {
353 if choice.Delta.Content != "" {
354 eventChan <- ProviderEvent{
355 Type: EventContentDelta,
356 Content: choice.Delta.Content,
357 }
358 currentContent += choice.Delta.Content
359 } else if len(choice.Delta.ToolCalls) > 0 {
360 toolCall := choice.Delta.ToolCalls[0]
361 // Detect tool use start
362 if currentToolCallID == "" {
363 if toolCall.ID != "" {
364 currentToolCallID = toolCall.ID
365 eventChan <- ProviderEvent{
366 Type: EventToolUseStart,
367 ToolCall: &message.ToolCall{
368 ID: toolCall.ID,
369 Name: toolCall.Function.Name,
370 Finished: false,
371 },
372 }
373 currentToolCall = openai.ChatCompletionMessageToolCall{
374 ID: toolCall.ID,
375 Type: "function",
376 Function: openai.ChatCompletionMessageToolCallFunction{
377 Name: toolCall.Function.Name,
378 Arguments: toolCall.Function.Arguments,
379 },
380 }
381 }
382 } else {
383 // Delta tool use
384 if toolCall.ID == "" || toolCall.ID == currentToolCallID {
385 currentToolCall.Function.Arguments += toolCall.Function.Arguments
386 } else {
387 // Detect new tool use
388 if toolCall.ID != currentToolCallID {
389 msgToolCalls = append(msgToolCalls, currentToolCall)
390 currentToolCallID = toolCall.ID
391 eventChan <- ProviderEvent{
392 Type: EventToolUseStart,
393 ToolCall: &message.ToolCall{
394 ID: toolCall.ID,
395 Name: toolCall.Function.Name,
396 Finished: false,
397 },
398 }
399 currentToolCall = openai.ChatCompletionMessageToolCall{
400 ID: toolCall.ID,
401 Type: "function",
402 Function: openai.ChatCompletionMessageToolCallFunction{
403 Name: toolCall.Function.Name,
404 Arguments: toolCall.Function.Arguments,
405 },
406 }
407 }
408 }
409 }
410 }
411 // Kujtim: some models send finish stop even for tool calls
412 if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
413 msgToolCalls = append(msgToolCalls, currentToolCall)
414 if len(acc.Choices) > 0 {
415 acc.Choices[0].Message.ToolCalls = msgToolCalls
416 }
417 }
418 }
419 }
420
421 err := openaiStream.Err()
422 if err == nil || errors.Is(err, io.EOF) {
423 if cfg.Options.Debug {
424 jsonData, _ := json.Marshal(acc.ChatCompletion)
425 slog.Debug("Response", "messages", string(jsonData))
426 }
427
428 if len(acc.Choices) == 0 {
429 eventChan <- ProviderEvent{
430 Type: EventError,
431 Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
432 }
433 return
434 }
435
436 resultFinishReason := acc.Choices[0].FinishReason
437 if resultFinishReason == "" {
438 // If the finish reason is empty, we assume it was a successful completion
439 // INFO: this is happening for openrouter for some reason
440 resultFinishReason = "stop"
441 }
442 // Stream completed successfully
443 finishReason := o.finishReason(resultFinishReason)
444 if len(acc.Choices[0].Message.ToolCalls) > 0 {
445 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
446 }
447 if len(toolCalls) > 0 {
448 finishReason = message.FinishReasonToolUse
449 }
450
451 eventChan <- ProviderEvent{
452 Type: EventComplete,
453 Response: &ProviderResponse{
454 Content: currentContent,
455 ToolCalls: toolCalls,
456 Usage: o.usage(acc.ChatCompletion),
457 FinishReason: finishReason,
458 },
459 }
460 close(eventChan)
461 return
462 }
463
464 // If there is an error we are going to see if we can retry the call
465 retry, after, retryErr := o.shouldRetry(attempts, err)
466 if retryErr != nil {
467 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
468 close(eventChan)
469 return
470 }
471 if retry {
472 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
473 select {
474 case <-ctx.Done():
475 // context cancelled
476 if ctx.Err() == nil {
477 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
478 }
479 close(eventChan)
480 return
481 case <-time.After(time.Duration(after) * time.Millisecond):
482 continue
483 }
484 }
485 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
486 close(eventChan)
487 return
488 }
489 }()
490
491 return eventChan
492}
493
494func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
495 if attempts > maxRetries {
496 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
497 }
498 if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
499 return false, 0, err
500 }
501 var apiErr *openai.Error
502 retryMs := 0
503 retryAfterValues := []string{}
504 if errors.As(err, &apiErr) {
505 // Check for token expiration (401 Unauthorized)
506 if apiErr.StatusCode == 401 {
507 o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
508 if err != nil {
509 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
510 }
511 o.client = createOpenAIClient(o.providerOptions)
512 return true, 0, nil
513 }
514
515 if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
516 return false, 0, err
517 }
518
519 retryAfterValues = apiErr.Response.Header.Values("Retry-After")
520 }
521
522 if apiErr != nil {
523 slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type)
524 if len(retryAfterValues) > 0 {
525 slog.Warn("Retry-After header", "values", retryAfterValues)
526 }
527 } else {
528 slog.Warn("OpenAI API error", "error", err.Error())
529 }
530
531 backoffMs := 2000 * (1 << (attempts - 1))
532 jitterMs := int(float64(backoffMs) * 0.2)
533 retryMs = backoffMs + jitterMs
534 if len(retryAfterValues) > 0 {
535 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
536 retryMs = retryMs * 1000
537 }
538 }
539 return true, int64(retryMs), nil
540}
541
542func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
543 var toolCalls []message.ToolCall
544
545 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
546 for _, call := range completion.Choices[0].Message.ToolCalls {
547 toolCall := message.ToolCall{
548 ID: call.ID,
549 Name: call.Function.Name,
550 Input: call.Function.Arguments,
551 Type: "function",
552 Finished: true,
553 }
554 toolCalls = append(toolCalls, toolCall)
555 }
556 }
557
558 return toolCalls
559}
560
561func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
562 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
563 inputTokens := completion.Usage.PromptTokens - cachedTokens
564
565 return TokenUsage{
566 InputTokens: inputTokens,
567 OutputTokens: completion.Usage.CompletionTokens,
568 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
569 CacheReadTokens: cachedTokens,
570 }
571}
572
573func (o *openaiClient) Model() catwalk.Model {
574 return o.providerOptions.model(o.providerOptions.modelType)
575}