1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/message"
 16	"github.com/openai/openai-go"
 17	"github.com/openai/openai-go/option"
 18	"github.com/openai/openai-go/shared"
 19)
 20
 21type openaiClient struct {
 22	providerOptions providerClientOptions
 23	client          openai.Client
 24}
 25
 26type OpenAIClient ProviderClient
 27
 28func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 29	return &openaiClient{
 30		providerOptions: opts,
 31		client:          createOpenAIClient(opts),
 32	}
 33}
 34
 35func createOpenAIClient(opts providerClientOptions) openai.Client {
 36	openaiClientOptions := []option.RequestOption{}
 37	if opts.apiKey != "" {
 38		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 39	}
 40	if opts.baseURL != "" {
 41		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 42		if err == nil {
 43			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 44		}
 45	}
 46
 47	for key, value := range opts.extraHeaders {
 48		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 49	}
 50
 51	for extraKey, extraValue := range opts.extraBody {
 52		openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
 53	}
 54
 55	return openai.NewClient(openaiClientOptions...)
 56}
 57
 58func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 59	// Add system message first
 60	openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
 61
 62	for _, msg := range messages {
 63		switch msg.Role {
 64		case message.User:
 65			var content []openai.ChatCompletionContentPartUnionParam
 66			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 67			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 68			for _, binaryContent := range msg.BinaryContent() {
 69				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
 70				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 71
 72				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 73			}
 74
 75			openaiMessages = append(openaiMessages, openai.UserMessage(content))
 76
 77		case message.Assistant:
 78			assistantMsg := openai.ChatCompletionAssistantMessageParam{
 79				Role: "assistant",
 80			}
 81
 82			hasContent := false
 83			if msg.Content().String() != "" {
 84				hasContent = true
 85				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
 86					OfString: openai.String(msg.Content().String()),
 87				}
 88			}
 89
 90			if len(msg.ToolCalls()) > 0 {
 91				hasContent = true
 92				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
 93				for i, call := range msg.ToolCalls() {
 94					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
 95						ID:   call.ID,
 96						Type: "function",
 97						Function: openai.ChatCompletionMessageToolCallFunctionParam{
 98							Name:      call.Name,
 99							Arguments: call.Input,
100						},
101					}
102				}
103			}
104			if !hasContent {
105				slog.Warn("There is a message without content, investigate, this should not happen")
106				continue
107			}
108
109			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
110				OfAssistant: &assistantMsg,
111			})
112
113		case message.Tool:
114			for _, result := range msg.ToolResults() {
115				openaiMessages = append(openaiMessages,
116					openai.ToolMessage(result.Content, result.ToolCallID),
117				)
118			}
119		}
120	}
121
122	return
123}
124
125func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
126	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
127
128	for i, tool := range tools {
129		info := tool.Info()
130		openaiTools[i] = openai.ChatCompletionToolParam{
131			Function: openai.FunctionDefinitionParam{
132				Name:        info.Name,
133				Description: openai.String(info.Description),
134				Parameters: openai.FunctionParameters{
135					"type":       "object",
136					"properties": info.Parameters,
137					"required":   info.Required,
138				},
139			},
140		}
141	}
142
143	return openaiTools
144}
145
146func (o *openaiClient) finishReason(reason string) message.FinishReason {
147	switch reason {
148	case "stop":
149		return message.FinishReasonEndTurn
150	case "length":
151		return message.FinishReasonMaxTokens
152	case "tool_calls":
153		return message.FinishReasonToolUse
154	default:
155		return message.FinishReasonUnknown
156	}
157}
158
159func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
160	model := o.providerOptions.model(o.providerOptions.modelType)
161	cfg := config.Get()
162
163	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
164	if o.providerOptions.modelType == config.SelectedModelTypeSmall {
165		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
166	}
167
168	reasoningEffort := modelConfig.ReasoningEffort
169
170	params := openai.ChatCompletionNewParams{
171		Model:    openai.ChatModel(model.ID),
172		Messages: messages,
173		Tools:    tools,
174	}
175
176	maxTokens := model.DefaultMaxTokens
177	if modelConfig.MaxTokens > 0 {
178		maxTokens = modelConfig.MaxTokens
179	}
180
181	// Override max tokens if set in provider options
182	if o.providerOptions.maxTokens > 0 {
183		maxTokens = o.providerOptions.maxTokens
184	}
185	if model.CanReason {
186		params.MaxCompletionTokens = openai.Int(maxTokens)
187		switch reasoningEffort {
188		case "low":
189			params.ReasoningEffort = shared.ReasoningEffortLow
190		case "medium":
191			params.ReasoningEffort = shared.ReasoningEffortMedium
192		case "high":
193			params.ReasoningEffort = shared.ReasoningEffortHigh
194		default:
195			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
196		}
197	} else {
198		params.MaxTokens = openai.Int(maxTokens)
199	}
200
201	return params
202}
203
204func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
205	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
206	cfg := config.Get()
207	if cfg.Options.Debug {
208		jsonData, _ := json.Marshal(params)
209		slog.Debug("Prepared messages", "messages", string(jsonData))
210	}
211	attempts := 0
212	for {
213		attempts++
214		openaiResponse, err := o.client.Chat.Completions.New(
215			ctx,
216			params,
217		)
218		// If there is an error we are going to see if we can retry the call
219		if err != nil {
220			retry, after, retryErr := o.shouldRetry(attempts, err)
221			if retryErr != nil {
222				return nil, retryErr
223			}
224			if retry {
225				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
226				select {
227				case <-ctx.Done():
228					return nil, ctx.Err()
229				case <-time.After(time.Duration(after) * time.Millisecond):
230					continue
231				}
232			}
233			return nil, retryErr
234		}
235
236		if len(openaiResponse.Choices) == 0 {
237			return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
238		}
239
240		content := ""
241		if openaiResponse.Choices[0].Message.Content != "" {
242			content = openaiResponse.Choices[0].Message.Content
243		}
244
245		toolCalls := o.toolCalls(*openaiResponse)
246		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
247
248		if len(toolCalls) > 0 {
249			finishReason = message.FinishReasonToolUse
250		}
251
252		return &ProviderResponse{
253			Content:      content,
254			ToolCalls:    toolCalls,
255			Usage:        o.usage(*openaiResponse),
256			FinishReason: finishReason,
257		}, nil
258	}
259}
260
261func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
262	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
263	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
264		IncludeUsage: openai.Bool(true),
265	}
266
267	cfg := config.Get()
268	if cfg.Options.Debug {
269		jsonData, _ := json.Marshal(params)
270		slog.Debug("Prepared messages", "messages", string(jsonData))
271	}
272
273	attempts := 0
274	eventChan := make(chan ProviderEvent)
275
276	go func() {
277		for {
278			attempts++
279			openaiStream := o.client.Chat.Completions.NewStreaming(
280				ctx,
281				params,
282			)
283
284			acc := openai.ChatCompletionAccumulator{}
285			currentContent := ""
286			toolCalls := make([]message.ToolCall, 0)
287
288			var currentToolCallID string
289			var currentToolCall openai.ChatCompletionMessageToolCall
290			var msgToolCalls []openai.ChatCompletionMessageToolCall
291			for openaiStream.Next() {
292				chunk := openaiStream.Current()
293				acc.AddChunk(chunk)
294				// This fixes multiple tool calls for some providers
295				for _, choice := range chunk.Choices {
296					if choice.Delta.Content != "" {
297						eventChan <- ProviderEvent{
298							Type:    EventContentDelta,
299							Content: choice.Delta.Content,
300						}
301						currentContent += choice.Delta.Content
302					} else if len(choice.Delta.ToolCalls) > 0 {
303						toolCall := choice.Delta.ToolCalls[0]
304						// Detect tool use start
305						if currentToolCallID == "" {
306							if toolCall.ID != "" {
307								currentToolCallID = toolCall.ID
308								currentToolCall = openai.ChatCompletionMessageToolCall{
309									ID:   toolCall.ID,
310									Type: "function",
311									Function: openai.ChatCompletionMessageToolCallFunction{
312										Name:      toolCall.Function.Name,
313										Arguments: toolCall.Function.Arguments,
314									},
315								}
316							}
317						} else {
318							// Delta tool use
319							if toolCall.ID == "" {
320								currentToolCall.Function.Arguments += toolCall.Function.Arguments
321							} else {
322								// Detect new tool use
323								if toolCall.ID != currentToolCallID {
324									msgToolCalls = append(msgToolCalls, currentToolCall)
325									currentToolCallID = toolCall.ID
326									currentToolCall = openai.ChatCompletionMessageToolCall{
327										ID:   toolCall.ID,
328										Type: "function",
329										Function: openai.ChatCompletionMessageToolCallFunction{
330											Name:      toolCall.Function.Name,
331											Arguments: toolCall.Function.Arguments,
332										},
333									}
334								}
335							}
336						}
337					}
338					if choice.FinishReason == "tool_calls" {
339						msgToolCalls = append(msgToolCalls, currentToolCall)
340						if len(acc.Choices) > 0 {
341							acc.Choices[0].Message.ToolCalls = msgToolCalls
342						}
343					}
344				}
345			}
346
347			err := openaiStream.Err()
348			if err == nil || errors.Is(err, io.EOF) {
349				if cfg.Options.Debug {
350					jsonData, _ := json.Marshal(acc.ChatCompletion)
351					slog.Debug("Response", "messages", string(jsonData))
352				}
353
354				if len(acc.Choices) == 0 {
355					eventChan <- ProviderEvent{
356						Type:  EventError,
357						Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
358					}
359					return
360				}
361
362				resultFinishReason := acc.Choices[0].FinishReason
363				if resultFinishReason == "" {
364					// If the finish reason is empty, we assume it was a successful completion
365					// INFO: this is happening for openrouter for some reason
366					resultFinishReason = "stop"
367				}
368				// Stream completed successfully
369				finishReason := o.finishReason(resultFinishReason)
370				if len(acc.Choices[0].Message.ToolCalls) > 0 {
371					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
372				}
373				if len(toolCalls) > 0 {
374					finishReason = message.FinishReasonToolUse
375				}
376
377				eventChan <- ProviderEvent{
378					Type: EventComplete,
379					Response: &ProviderResponse{
380						Content:      currentContent,
381						ToolCalls:    toolCalls,
382						Usage:        o.usage(acc.ChatCompletion),
383						FinishReason: finishReason,
384					},
385				}
386				close(eventChan)
387				return
388			}
389
390			// If there is an error we are going to see if we can retry the call
391			retry, after, retryErr := o.shouldRetry(attempts, err)
392			if retryErr != nil {
393				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
394				close(eventChan)
395				return
396			}
397			if retry {
398				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
399				select {
400				case <-ctx.Done():
401					// context cancelled
402					if ctx.Err() == nil {
403						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
404					}
405					close(eventChan)
406					return
407				case <-time.After(time.Duration(after) * time.Millisecond):
408					continue
409				}
410			}
411			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
412			close(eventChan)
413			return
414		}
415	}()
416
417	return eventChan
418}
419
420func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
421	var apiErr *openai.Error
422	if !errors.As(err, &apiErr) {
423		return false, 0, err
424	}
425
426	if attempts > maxRetries {
427		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
428	}
429
430	// Check for token expiration (401 Unauthorized)
431	if apiErr.StatusCode == 401 {
432		o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
433		if err != nil {
434			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
435		}
436		o.client = createOpenAIClient(o.providerOptions)
437		return true, 0, nil
438	}
439
440	if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
441		return false, 0, err
442	}
443
444	retryMs := 0
445	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
446
447	backoffMs := 2000 * (1 << (attempts - 1))
448	jitterMs := int(float64(backoffMs) * 0.2)
449	retryMs = backoffMs + jitterMs
450	if len(retryAfterValues) > 0 {
451		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
452			retryMs = retryMs * 1000
453		}
454	}
455	return true, int64(retryMs), nil
456}
457
458func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
459	var toolCalls []message.ToolCall
460
461	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
462		for _, call := range completion.Choices[0].Message.ToolCalls {
463			toolCall := message.ToolCall{
464				ID:       call.ID,
465				Name:     call.Function.Name,
466				Input:    call.Function.Arguments,
467				Type:     "function",
468				Finished: true,
469			}
470			toolCalls = append(toolCalls, toolCall)
471		}
472	}
473
474	return toolCalls
475}
476
477func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
478	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
479	inputTokens := completion.Usage.PromptTokens - cachedTokens
480
481	return TokenUsage{
482		InputTokens:         inputTokens,
483		OutputTokens:        completion.Usage.CompletionTokens,
484		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
485		CacheReadTokens:     cachedTokens,
486	}
487}
488
489func (o *openaiClient) Model() provider.Model {
490	return o.providerOptions.model(o.providerOptions.modelType)
491}