1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/message"
 16	"github.com/openai/openai-go"
 17	"github.com/openai/openai-go/option"
 18	"github.com/openai/openai-go/shared"
 19)
 20
 21type openaiClient struct {
 22	providerOptions providerClientOptions
 23	client          openai.Client
 24}
 25
 26type OpenAIClient ProviderClient
 27
 28func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 29	return &openaiClient{
 30		providerOptions: opts,
 31		client:          createOpenAIClient(opts),
 32	}
 33}
 34
 35func createOpenAIClient(opts providerClientOptions) openai.Client {
 36	openaiClientOptions := []option.RequestOption{}
 37	if opts.apiKey != "" {
 38		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 39	}
 40	if opts.baseURL != "" {
 41		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 42		if err == nil {
 43			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 44		}
 45	}
 46
 47	if opts.extraHeaders != nil {
 48		for key, value := range opts.extraHeaders {
 49			openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 50		}
 51	}
 52
 53	return openai.NewClient(openaiClientOptions...)
 54}
 55
 56func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 57	// Add system message first
 58	openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
 59
 60	for _, msg := range messages {
 61		switch msg.Role {
 62		case message.User:
 63			var content []openai.ChatCompletionContentPartUnionParam
 64			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 65			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 66			for _, binaryContent := range msg.BinaryContent() {
 67				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
 68				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 69
 70				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 71			}
 72
 73			openaiMessages = append(openaiMessages, openai.UserMessage(content))
 74
 75		case message.Assistant:
 76			assistantMsg := openai.ChatCompletionAssistantMessageParam{
 77				Role: "assistant",
 78			}
 79
 80			if msg.Content().String() != "" {
 81				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
 82					OfString: openai.String(msg.Content().String()),
 83				}
 84			}
 85
 86			if len(msg.ToolCalls()) > 0 {
 87				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
 88				for i, call := range msg.ToolCalls() {
 89					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
 90						ID:   call.ID,
 91						Type: "function",
 92						Function: openai.ChatCompletionMessageToolCallFunctionParam{
 93							Name:      call.Name,
 94							Arguments: call.Input,
 95						},
 96					}
 97				}
 98			}
 99
