openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/message"
 16	"github.com/openai/openai-go"
 17	"github.com/openai/openai-go/option"
 18	"github.com/openai/openai-go/shared"
 19)
 20
 21type openaiClient struct {
 22	providerOptions providerClientOptions
 23	client          openai.Client
 24}
 25
 26type OpenAIClient ProviderClient
 27
 28func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 29	return &openaiClient{
 30		providerOptions: opts,
 31		client:          createOpenAIClient(opts),
 32	}
 33}
 34
 35func createOpenAIClient(opts providerClientOptions) openai.Client {
 36	openaiClientOptions := []option.RequestOption{}
 37	if opts.apiKey != "" {
 38		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 39	}
 40	if opts.baseURL != "" {
 41		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 42		if err == nil {
 43			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 44		}
 45	}
 46
 47	if opts.extraHeaders != nil {
 48		for key, value := range opts.extraHeaders {
 49			openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 50		}
 51	}
 52
 53	return openai.NewClient(openaiClientOptions...)
 54}
 55
 56func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 57	// Add system message first
 58	openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
 59
 60	for _, msg := range messages {
 61		switch msg.Role {
 62		case message.User:
 63			var content []openai.ChatCompletionContentPartUnionParam
 64			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 65			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 66			for _, binaryContent := range msg.BinaryContent() {
 67				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
 68				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 69
 70				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 71			}
 72
 73			openaiMessages = append(openaiMessages, openai.UserMessage(content))
 74
 75		case message.Assistant:
 76			assistantMsg := openai.ChatCompletionAssistantMessageParam{
 77				Role: "assistant",
 78			}
 79
 80			if msg.Content().String() != "" {
 81				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
 82					OfString: openai.String(msg.Content().String()),
 83				}
 84			}
 85
 86			if len(msg.ToolCalls()) > 0 {
 87				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
 88				for i, call := range msg.ToolCalls() {
 89					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
 90						ID:   call.ID,
 91						Type: "function",
 92						Function: openai.ChatCompletionMessageToolCallFunctionParam{
 93							Name:      call.Name,
 94							Arguments: call.Input,
 95						},
 96					}
 97				}
 98			}
 99
