openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/message"
 16	"github.com/openai/openai-go"
 17	"github.com/openai/openai-go/option"
 18	"github.com/openai/openai-go/shared"
 19)
 20
 21type openaiClient struct {
 22	providerOptions providerClientOptions
 23	client          openai.Client
 24}
 25
 26type OpenAIClient ProviderClient
 27
 28func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 29	return &openaiClient{
 30		providerOptions: opts,
 31		client:          createOpenAIClient(opts),
 32	}
 33}
 34
 35func createOpenAIClient(opts providerClientOptions) openai.Client {
 36	openaiClientOptions := []option.RequestOption{}
 37	if opts.apiKey != "" {
 38		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 39	}
 40	if opts.baseURL != "" {
 41		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 42		if err == nil {
 43			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 44		}
 45	}
 46
 47	if opts.extraHeaders != nil {
 48		for key, value := range opts.extraHeaders {
 49			openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 50		}
 51	}
 52
 53	return openai.NewClient(openaiClientOptions...)
 54}
 55
 56func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 57	// Add system message first
 58	openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
 59
 60	for _, msg := range messages {
 61		switch msg.Role {
 62		case message.User:
 63			var content []openai.ChatCompletionContentPartUnionParam
 64			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 65			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 66			for _, binaryContent := range msg.BinaryContent() {
 67				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
 68				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 69
 70				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 71			}
 72
 73			openaiMessages = append(openaiMessages, openai.UserMessage(content))
 74
 75		case message.Assistant:
 76			assistantMsg := openai.ChatCompletionAssistantMessageParam{
 77				Role: "assistant",
 78			}
 79
 80			hasContent := false
 81			if msg.Content().String() != "" {
 82				hasContent = true
 83				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
 84					OfString: openai.String(msg.Content().String()),
 85				}
 86			}
 87
 88			if len(msg.ToolCalls()) > 0 {
 89				hasContent = true
 90				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
 91				for i, call := range msg.ToolCalls() {
 92					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
 93						ID:   call.ID,
 94						Type: "function",
 95						Function: openai.ChatCompletionMessageToolCallFunctionParam{
 96							Name:      call.Name,
 97							Arguments: call.Input,
 98						},
 99					}
100				}
101			}
102			if !hasContent {
103				slog.Warn("There is a message without content, investigate, this should not happen")
104				continue
105			}
106
107			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
108				OfAssistant: &assistantMsg,
109			})
110
111		case message.Tool:
112			for _, result := range msg.ToolResults() {
113				openaiMessages = append(openaiMessages,
114					openai.ToolMessage(result.Content, result.ToolCallID),
115				)
116			}
117		}
118	}
119
120	return
121}
122
123func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
124	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
125
126	for i, tool := range tools {
127		info := tool.Info()
128		openaiTools[i] = openai.ChatCompletionToolParam{
129			Function: openai.FunctionDefinitionParam{
130				Name:        info.Name,
131				Description: openai.String(info.Description),
132				Parameters: openai.FunctionParameters{
133					"type":       "object",
134					"properties": info.Parameters,
135					"required":   info.Required,
136				},
137			},
138		}
139	}
140
141	return openaiTools
142}
143
144func (o *openaiClient) finishReason(reason string) message.FinishReason {
145	switch reason {
146	case "stop":
147		return message.FinishReasonEndTurn
148	case "length":
149		return message.FinishReasonMaxTokens
150	case "tool_calls":
151		return message.FinishReasonToolUse
152	default:
153		return message.FinishReasonUnknown
154	}
155}
156
157func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
158	model := o.providerOptions.model(o.providerOptions.modelType)
159	cfg := config.Get()
160
161	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
162	if o.providerOptions.modelType == config.SelectedModelTypeSmall {
163		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
164	}
165
166	reasoningEffort := modelConfig.ReasoningEffort
167
168	params := openai.ChatCompletionNewParams{
169		Model:    openai.ChatModel(model.ID),
170		Messages: messages,
171		Tools:    tools,
172	}
173
174	maxTokens := model.DefaultMaxTokens
175	if modelConfig.MaxTokens > 0 {
176		maxTokens = modelConfig.MaxTokens
177	}
178
179	// Override max tokens if set in provider options
180	if o.providerOptions.maxTokens > 0 {
181		maxTokens = o.providerOptions.maxTokens
182	}
183	if model.CanReason {
184		params.MaxCompletionTokens = openai.Int(maxTokens)
185		switch reasoningEffort {
186		case "low":
187			params.ReasoningEffort = shared.ReasoningEffortLow
188		case "medium":
189			params.ReasoningEffort = shared.ReasoningEffortMedium
190		case "high":
191			params.ReasoningEffort = shared.ReasoningEffortHigh
192		default:
193			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
194		}
195	} else {
196		params.MaxTokens = openai.Int(maxTokens)
197	}
198
199	return params
200}
201
202func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
203	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
204	cfg := config.Get()
205	if cfg.Options.Debug {
206		jsonData, _ := json.Marshal(params)
207		slog.Debug("Prepared messages", "messages", string(jsonData))
208	}
209	attempts := 0
210	for {
211		attempts++
212		openaiResponse, err := o.client.Chat.Completions.New(
213			ctx,
214			params,
215		)
216		// If there is an error we are going to see if we can retry the call
217		if err != nil {
218			retry, after, retryErr := o.shouldRetry(attempts, err)
219			if retryErr != nil {
220				return nil, retryErr
221			}
222			if retry {
223				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
224				select {
225				case <-ctx.Done():
226					return nil, ctx.Err()
227				case <-time.After(time.Duration(after) * time.Millisecond):
228					continue
229				}
230			}
231			return nil, retryErr
232		}
233
234		if len(openaiResponse.Choices) == 0 {
235			return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
236		}
237
238		content := ""
239		if openaiResponse.Choices[0].Message.Content != "" {
240			content = openaiResponse.Choices[0].Message.Content
241		}
242
243		toolCalls := o.toolCalls(*openaiResponse)
244		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
245
246		if len(toolCalls) > 0 {
247			finishReason = message.FinishReasonToolUse
248		}
249
250		return &ProviderResponse{
251			Content:      content,
252			ToolCalls:    toolCalls,
253			Usage:        o.usage(*openaiResponse),
254			FinishReason: finishReason,
255		}, nil
256	}
257}
258
259func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
260	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
261	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
262		IncludeUsage: openai.Bool(true),
263	}
264
265	cfg := config.Get()
266	if cfg.Options.Debug {
267		jsonData, _ := json.Marshal(params)
268		slog.Debug("Prepared messages", "messages", string(jsonData))
269	}
270
271	attempts := 0
272	eventChan := make(chan ProviderEvent)
273
274	go func() {
275		for {
276			attempts++
277			openaiStream := o.client.Chat.Completions.NewStreaming(
278				ctx,
279				params,
280			)
281
282			acc := openai.ChatCompletionAccumulator{}
283			currentContent := ""
284			toolCalls := make([]message.ToolCall, 0)
285
286			var currentToolCallID string
287			var currentToolCall openai.ChatCompletionMessageToolCall
288			var msgToolCalls []openai.ChatCompletionMessageToolCall
289			for openaiStream.Next() {
290				chunk := openaiStream.Current()
291				acc.AddChunk(chunk)
292				// This fixes multiple tool calls for some providers
293				for _, choice := range chunk.Choices {
294					if choice.Delta.Content != "" {
295						eventChan <- ProviderEvent{
296							Type:    EventContentDelta,
297							Content: choice.Delta.Content,
298						}
299						currentContent += choice.Delta.Content
300					} else if len(choice.Delta.ToolCalls) > 0 {
301						toolCall := choice.Delta.ToolCalls[0]
302						// Detect tool use start
303						if currentToolCallID == "" {
304							if toolCall.ID != "" {
305								currentToolCallID = toolCall.ID
306								currentToolCall = openai.ChatCompletionMessageToolCall{
307									ID:   toolCall.ID,
308									Type: "function",
309									Function: openai.ChatCompletionMessageToolCallFunction{
310										Name:      toolCall.Function.Name,
311										Arguments: toolCall.Function.Arguments,
312									},
313								}
314							}
315						} else {
316							// Delta tool use
317							if toolCall.ID == "" {
318								currentToolCall.Function.Arguments += toolCall.Function.Arguments
319							} else {
320								// Detect new tool use
321								if toolCall.ID != currentToolCallID {
322									msgToolCalls = append(msgToolCalls, currentToolCall)
323									currentToolCallID = toolCall.ID
324									currentToolCall = openai.ChatCompletionMessageToolCall{
325										ID:   toolCall.ID,
326										Type: "function",
327										Function: openai.ChatCompletionMessageToolCallFunction{
328											Name:      toolCall.Function.Name,
329											Arguments: toolCall.Function.Arguments,
330										},
331									}
332								}
333							}
334						}
335					}
336					if choice.FinishReason == "tool_calls" {
337						msgToolCalls = append(msgToolCalls, currentToolCall)
338						if len(acc.Choices) > 0 {
339							acc.Choices[0].Message.ToolCalls = msgToolCalls
340						}
341					}
342				}
343			}
344
345			err := openaiStream.Err()
346			if err == nil || errors.Is(err, io.EOF) {
347				if cfg.Options.Debug {
348					jsonData, _ := json.Marshal(acc.ChatCompletion)
349					slog.Debug("Response", "messages", string(jsonData))
350				}
351
352				if len(acc.Choices) == 0 {
353					eventChan <- ProviderEvent{
354						Type:  EventError,
355						Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
356					}
357					return
358				}
359
360				resultFinishReason := acc.Choices[0].FinishReason
361				if resultFinishReason == "" {
362					// If the finish reason is empty, we assume it was a successful completion
363					// INFO: this is happening for openrouter for some reason
364					resultFinishReason = "stop"
365				}
366				// Stream completed successfully
367				finishReason := o.finishReason(resultFinishReason)
368				if len(acc.Choices[0].Message.ToolCalls) > 0 {
369					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
370				}
371				if len(toolCalls) > 0 {
372					finishReason = message.FinishReasonToolUse
373				}
374
375				eventChan <- ProviderEvent{
376					Type: EventComplete,
377					Response: &ProviderResponse{
378						Content:      currentContent,
379						ToolCalls:    toolCalls,
380						Usage:        o.usage(acc.ChatCompletion),
381						FinishReason: finishReason,
382					},
383				}
384				close(eventChan)
385				return
386			}
387
388			// If there is an error we are going to see if we can retry the call
389			retry, after, retryErr := o.shouldRetry(attempts, err)
390			if retryErr != nil {
391				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
392				close(eventChan)
393				return
394			}
395			if retry {
396				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
397				select {
398				case <-ctx.Done():
399					// context cancelled
400					if ctx.Err() == nil {
401						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
402					}
403					close(eventChan)
404					return
405				case <-time.After(time.Duration(after) * time.Millisecond):
406					continue
407				}
408			}
409			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
410			close(eventChan)
411			return
412		}
413	}()
414
415	return eventChan
416}
417
418func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
419	var apiErr *openai.Error
420	if !errors.As(err, &apiErr) {
421		return false, 0, err
422	}
423
424	if attempts > maxRetries {
425		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
426	}
427
428	// Check for token expiration (401 Unauthorized)
429	if apiErr.StatusCode == 401 {
430		o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
431		if err != nil {
432			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
433		}
434		o.client = createOpenAIClient(o.providerOptions)
435		return true, 0, nil
436	}
437
438	if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
439		return false, 0, err
440	}
441
442	retryMs := 0
443	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
444
445	backoffMs := 2000 * (1 << (attempts - 1))
446	jitterMs := int(float64(backoffMs) * 0.2)
447	retryMs = backoffMs + jitterMs
448	if len(retryAfterValues) > 0 {
449		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
450			retryMs = retryMs * 1000
451		}
452	}
453	return true, int64(retryMs), nil
454}
455
456func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
457	var toolCalls []message.ToolCall
458
459	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
460		for _, call := range completion.Choices[0].Message.ToolCalls {
461			toolCall := message.ToolCall{
462				ID:       call.ID,
463				Name:     call.Function.Name,
464				Input:    call.Function.Arguments,
465				Type:     "function",
466				Finished: true,
467			}
468			toolCalls = append(toolCalls, toolCall)
469		}
470	}
471
472	return toolCalls
473}
474
475func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
476	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
477	inputTokens := completion.Usage.PromptTokens - cachedTokens
478
479	return TokenUsage{
480		InputTokens:         inputTokens,
481		OutputTokens:        completion.Usage.CompletionTokens,
482		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
483		CacheReadTokens:     cachedTokens,
484	}
485}
486
487func (o *openaiClient) Model() provider.Model {
488	return o.providerOptions.model(o.providerOptions.modelType)
489}