1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/openai/openai-go"
 18	"github.com/openai/openai-go/option"
 19	"github.com/openai/openai-go/shared"
 20)
 21
 22type openaiClient struct {
 23	providerOptions providerClientOptions
 24	client          openai.Client
 25}
 26
 27type OpenAIClient ProviderClient
 28
 29func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 30	return &openaiClient{
 31		providerOptions: opts,
 32		client:          createOpenAIClient(opts),
 33	}
 34}
 35
 36func createOpenAIClient(opts providerClientOptions) openai.Client {
 37	openaiClientOptions := []option.RequestOption{}
 38	if opts.apiKey != "" {
 39		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 40	}
 41	if opts.baseURL != "" {
 42		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 43		if err == nil {
 44			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 45		}
 46	}
 47
 48	for key, value := range opts.extraHeaders {
 49		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 50	}
 51
 52	for extraKey, extraValue := range opts.extraBody {
 53		openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
 54	}
 55
 56	return openai.NewClient(openaiClientOptions...)
 57}
 58
 59func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 60	isAnthropicModel := o.providerOptions.config.ID == string(catwalk.InferenceProviderOpenRouter) && strings.HasPrefix(o.Model().ID, "anthropic/")
 61	// Add system message first
 62	systemMessage := o.providerOptions.systemMessage
 63	if o.providerOptions.systemPromptPrefix != "" {
 64		systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
 65	}
 66
 67	systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
 68	if isAnthropicModel && !o.providerOptions.disableCache {
 69		systemTextBlock.SetExtraFields(
 70			map[string]any{
 71				"cache_control": map[string]string{
 72					"type": "ephemeral",
 73				},
 74			},
 75		)
 76	}
 77	var content []openai.ChatCompletionContentPartTextParam
 78	content = append(content, systemTextBlock)
 79	system := openai.SystemMessage(content)
 80	openaiMessages = append(openaiMessages, system)
 81
 82	for i, msg := range messages {
 83		cache := false
 84		if i > len(messages)-3 {
 85			cache = true
 86		}
 87		switch msg.Role {
 88		case message.User:
 89			var content []openai.ChatCompletionContentPartUnionParam
 90			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 91			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 92			for _, binaryContent := range msg.BinaryContent() {
 93				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
 94				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 95
 96				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 97			}
 98			if cache && !o.providerOptions.disableCache && isAnthropicModel {
 99				textBlock.SetExtraFields(map[string]any{
100					"cache_control": map[string]string{
101						"type": "ephemeral",
102					},
103				})
104			}
105
106			openaiMessages = append(openaiMessages, openai.UserMessage(content))
107
108		case message.Assistant:
109			assistantMsg := openai.ChatCompletionAssistantMessageParam{
110				Role: "assistant",
111			}
112
113			hasContent := false
114			if msg.Content().String() != "" {
115				hasContent = true
116				textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
117				if cache && !o.providerOptions.disableCache && isAnthropicModel {
118					textBlock.SetExtraFields(map[string]any{
119						"cache_control": map[string]string{
120							"type": "ephemeral",
121						},
122					})
123				}
124				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
125					OfArrayOfContentParts: []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
126						{
127							OfText: &textBlock,
128						},
129					},
130				}
131			}
132
133			if len(msg.ToolCalls()) > 0 {
134				hasContent = true
135				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
136				for i, call := range msg.ToolCalls() {
137					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
138						ID:   call.ID,
139						Type: "function",
140						Function: openai.ChatCompletionMessageToolCallFunctionParam{
141							Name:      call.Name,
142							Arguments: call.Input,
143						},
144					}
145				}
146			}
147			if !hasContent {
148				slog.Warn("There is a message without content, investigate, this should not happen")
149				continue
150			}
151
152			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
153				OfAssistant: &assistantMsg,
154			})
155
156		case message.Tool:
157			for _, result := range msg.ToolResults() {
158				openaiMessages = append(openaiMessages,
159					openai.ToolMessage(result.Content, result.ToolCallID),
160				)
161			}
162		}
163	}
164
165	return
166}
167
168func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
169	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
170
171	for i, tool := range tools {
172		info := tool.Info()
173		openaiTools[i] = openai.ChatCompletionToolParam{
174			Function: openai.FunctionDefinitionParam{
175				Name:        info.Name,
176				Description: openai.String(info.Description),
177				Parameters: openai.FunctionParameters{
178					"type":       "object",
179					"properties": info.Parameters,
180					"required":   info.Required,
181				},
182			},
183		}
184	}
185
186	return openaiTools
187}
188
189func (o *openaiClient) finishReason(reason string) message.FinishReason {
190	switch reason {
191	case "stop":
192		return message.FinishReasonEndTurn
193	case "length":
194		return message.FinishReasonMaxTokens
195	case "tool_calls":
196		return message.FinishReasonToolUse
197	default:
198		return message.FinishReasonUnknown
199	}
200}
201
202func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
203	model := o.providerOptions.model(o.providerOptions.modelType)
204	cfg := config.Get()
205
206	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
207	if o.providerOptions.modelType == config.SelectedModelTypeSmall {
208		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
209	}
210
211	reasoningEffort := modelConfig.ReasoningEffort
212
213	params := openai.ChatCompletionNewParams{
214		Model:    openai.ChatModel(model.ID),
215		Messages: messages,
216		Tools:    tools,
217	}
218
219	maxTokens := model.DefaultMaxTokens
220	if modelConfig.MaxTokens > 0 {
221		maxTokens = modelConfig.MaxTokens
222	}
223
224	// Override max tokens if set in provider options
225	if o.providerOptions.maxTokens > 0 {
226		maxTokens = o.providerOptions.maxTokens
227	}
228	if model.CanReason {
229		params.MaxCompletionTokens = openai.Int(maxTokens)
230		switch reasoningEffort {
231		case "low":
232			params.ReasoningEffort = shared.ReasoningEffortLow
233		case "medium":
234			params.ReasoningEffort = shared.ReasoningEffortMedium
235		case "high":
236			params.ReasoningEffort = shared.ReasoningEffortHigh
237		default:
238			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
239		}
240	} else {
241		params.MaxTokens = openai.Int(maxTokens)
242	}
243
244	return params
245}
246
247func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
248	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
249	cfg := config.Get()
250	if cfg.Options.Debug {
251		jsonData, _ := json.Marshal(params)
252		slog.Debug("Prepared messages", "messages", string(jsonData))
253	}
254	attempts := 0
255	for {
256		attempts++
257		openaiResponse, err := o.client.Chat.Completions.New(
258			ctx,
259			params,
260		)
261		// If there is an error we are going to see if we can retry the call
262		if err != nil {
263			retry, after, retryErr := o.shouldRetry(attempts, err)
264			if retryErr != nil {
265				return nil, retryErr
266			}
267			if retry {
268				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
269				select {
270				case <-ctx.Done():
271					return nil, ctx.Err()
272				case <-time.After(time.Duration(after) * time.Millisecond):
273					continue
274				}
275			}
276			return nil, retryErr
277		}
278
279		if len(openaiResponse.Choices) == 0 {
280			return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
281		}
282
283		content := ""
284		if openaiResponse.Choices[0].Message.Content != "" {
285			content = openaiResponse.Choices[0].Message.Content
286		}
287
288		toolCalls := o.toolCalls(*openaiResponse)
289		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
290
291		if len(toolCalls) > 0 {
292			finishReason = message.FinishReasonToolUse
293		}
294
295		return &ProviderResponse{
296			Content:      content,
297			ToolCalls:    toolCalls,
298			Usage:        o.usage(*openaiResponse),
299			FinishReason: finishReason,
300		}, nil
301	}
302}
303
304func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
305	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
306	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
307		IncludeUsage: openai.Bool(true),
308	}
309
310	cfg := config.Get()
311	if cfg.Options.Debug {
312		jsonData, _ := json.Marshal(params)
313		slog.Debug("Prepared messages", "messages", string(jsonData))
314	}
315
316	attempts := 0
317	eventChan := make(chan ProviderEvent)
318
319	go func() {
320		for {
321			attempts++
322			// Kujtim: fixes an issue with anthropig models on openrouter
323			if len(params.Tools) == 0 {
324				params.Tools = nil
325			}
326			openaiStream := o.client.Chat.Completions.NewStreaming(
327				ctx,
328				params,
329			)
330
331			acc := openai.ChatCompletionAccumulator{}
332			currentContent := ""
333			toolCalls := make([]message.ToolCall, 0)
334
335			var currentToolCallID string
336			var currentToolCall openai.ChatCompletionMessageToolCall
337			var msgToolCalls []openai.ChatCompletionMessageToolCall
338			currentToolIndex := 0
339			for openaiStream.Next() {
340				chunk := openaiStream.Current()
341				// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
342				if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
343					chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
344					currentToolIndex++
345				}
346				acc.AddChunk(chunk)
347				// This fixes multiple tool calls for some providers
348				for _, choice := range chunk.Choices {
349					if choice.Delta.Content != "" {
350						eventChan <- ProviderEvent{
351							Type:    EventContentDelta,
352							Content: choice.Delta.Content,
353						}
354						currentContent += choice.Delta.Content
355					} else if len(choice.Delta.ToolCalls) > 0 {
356						toolCall := choice.Delta.ToolCalls[0]
357						// Detect tool use start
358						if currentToolCallID == "" {
359							if toolCall.ID != "" {
360								currentToolCallID = toolCall.ID
361								eventChan <- ProviderEvent{
362									Type: EventToolUseStart,
363									ToolCall: &message.ToolCall{
364										ID:       toolCall.ID,
365										Name:     toolCall.Function.Name,
366										Finished: false,
367									},
368								}
369								currentToolCall = openai.ChatCompletionMessageToolCall{
370									ID:   toolCall.ID,
371									Type: "function",
372									Function: openai.ChatCompletionMessageToolCallFunction{
373										Name:      toolCall.Function.Name,
374										Arguments: toolCall.Function.Arguments,
375									},
376								}
377							}
378						} else {
379							// Delta tool use
380							if toolCall.ID == "" || toolCall.ID == currentToolCallID {
381								currentToolCall.Function.Arguments += toolCall.Function.Arguments
382							} else {
383								// Detect new tool use
384								if toolCall.ID != currentToolCallID {
385									msgToolCalls = append(msgToolCalls, currentToolCall)
386									currentToolCallID = toolCall.ID
387									eventChan <- ProviderEvent{
388										Type: EventToolUseStart,
389										ToolCall: &message.ToolCall{
390											ID:       toolCall.ID,
391											Name:     toolCall.Function.Name,
392											Finished: false,
393										},
394									}
395									currentToolCall = openai.ChatCompletionMessageToolCall{
396										ID:   toolCall.ID,
397										Type: "function",
398										Function: openai.ChatCompletionMessageToolCallFunction{
399											Name:      toolCall.Function.Name,
400											Arguments: toolCall.Function.Arguments,
401										},
402									}
403								}
404							}
405						}
406					}
407					// Kujtim: some models send finish stop even for tool calls
408					if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
409						msgToolCalls = append(msgToolCalls, currentToolCall)
410						if len(acc.Choices) > 0 {
411							acc.Choices[0].Message.ToolCalls = msgToolCalls
412						}
413					}
414				}
415			}
416
417			err := openaiStream.Err()
418			if err == nil || errors.Is(err, io.EOF) {
419				if cfg.Options.Debug {
420					jsonData, _ := json.Marshal(acc.ChatCompletion)
421					slog.Debug("Response", "messages", string(jsonData))
422				}
423
424				if len(acc.Choices) == 0 {
425					eventChan <- ProviderEvent{
426						Type:  EventError,
427						Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
428					}
429					return
430				}
431
432				resultFinishReason := acc.Choices[0].FinishReason
433				if resultFinishReason == "" {
434					// If the finish reason is empty, we assume it was a successful completion
435					// INFO: this is happening for openrouter for some reason
436					resultFinishReason = "stop"
437				}
438				// Stream completed successfully
439				finishReason := o.finishReason(resultFinishReason)
440				if len(acc.Choices[0].Message.ToolCalls) > 0 {
441					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
442				}
443				if len(toolCalls) > 0 {
444					finishReason = message.FinishReasonToolUse
445				}
446
447				eventChan <- ProviderEvent{
448					Type: EventComplete,
449					Response: &ProviderResponse{
450						Content:      currentContent,
451						ToolCalls:    toolCalls,
452						Usage:        o.usage(acc.ChatCompletion),
453						FinishReason: finishReason,
454					},
455				}
456				close(eventChan)
457				return
458			}
459
460			// If there is an error we are going to see if we can retry the call
461			retry, after, retryErr := o.shouldRetry(attempts, err)
462			if retryErr != nil {
463				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
464				close(eventChan)
465				return
466			}
467			if retry {
468				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
469				select {
470				case <-ctx.Done():
471					// context cancelled
472					if ctx.Err() == nil {
473						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
474					}
475					close(eventChan)
476					return
477				case <-time.After(time.Duration(after) * time.Millisecond):
478					continue
479				}
480			}
481			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
482			close(eventChan)
483			return
484		}
485	}()
486
487	return eventChan
488}
489
490func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
491	if attempts > maxRetries {
492		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
493	}
494	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
495		return false, 0, err
496	}
497	var apiErr *openai.Error
498	retryMs := 0
499	retryAfterValues := []string{}
500	if errors.As(err, &apiErr) {
501		// Check for token expiration (401 Unauthorized)
502		if apiErr.StatusCode == 401 {
503			o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
504			if err != nil {
505				return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
506			}
507			o.client = createOpenAIClient(o.providerOptions)
508			return true, 0, nil
509		}
510
511		if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
512			return false, 0, err
513		}
514
515		retryAfterValues = apiErr.Response.Header.Values("Retry-After")
516	}
517
518	if apiErr != nil {
519		slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type)
520		if len(retryAfterValues) > 0 {
521			slog.Warn("Retry-After header", "values", retryAfterValues)
522		}
523	} else {
524		slog.Warn("OpenAI API error", "error", err.Error())
525	}
526
527	backoffMs := 2000 * (1 << (attempts - 1))
528	jitterMs := int(float64(backoffMs) * 0.2)
529	retryMs = backoffMs + jitterMs
530	if len(retryAfterValues) > 0 {
531		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
532			retryMs = retryMs * 1000
533		}
534	}
535	return true, int64(retryMs), nil
536}
537
538func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
539	var toolCalls []message.ToolCall
540
541	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
542		for _, call := range completion.Choices[0].Message.ToolCalls {
543			toolCall := message.ToolCall{
544				ID:       call.ID,
545				Name:     call.Function.Name,
546				Input:    call.Function.Arguments,
547				Type:     "function",
548				Finished: true,
549			}
550			toolCalls = append(toolCalls, toolCall)
551		}
552	}
553
554	return toolCalls
555}
556
557func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
558	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
559	inputTokens := completion.Usage.PromptTokens - cachedTokens
560
561	return TokenUsage{
562		InputTokens:         inputTokens,
563		OutputTokens:        completion.Usage.CompletionTokens,
564		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
565		CacheReadTokens:     cachedTokens,
566	}
567}
568
569func (o *openaiClient) Model() catwalk.Model {
570	return o.providerOptions.model(o.providerOptions.modelType)
571}