100			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
101				OfAssistant: &assistantMsg,
102			})
103
104		case message.Tool:
105			for _, result := range msg.ToolResults() {
106				openaiMessages = append(openaiMessages,
107					openai.ToolMessage(result.Content, result.ToolCallID),
108				)
109			}
110		}
111	}
112
113	return
114}
115
116func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
117	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
118
119	for i, tool := range tools {
120		info := tool.Info()
121		openaiTools[i] = openai.ChatCompletionToolParam{
122			Function: openai.FunctionDefinitionParam{
123				Name:        info.Name,
124				Description: openai.String(info.Description),
125				Parameters: openai.FunctionParameters{
126					"type":       "object",
127					"properties": info.Parameters,
128					"required":   info.Required,
129				},
130			},
131		}
132	}
133
134	return openaiTools
135}
136
137func (o *openaiClient) finishReason(reason string) message.FinishReason {
138	switch reason {
139	case "stop":
140		return message.FinishReasonEndTurn
141	case "length":
142		return message.FinishReasonMaxTokens
143	case "tool_calls":
144		return message.FinishReasonToolUse
145	default:
146		return message.FinishReasonUnknown
147	}
148}
149
150func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
151	model := o.providerOptions.model(o.providerOptions.modelType)
152	cfg := config.Get()
153
154	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
155	if o.providerOptions.modelType == config.SelectedModelTypeSmall {
156		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
157	}
158
159	reasoningEffort := modelConfig.ReasoningEffort
160
161	params := openai.ChatCompletionNewParams{
162		Model:    openai.ChatModel(model.ID),
163		Messages: messages,
164		Tools:    tools,
165	}
166
167	maxTokens := model.DefaultMaxTokens
168	if modelConfig.MaxTokens > 0 {
169		maxTokens = modelConfig.MaxTokens
170	}
171
172	// Override max tokens if set in provider options
173	if o.providerOptions.maxTokens > 0 {
174		maxTokens = o.providerOptions.maxTokens
175	}
176	if model.CanReason {
177		params.MaxCompletionTokens = openai.Int(maxTokens)
178		switch reasoningEffort {
179		case "low":
180			params.ReasoningEffort = shared.ReasoningEffortLow
181		case "medium":
182			params.ReasoningEffort = shared.ReasoningEffortMedium
183		case "high":
184			params.ReasoningEffort = shared.ReasoningEffortHigh
185		default:
186			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
187		}
188	} else {
189		params.MaxTokens = openai.Int(maxTokens)
190	}
191
192	return params
193}
194
195func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
196	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
197	cfg := config.Get()
198	if cfg.Options.Debug {
199		jsonData, _ := json.Marshal(params)
200		slog.Debug("Prepared messages", "messages", string(jsonData))
201	}
202	attempts := 0
203	for {
204		attempts++
205		openaiResponse, err := o.client.Chat.Completions.New(
206			ctx,
207			params,
208		)
209		// If there is an error we are going to see if we can retry the call
210		if err != nil {
211			retry, after, retryErr := o.shouldRetry(attempts, err)
212			if retryErr != nil {
213				return nil, retryErr
214			}
215			if retry {
216				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
217				select {
218				case <-ctx.Done():
219					return nil, ctx.Err()
220				case <-time.After(time.Duration(after) * time.Millisecond):
221					continue
222				}
223			}
224			return nil, retryErr
225		}
226
227		if len(openaiResponse.Choices) == 0 {
228			return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
229		}
230
231		content := ""
232		if openaiResponse.Choices[0].Message.Content != "" {
233			content = openaiResponse.Choices[0].Message.Content
234		}
235
236		toolCalls := o.toolCalls(*openaiResponse)
237		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
238
239		if len(toolCalls) > 0 {
240			finishReason = message.FinishReasonToolUse
241		}
242
243		return &ProviderResponse{
244			Content:      content,
245			ToolCalls:    toolCalls,
246			Usage:        o.usage(*openaiResponse),
247			FinishReason: finishReason,
248		}, nil
249	}
250}
251
252func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
253	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
254	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
255		IncludeUsage: openai.Bool(true),
256	}
257
258	cfg := config.Get()
259	if cfg.Options.Debug {
260		jsonData, _ := json.Marshal(params)
261		slog.Debug("Prepared messages", "messages", string(jsonData))
262	}
263
264	attempts := 0
265	eventChan := make(chan ProviderEvent)
266
267	go func() {
268		for {
269			attempts++
270			openaiStream := o.client.Chat.Completions.NewStreaming(
271				ctx,
272				params,
273			)
274
275			acc := openai.ChatCompletionAccumulator{}
276			currentContent := ""
277			toolCalls := make([]message.ToolCall, 0)
278
279			var currentToolCallID string
280			var currentToolCall openai.ChatCompletionMessageToolCall
281			var msgToolCalls []openai.ChatCompletionMessageToolCall
282			for openaiStream.Next() {
283				chunk := openaiStream.Current()
284				acc.AddChunk(chunk)
285				// This fixes multiple tool calls for some providers
286				for _, choice := range chunk.Choices {
287					if choice.Delta.Content != "" {
288						eventChan <- ProviderEvent{
289							Type:    EventContentDelta,
290							Content: choice.Delta.Content,
291						}
292						currentContent += choice.Delta.Content
293					} else if len(choice.Delta.ToolCalls) > 0 {
294						toolCall := choice.Delta.ToolCalls[0]
295						// Detect tool use start
296						if currentToolCallID == "" {
297							if toolCall.ID != "" {
298								currentToolCallID = toolCall.ID
299								currentToolCall = openai.ChatCompletionMessageToolCall{
300									ID:   toolCall.ID,
301									Type: "function",
302									Function: openai.ChatCompletionMessageToolCallFunction{
303										Name:      toolCall.Function.Name,
304										Arguments: toolCall.Function.Arguments,
305									},
306								}
307							}
308						} else {
309							// Delta tool use
310							if toolCall.ID == "" {
311								currentToolCall.Function.Arguments += toolCall.Function.Arguments
312							} else {
313								// Detect new tool use
314								if toolCall.ID != currentToolCallID {
315									msgToolCalls = append(msgToolCalls, currentToolCall)
316									currentToolCallID = toolCall.ID
317									currentToolCall = openai.ChatCompletionMessageToolCall{
318										ID:   toolCall.ID,
319										Type: "function",
320										Function: openai.ChatCompletionMessageToolCallFunction{
321											Name:      toolCall.Function.Name,
322											Arguments: toolCall.Function.Arguments,
323										},
324									}
325								}
326							}
327						}
328					}
329					if choice.FinishReason == "tool_calls" {
330						msgToolCalls = append(msgToolCalls, currentToolCall)
331						if len(acc.Choices) > 0 {
332							acc.Choices[0].Message.ToolCalls = msgToolCalls
333						}
334					}
335				}
336			}
337
338			err := openaiStream.Err()
339			if err == nil || errors.Is(err, io.EOF) {
340				if cfg.Options.Debug {
341					jsonData, _ := json.Marshal(acc.ChatCompletion)
342					slog.Debug("Response", "messages", string(jsonData))
343				}
344
345				if len(acc.Choices) == 0 {
346					eventChan <- ProviderEvent{
347						Type:  EventError,
348						Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
349					}
350					return
351				}
352
353				resultFinishReason := acc.Choices[0].FinishReason
354				if resultFinishReason == "" {
355					// If the finish reason is empty, we assume it was a successful completion
356					// INFO: this is happening for openrouter for some reason
357					resultFinishReason = "stop"
358				}
359				// Stream completed successfully
360				finishReason := o.finishReason(resultFinishReason)
361				if len(acc.Choices[0].Message.ToolCalls) > 0 {
362					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
363				}
364				if len(toolCalls) > 0 {
365					finishReason = message.FinishReasonToolUse
366				}
367
368				eventChan <- ProviderEvent{
369					Type: EventComplete,
370					Response: &ProviderResponse{
371						Content:      currentContent,
372						ToolCalls:    toolCalls,
373						Usage:        o.usage(acc.ChatCompletion),
374						FinishReason: finishReason,
375					},
376				}
377				close(eventChan)
378				return
379			}
380
381			// If there is an error we are going to see if we can retry the call
382			retry, after, retryErr := o.shouldRetry(attempts, err)
383			if retryErr != nil {
384				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
385				close(eventChan)
386				return
387			}
388			if retry {
389				slog.Warn(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries))
390				select {
391				case <-ctx.Done():
392					// context cancelled
393					if ctx.Err() == nil {
394						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
395					}
396					close(eventChan)
397					return
398				case <-time.After(time.Duration(after) * time.Millisecond):
399					continue
400				}
401			}
402			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
403			close(eventChan)
404			return
405		}
406	}()
407
408	return eventChan
409}
410
411func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
412	var apiErr *openai.Error
413	if !errors.As(err, &apiErr) {
414		return false, 0, err
415	}
416
417	if attempts > maxRetries {
418		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
419	}
420
421	// Check for token expiration (401 Unauthorized)
422	if apiErr.StatusCode == 401 {
423		o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
424		if err != nil {
425			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
426		}
427		o.client = createOpenAIClient(o.providerOptions)
428		return true, 0, nil
429	}
430
431	if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
432		return false, 0, err
433	}
434
435	retryMs := 0
436	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
437
438	backoffMs := 2000 * (1 << (attempts - 1))
439	jitterMs := int(float64(backoffMs) * 0.2)
440	retryMs = backoffMs + jitterMs
441	if len(retryAfterValues) > 0 {
442		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
443			retryMs = retryMs * 1000
444		}
445	}
446	return true, int64(retryMs), nil
447}
448
449func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
450	var toolCalls []message.ToolCall
451
452	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
453		for _, call := range completion.Choices[0].Message.ToolCalls {
454			toolCall := message.ToolCall{
455				ID:       call.ID,
456				Name:     call.Function.Name,
457				Input:    call.Function.Arguments,
458				Type:     "function",
459				Finished: true,
460			}
461			toolCalls = append(toolCalls, toolCall)
462		}
463	}
464
465	return toolCalls
466}
467
468func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
469	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
470	inputTokens := completion.Usage.PromptTokens - cachedTokens
471
472	return TokenUsage{
473		InputTokens:         inputTokens,
474		OutputTokens:        completion.Usage.CompletionTokens,
475		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
476		CacheReadTokens:     cachedTokens,
477	}
478}
479
480func (o *openaiClient) Model() provider.Model {
481	return o.providerOptions.model(o.providerOptions.modelType)
482}