100			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
101				OfAssistant: &assistantMsg,
102			})
103
104		case message.Tool:
105			for _, result := range msg.ToolResults() {
106				openaiMessages = append(openaiMessages,
107					openai.ToolMessage(result.Content, result.ToolCallID),
108				)
109			}
110		}
111	}
112
113	return
114}
115
116func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
117	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
118
119	for i, tool := range tools {
120		info := tool.Info()
121		openaiTools[i] = openai.ChatCompletionToolParam{
122			Function: openai.FunctionDefinitionParam{
123				Name:        info.Name,
124				Description: openai.String(info.Description),
125				Parameters: openai.FunctionParameters{
126					"type":       "object",
127					"properties": info.Parameters,
128					"required":   info.Required,
129				},
130			},
131		}
132	}
133
134	return openaiTools
135}
136
137func (o *openaiClient) finishReason(reason string) message.FinishReason {
138	switch reason {
139	case "stop":
140		return message.FinishReasonEndTurn
141	case "length":
142		return message.FinishReasonMaxTokens
143	case "tool_calls":
144		return message.FinishReasonToolUse
145	default:
146		return message.FinishReasonUnknown
147	}
148}
149
150func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
151	model := o.providerOptions.model(o.providerOptions.modelType)
152	cfg := config.Get()
153
154	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
155	if o.providerOptions.modelType == config.SelectedModelTypeSmall {
156		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
157	}
158
159	reasoningEffort := modelConfig.ReasoningEffort
160
161	params := openai.ChatCompletionNewParams{
162		Model:    openai.ChatModel(model.ID),
163		Messages: messages,
164		Tools:    tools,
165	}
166
167	maxTokens := model.DefaultMaxTokens
168	if modelConfig.MaxTokens > 0 {
169		maxTokens = modelConfig.MaxTokens
170	}
171
172	// Override max tokens if set in provider options
173	if o.providerOptions.maxTokens > 0 {
174		maxTokens = o.providerOptions.maxTokens
175	}
176	if model.CanReason {
177		params.MaxCompletionTokens = openai.Int(maxTokens)
178		switch reasoningEffort {
179		case "low":
180			params.ReasoningEffort = shared.ReasoningEffortLow
181		case "medium":
182			params.ReasoningEffort = shared.ReasoningEffortMedium
183		case "high":
184			params.ReasoningEffort = shared.ReasoningEffortHigh
185		default:
186			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
187		}
188	} else {
189		params.MaxTokens = openai.Int(maxTokens)
190	}
191
192	return params
193}
194
195func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
196	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
197	cfg := config.Get()
198	if cfg.Options.Debug {
199		jsonData, _ := json.Marshal(params)
200		slog.Debug("Prepared messages", "messages", string(jsonData))
201	}
202	attempts := 0
203	for {
204		attempts++
205		openaiResponse, err := o.client.Chat.Completions.New(
206			ctx,
207			params,
208		)
209		// If there is an error we are going to see if we can retry the call
210		if err != nil {
211			retry, after, retryErr := o.shouldRetry(attempts, err)
212			if retryErr != nil {
213				return nil, retryErr
214			}
215			if retry {
216				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
217				select {
218				case <-ctx.Done():
219					return nil, ctx.Err()
220				case <-time.After(time.Duration(after) * time.Millisecond):
221					continue
222				}
223			}
224			return nil, retryErr
225		}
226
227		content := ""
228		if openaiResponse.Choices[0].Message.Content != "" {
229			content = openaiResponse.Choices[0].Message.Content
230		}
231
232		toolCalls := o.toolCalls(*openaiResponse)
233		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
234
235		if len(toolCalls) > 0 {
236			finishReason = message.FinishReasonToolUse
237		}
238
239		return &ProviderResponse{
240			Content:      content,
241			ToolCalls:    toolCalls,
242			Usage:        o.usage(*openaiResponse),
243			FinishReason: finishReason,
244		}, nil
245	}
246}
247
248func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
249	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
250	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
251		IncludeUsage: openai.Bool(true),
252	}
253
254	cfg := config.Get()
255	if cfg.Options.Debug {
256		jsonData, _ := json.Marshal(params)
257		slog.Debug("Prepared messages", "messages", string(jsonData))
258	}
259
260	attempts := 0
261	eventChan := make(chan ProviderEvent)
262
263	go func() {
264		for {
265			attempts++
266			openaiStream := o.client.Chat.Completions.NewStreaming(
267				ctx,
268				params,
269			)
270
271			acc := openai.ChatCompletionAccumulator{}
272			currentContent := ""
273			toolCalls := make([]message.ToolCall, 0)
274
275			var currentToolCallID string
276			var currentToolCall openai.ChatCompletionMessageToolCall
277			var msgToolCalls []openai.ChatCompletionMessageToolCall
278			for openaiStream.Next() {
279				chunk := openaiStream.Current()
280				acc.AddChunk(chunk)
281				// This fixes multiple tool calls for some providers
282				for _, choice := range chunk.Choices {
283					if choice.Delta.Content != "" {
284						eventChan <- ProviderEvent{
285							Type:    EventContentDelta,
286							Content: choice.Delta.Content,
287						}
288						currentContent += choice.Delta.Content
289					} else if len(choice.Delta.ToolCalls) > 0 {
290						toolCall := choice.Delta.ToolCalls[0]
291						// Detect tool use start
292						if currentToolCallID == "" {
293							if toolCall.ID != "" {
294								currentToolCallID = toolCall.ID
295								currentToolCall = openai.ChatCompletionMessageToolCall{
296									ID:   toolCall.ID,
297									Type: "function",
298									Function: openai.ChatCompletionMessageToolCallFunction{
299										Name:      toolCall.Function.Name,
300										Arguments: toolCall.Function.Arguments,
301									},
302								}
303							}
304						} else {
305							// Delta tool use
306							if toolCall.ID == "" {
307								currentToolCall.Function.Arguments += toolCall.Function.Arguments
308							} else {
309								// Detect new tool use
310								if toolCall.ID != currentToolCallID {
311									msgToolCalls = append(msgToolCalls, currentToolCall)
312									currentToolCallID = toolCall.ID
313									currentToolCall = openai.ChatCompletionMessageToolCall{
314										ID:   toolCall.ID,
315										Type: "function",
316										Function: openai.ChatCompletionMessageToolCallFunction{
317											Name:      toolCall.Function.Name,
318											Arguments: toolCall.Function.Arguments,
319										},
320									}
321								}
322							}
323						}
324					}
325					if choice.FinishReason == "tool_calls" {
326						msgToolCalls = append(msgToolCalls, currentToolCall)
327						acc.Choices[0].Message.ToolCalls = msgToolCalls
328					}
329				}
330			}
331
332			err := openaiStream.Err()
333			if err == nil || errors.Is(err, io.EOF) {
334				if cfg.Options.Debug {
335					jsonData, _ := json.Marshal(acc.ChatCompletion)
336					slog.Debug("Response", "messages", string(jsonData))
337				}
338
339				resultFinishReason := acc.ChatCompletion.Choices[0].FinishReason
340				if resultFinishReason == "" {
341					// If the finish reason is empty, we assume it was a successful completion
342					// INFO: this is happening for openrouter for some reason
343					resultFinishReason = "stop"
344				}
345				// Stream completed successfully
346				finishReason := o.finishReason(resultFinishReason)
347				if len(acc.Choices[0].Message.ToolCalls) > 0 {
348					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
349				}
350				if len(toolCalls) > 0 {
351					finishReason = message.FinishReasonToolUse
352				}
353
354				eventChan <- ProviderEvent{
355					Type: EventComplete,
356					Response: &ProviderResponse{
357						Content:      currentContent,
358						ToolCalls:    toolCalls,
359						Usage:        o.usage(acc.ChatCompletion),
360						FinishReason: finishReason,
361					},
362				}
363				close(eventChan)
364				return
365			}
366
367			// If there is an error we are going to see if we can retry the call
368			retry, after, retryErr := o.shouldRetry(attempts, err)
369			if retryErr != nil {
370				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
371				close(eventChan)
372				return
373			}
374			if retry {
375				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
376				select {
377				case <-ctx.Done():
378					// context cancelled
379					if ctx.Err() == nil {
380						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
381					}
382					close(eventChan)
383					return
384				case <-time.After(time.Duration(after) * time.Millisecond):
385					continue
386				}
387			}
388			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
389			close(eventChan)
390			return
391		}
392	}()
393
394	return eventChan
395}
396
397func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
398	var apiErr *openai.Error
399	if !errors.As(err, &apiErr) {
400		return false, 0, err
401	}
402
403	if attempts > maxRetries {
404		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
405	}
406
407	// Check for token expiration (401 Unauthorized)
408	if apiErr.StatusCode == 401 {
409		o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
410		if err != nil {
411			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
412		}
413		o.client = createOpenAIClient(o.providerOptions)
414		return true, 0, nil
415	}
416
417	if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
418		return false, 0, err
419	}
420
421	retryMs := 0
422	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
423
424	backoffMs := 2000 * (1 << (attempts - 1))
425	jitterMs := int(float64(backoffMs) * 0.2)
426	retryMs = backoffMs + jitterMs
427	if len(retryAfterValues) > 0 {
428		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
429			retryMs = retryMs * 1000
430		}
431	}
432	return true, int64(retryMs), nil
433}
434
435func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
436	var toolCalls []message.ToolCall
437
438	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
439		for _, call := range completion.Choices[0].Message.ToolCalls {
440			toolCall := message.ToolCall{
441				ID:       call.ID,
442				Name:     call.Function.Name,
443				Input:    call.Function.Arguments,
444				Type:     "function",
445				Finished: true,
446			}
447			toolCalls = append(toolCalls, toolCall)
448		}
449	}
450
451	return toolCalls
452}
453
454func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
455	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
456	inputTokens := completion.Usage.PromptTokens - cachedTokens
457
458	return TokenUsage{
459		InputTokens:         inputTokens,
460		OutputTokens:        completion.Usage.CompletionTokens,
461		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
462		CacheReadTokens:     cachedTokens,
463	}
464}
465
466func (a *openaiClient) Model() provider.Model {
467	return a.providerOptions.model(a.providerOptions.modelType)
468}