1package provider
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log/slog"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/llm/tools"
15 "github.com/charmbracelet/crush/internal/log"
16 "github.com/charmbracelet/crush/internal/message"
17 "github.com/openai/openai-go"
18 "github.com/openai/openai-go/option"
19 "github.com/openai/openai-go/packages/param"
20 "github.com/openai/openai-go/shared"
21)
22
23type openaiClient struct {
24 providerOptions providerClientOptions
25 client openai.Client
26}
27
28type OpenAIClient ProviderClient
29
30func newOpenAIClient(opts providerClientOptions) OpenAIClient {
31 return &openaiClient{
32 providerOptions: opts,
33 client: createOpenAIClient(opts),
34 }
35}
36
37func createOpenAIClient(opts providerClientOptions) openai.Client {
38 openaiClientOptions := []option.RequestOption{}
39 if opts.apiKey != "" {
40 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
41 }
42 if opts.baseURL != "" {
43 resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
44 if err == nil {
45 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
46 }
47 }
48
49 if config.Get().Options.Debug {
50 httpClient := log.NewHTTPClient()
51 openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(httpClient))
52 }
53
54 for key, value := range opts.extraHeaders {
55 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
56 }
57
58 for extraKey, extraValue := range opts.extraBody {
59 openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
60 }
61
62 return openai.NewClient(openaiClientOptions...)
63}
64
65func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
66 isAnthropicModel := o.providerOptions.config.ID == string(catwalk.InferenceProviderOpenRouter) && strings.HasPrefix(o.Model().ID, "anthropic/")
67 // Add system message first
68 systemMessage := o.providerOptions.systemMessage
69 if o.providerOptions.systemPromptPrefix != "" {
70 systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
71 }
72
73 system := openai.SystemMessage(systemMessage)
74 if isAnthropicModel && !o.providerOptions.disableCache {
75 systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
76 systemTextBlock.SetExtraFields(
77 map[string]any{
78 "cache_control": map[string]string{
79 "type": "ephemeral",
80 },
81 },
82 )
83 var content []openai.ChatCompletionContentPartTextParam
84 content = append(content, systemTextBlock)
85 system = openai.SystemMessage(content)
86 }
87 openaiMessages = append(openaiMessages, system)
88
89 for i, msg := range messages {
90 cache := false
91 if i > len(messages)-3 {
92 cache = true
93 }
94 switch msg.Role {
95 case message.User:
96 var content []openai.ChatCompletionContentPartUnionParam
97
98 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
99 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
100 hasBinaryContent := false
101 for _, binaryContent := range msg.BinaryContent() {
102 hasBinaryContent = true
103 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
104 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
105
106 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
107 }
108 if cache && !o.providerOptions.disableCache && isAnthropicModel {
109 textBlock.SetExtraFields(map[string]any{
110 "cache_control": map[string]string{
111 "type": "ephemeral",
112 },
113 })
114 }
115 if hasBinaryContent || (isAnthropicModel && !o.providerOptions.disableCache) {
116 openaiMessages = append(openaiMessages, openai.UserMessage(content))
117 } else {
118 openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
119 }
120
121 case message.Assistant:
122 assistantMsg := openai.ChatCompletionAssistantMessageParam{
123 Role: "assistant",
124 }
125
126 hasContent := false
127 if msg.Content().String() != "" {
128 hasContent = true
129 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
130 if cache && !o.providerOptions.disableCache && isAnthropicModel {
131 textBlock.SetExtraFields(map[string]any{
132 "cache_control": map[string]string{
133 "type": "ephemeral",
134 },
135 })
136 }
137 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
138 OfArrayOfContentParts: []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
139 {
140 OfText: &textBlock,
141 },
142 },
143 }
144 if !isAnthropicModel {
145 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
146 OfString: param.NewOpt(msg.Content().String()),
147 }
148 }
149 }
150
151 if len(msg.ToolCalls()) > 0 {
152 hasContent = true
153 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
154 for i, call := range msg.ToolCalls() {
155 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
156 ID: call.ID,
157 Type: "function",
158 Function: openai.ChatCompletionMessageToolCallFunctionParam{
159 Name: call.Name,
160 Arguments: call.Input,
161 },
162 }
163 }
164 }
165 if !hasContent {
166 slog.Warn("There is a message without content, investigate, this should not happen")
167 continue
168 }
169
170 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
171 OfAssistant: &assistantMsg,
172 })
173
174 case message.Tool:
175 for _, result := range msg.ToolResults() {
176 openaiMessages = append(openaiMessages,
177 openai.ToolMessage(result.Content, result.ToolCallID),
178 )
179 }
180 }
181 }
182
183 return
184}
185
186func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
187 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
188
189 for i, tool := range tools {
190 info := tool.Info()
191 openaiTools[i] = openai.ChatCompletionToolParam{
192 Function: openai.FunctionDefinitionParam{
193 Name: info.Name,
194 Description: openai.String(info.Description),
195 Parameters: openai.FunctionParameters{
196 "type": "object",
197 "properties": info.Parameters,
198 "required": info.Required,
199 },
200 },
201 }
202 }
203
204 return openaiTools
205}
206
207func (o *openaiClient) finishReason(reason string) message.FinishReason {
208 switch reason {
209 case "stop":
210 return message.FinishReasonEndTurn
211 case "length":
212 return message.FinishReasonMaxTokens
213 case "tool_calls":
214 return message.FinishReasonToolUse
215 default:
216 return message.FinishReasonUnknown
217 }
218}
219
220func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
221 model := o.providerOptions.model(o.providerOptions.modelType)
222 cfg := config.Get()
223
224 modelConfig := cfg.Models[config.SelectedModelTypeLarge]
225 if o.providerOptions.modelType == config.SelectedModelTypeSmall {
226 modelConfig = cfg.Models[config.SelectedModelTypeSmall]
227 }
228
229 reasoningEffort := modelConfig.ReasoningEffort
230
231 params := openai.ChatCompletionNewParams{
232 Model: openai.ChatModel(model.ID),
233 Messages: messages,
234 Tools: tools,
235 }
236
237 maxTokens := model.DefaultMaxTokens
238 if modelConfig.MaxTokens > 0 {
239 maxTokens = modelConfig.MaxTokens
240 }
241
242 // Override max tokens if set in provider options
243 if o.providerOptions.maxTokens > 0 {
244 maxTokens = o.providerOptions.maxTokens
245 }
246 if model.CanReason {
247 params.MaxCompletionTokens = openai.Int(maxTokens)
248 switch reasoningEffort {
249 case "low":
250 params.ReasoningEffort = shared.ReasoningEffortLow
251 case "medium":
252 params.ReasoningEffort = shared.ReasoningEffortMedium
253 case "high":
254 params.ReasoningEffort = shared.ReasoningEffortHigh
255 default:
256 params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
257 }
258 } else {
259 params.MaxTokens = openai.Int(maxTokens)
260 }
261
262 return params
263}
264
265func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
266 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
267 attempts := 0
268 for {
269 attempts++
270 openaiResponse, err := o.client.Chat.Completions.New(
271 ctx,
272 params,
273 )
274 // If there is an error we are going to see if we can retry the call
275 if err != nil {
276 retry, after, retryErr := o.shouldRetry(attempts, err)
277 if retryErr != nil {
278 return nil, retryErr
279 }
280 if retry {
281 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
282 select {
283 case <-ctx.Done():
284 return nil, ctx.Err()
285 case <-time.After(time.Duration(after) * time.Millisecond):
286 continue
287 }
288 }
289 return nil, retryErr
290 }
291
292 if len(openaiResponse.Choices) == 0 {
293 return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
294 }
295
296 content := ""
297 if openaiResponse.Choices[0].Message.Content != "" {
298 content = openaiResponse.Choices[0].Message.Content
299 }
300
301 toolCalls := o.toolCalls(*openaiResponse)
302 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
303
304 if len(toolCalls) > 0 {
305 finishReason = message.FinishReasonToolUse
306 }
307
308 return &ProviderResponse{
309 Content: content,
310 ToolCalls: toolCalls,
311 Usage: o.usage(*openaiResponse),
312 FinishReason: finishReason,
313 }, nil
314 }
315}
316
317func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
318 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
319 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
320 IncludeUsage: openai.Bool(true),
321 }
322
323 attempts := 0
324 eventChan := make(chan ProviderEvent)
325
326 go func() {
327 for {
328 attempts++
329 // Kujtim: fixes an issue with anthropig models on openrouter
330 if len(params.Tools) == 0 {
331 params.Tools = nil
332 }
333 openaiStream := o.client.Chat.Completions.NewStreaming(
334 ctx,
335 params,
336 )
337
338 acc := openai.ChatCompletionAccumulator{}
339 currentContent := ""
340 toolCalls := make([]message.ToolCall, 0)
341
342 var currentToolCallID string
343 var currentToolCall openai.ChatCompletionMessageToolCall
344 var msgToolCalls []openai.ChatCompletionMessageToolCall
345 currentToolIndex := 0
346 for openaiStream.Next() {
347 chunk := openaiStream.Current()
348 // Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
349 if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
350 chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
351 currentToolIndex++
352 }
353 acc.AddChunk(chunk)
354 // This fixes multiple tool calls for some providers
355 for _, choice := range chunk.Choices {
356 if choice.Delta.Content != "" {
357 eventChan <- ProviderEvent{
358 Type: EventContentDelta,
359 Content: choice.Delta.Content,
360 }
361 currentContent += choice.Delta.Content
362 } else if len(choice.Delta.ToolCalls) > 0 {
363 toolCall := choice.Delta.ToolCalls[0]
364 // Detect tool use start
365 if currentToolCallID == "" {
366 if toolCall.ID != "" {
367 currentToolCallID = toolCall.ID
368 eventChan <- ProviderEvent{
369 Type: EventToolUseStart,
370 ToolCall: &message.ToolCall{
371 ID: toolCall.ID,
372 Name: toolCall.Function.Name,
373 Finished: false,
374 },
375 }
376 currentToolCall = openai.ChatCompletionMessageToolCall{
377 ID: toolCall.ID,
378 Type: "function",
379 Function: openai.ChatCompletionMessageToolCallFunction{
380 Name: toolCall.Function.Name,
381 Arguments: toolCall.Function.Arguments,
382 },
383 }
384 }
385 } else {
386 // Delta tool use
387 if toolCall.ID == "" || toolCall.ID == currentToolCallID {
388 currentToolCall.Function.Arguments += toolCall.Function.Arguments
389 } else {
390 // Detect new tool use
391 if toolCall.ID != currentToolCallID {
392 msgToolCalls = append(msgToolCalls, currentToolCall)
393 currentToolCallID = toolCall.ID
394 eventChan <- ProviderEvent{
395 Type: EventToolUseStart,
396 ToolCall: &message.ToolCall{
397 ID: toolCall.ID,
398 Name: toolCall.Function.Name,
399 Finished: false,
400 },
401 }
402 currentToolCall = openai.ChatCompletionMessageToolCall{
403 ID: toolCall.ID,
404 Type: "function",
405 Function: openai.ChatCompletionMessageToolCallFunction{
406 Name: toolCall.Function.Name,
407 Arguments: toolCall.Function.Arguments,
408 },
409 }
410 }
411 }
412 }
413 }
414 // Kujtim: some models send finish stop even for tool calls
415 if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
416 msgToolCalls = append(msgToolCalls, currentToolCall)
417 if len(acc.Choices) > 0 {
418 acc.Choices[0].Message.ToolCalls = msgToolCalls
419 }
420 }
421 }
422 }
423
424 err := openaiStream.Err()
425 if err == nil || errors.Is(err, io.EOF) {
426 if len(acc.Choices) == 0 {
427 eventChan <- ProviderEvent{
428 Type: EventError,
429 Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
430 }
431 return
432 }
433
434 resultFinishReason := acc.Choices[0].FinishReason
435 if resultFinishReason == "" {
436 // If the finish reason is empty, we assume it was a successful completion
437 // INFO: this is happening for openrouter for some reason
438 resultFinishReason = "stop"
439 }
440 // Stream completed successfully
441 finishReason := o.finishReason(resultFinishReason)
442 if len(acc.Choices[0].Message.ToolCalls) > 0 {
443 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
444 }
445 if len(toolCalls) > 0 {
446 finishReason = message.FinishReasonToolUse
447 }
448
449 eventChan <- ProviderEvent{
450 Type: EventComplete,
451 Response: &ProviderResponse{
452 Content: currentContent,
453 ToolCalls: toolCalls,
454 Usage: o.usage(acc.ChatCompletion),
455 FinishReason: finishReason,
456 },
457 }
458 close(eventChan)
459 return
460 }
461
462 // If there is an error we are going to see if we can retry the call
463 retry, after, retryErr := o.shouldRetry(attempts, err)
464 if retryErr != nil {
465 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
466 close(eventChan)
467 return
468 }
469 if retry {
470 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
471 select {
472 case <-ctx.Done():
473 // context cancelled
474 if ctx.Err() == nil {
475 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
476 }
477 close(eventChan)
478 return
479 case <-time.After(time.Duration(after) * time.Millisecond):
480 continue
481 }
482 }
483 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
484 close(eventChan)
485 return
486 }
487 }()
488
489 return eventChan
490}
491
492func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
493 if attempts > maxRetries {
494 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
495 }
496 if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
497 return false, 0, err
498 }
499 var apiErr *openai.Error
500 retryMs := 0
501 retryAfterValues := []string{}
502 if errors.As(err, &apiErr) {
503 // Check for token expiration (401 Unauthorized)
504 if apiErr.StatusCode == 401 {
505 o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
506 if err != nil {
507 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
508 }
509 o.client = createOpenAIClient(o.providerOptions)
510 return true, 0, nil
511 }
512
513 if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
514 return false, 0, err
515 }
516
517 retryAfterValues = apiErr.Response.Header.Values("Retry-After")
518 }
519
520 if apiErr != nil {
521 slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type)
522 if len(retryAfterValues) > 0 {
523 slog.Warn("Retry-After header", "values", retryAfterValues)
524 }
525 } else {
526 slog.Error("OpenAI API error", "error", err.Error(), "attempt", attempts, "max_retries", maxRetries)
527 }
528
529 backoffMs := 2000 * (1 << (attempts - 1))
530 jitterMs := int(float64(backoffMs) * 0.2)
531 retryMs = backoffMs + jitterMs
532 if len(retryAfterValues) > 0 {
533 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
534 retryMs = retryMs * 1000
535 }
536 }
537 return true, int64(retryMs), nil
538}
539
540func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
541 var toolCalls []message.ToolCall
542
543 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
544 for _, call := range completion.Choices[0].Message.ToolCalls {
545 toolCall := message.ToolCall{
546 ID: call.ID,
547 Name: call.Function.Name,
548 Input: call.Function.Arguments,
549 Type: "function",
550 Finished: true,
551 }
552 toolCalls = append(toolCalls, toolCall)
553 }
554 }
555
556 return toolCalls
557}
558
559func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
560 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
561 inputTokens := completion.Usage.PromptTokens - cachedTokens
562
563 return TokenUsage{
564 InputTokens: inputTokens,
565 OutputTokens: completion.Usage.CompletionTokens,
566 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
567 CacheReadTokens: cachedTokens,
568 }
569}
570
571func (o *openaiClient) Model() catwalk.Model {
572 return o.providerOptions.model(o.providerOptions.modelType)
573}