openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"log/slog"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/log"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/openai/openai-go"
 18	"github.com/openai/openai-go/option"
 19	"github.com/openai/openai-go/packages/param"
 20	"github.com/openai/openai-go/shared"
 21)
 22
 23type openaiClient struct {
 24	providerOptions providerClientOptions
 25	client          openai.Client
 26}
 27
 28type OpenAIClient ProviderClient
 29
 30func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 31	return &openaiClient{
 32		providerOptions: opts,
 33		client:          createOpenAIClient(opts),
 34	}
 35}
 36
 37func createOpenAIClient(opts providerClientOptions) openai.Client {
 38	openaiClientOptions := []option.RequestOption{}
 39	if opts.apiKey != "" {
 40		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 41	}
 42	if opts.baseURL != "" {
 43		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 44		if err == nil {
 45			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 46		}
 47	}
 48
 49	if config.Get().Options.Debug {
 50		httpClient := log.NewHTTPClient()
 51		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(httpClient))
 52	}
 53
 54	for key, value := range opts.extraHeaders {
 55		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 56	}
 57
 58	for extraKey, extraValue := range opts.extraBody {
 59		openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
 60	}
 61
 62	return openai.NewClient(openaiClientOptions...)
 63}
 64
 65func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 66	isAnthropicModel := o.providerOptions.config.ID == string(catwalk.InferenceProviderOpenRouter) && strings.HasPrefix(o.Model().ID, "anthropic/")
 67	// Add system message first
 68	systemMessage := o.providerOptions.systemMessage
 69	if o.providerOptions.systemPromptPrefix != "" {
 70		systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
 71	}
 72
 73	systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
 74	if isAnthropicModel && !o.providerOptions.disableCache {
 75		systemTextBlock.SetExtraFields(
 76			map[string]any{
 77				"cache_control": map[string]string{
 78					"type": "ephemeral",
 79				},
 80			},
 81		)
 82	}
 83	var content []openai.ChatCompletionContentPartTextParam
 84	content = append(content, systemTextBlock)
 85	system := openai.SystemMessage(content)
 86	openaiMessages = append(openaiMessages, system)
 87
 88	for i, msg := range messages {
 89		cache := false
 90		if i > len(messages)-3 {
 91			cache = true
 92		}
 93		switch msg.Role {
 94		case message.User:
 95			var content []openai.ChatCompletionContentPartUnionParam
 96			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 97			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 98			for _, binaryContent := range msg.BinaryContent() {
 99				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
100				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
101
102				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
103			}
104			if cache && !o.providerOptions.disableCache && isAnthropicModel {
105				textBlock.SetExtraFields(map[string]any{
106					"cache_control": map[string]string{
107						"type": "ephemeral",
108					},
109				})
110			}
111
112			openaiMessages = append(openaiMessages, openai.UserMessage(content))
113
114		case message.Assistant:
115			assistantMsg := openai.ChatCompletionAssistantMessageParam{
116				Role: "assistant",
117			}
118
119			hasContent := false
120			if msg.Content().String() != "" {
121				hasContent = true
122				textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
123				if cache && !o.providerOptions.disableCache && isAnthropicModel {
124					textBlock.SetExtraFields(map[string]any{
125						"cache_control": map[string]string{
126							"type": "ephemeral",
127						},
128					})
129				}
130				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
131					OfArrayOfContentParts: []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
132						{
133							OfText: &textBlock,
134						},
135					},
136				}
137			}
138
139			if len(msg.ToolCalls()) > 0 {
140				hasContent = true
141				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
142					OfString: param.NewOpt(msg.Content().String()),
143				}
144				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
145				for i, call := range msg.ToolCalls() {
146					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
147						ID:   call.ID,
148						Type: "function",
149						Function: openai.ChatCompletionMessageToolCallFunctionParam{
150							Name:      call.Name,
151							Arguments: call.Input,
152						},
153					}
154				}
155			}
156			if !hasContent {
157				slog.Warn("There is a message without content, investigate, this should not happen")
158				continue
159			}
160
161			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
162				OfAssistant: &assistantMsg,
163			})
164
165		case message.Tool:
166			for _, result := range msg.ToolResults() {
167				openaiMessages = append(openaiMessages,
168					openai.ToolMessage(result.Content, result.ToolCallID),
169				)
170			}
171		}
172	}
173
174	return
175}
176
177func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
178	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
179
180	for i, tool := range tools {
181		info := tool.Info()
182		openaiTools[i] = openai.ChatCompletionToolParam{
183			Function: openai.FunctionDefinitionParam{
184				Name:        info.Name,
185				Description: openai.String(info.Description),
186				Parameters: openai.FunctionParameters{
187					"type":       "object",
188					"properties": info.Parameters,
189					"required":   info.Required,
190				},
191			},
192		}
193	}
194
195	return openaiTools
196}
197
198func (o *openaiClient) finishReason(reason string) message.FinishReason {
199	switch reason {
200	case "stop":
201		return message.FinishReasonEndTurn
202	case "length":
203		return message.FinishReasonMaxTokens
204	case "tool_calls":
205		return message.FinishReasonToolUse
206	default:
207		return message.FinishReasonUnknown
208	}
209}
210
211func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
212	model := o.providerOptions.model(o.providerOptions.modelType)
213	cfg := config.Get()
214
215	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
216	if o.providerOptions.modelType == config.SelectedModelTypeSmall {
217		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
218	}
219
220	reasoningEffort := modelConfig.ReasoningEffort
221
222	params := openai.ChatCompletionNewParams{
223		Model:    openai.ChatModel(model.ID),
224		Messages: messages,
225		Tools:    tools,
226	}
227
228	maxTokens := model.DefaultMaxTokens
229	if modelConfig.MaxTokens > 0 {
230		maxTokens = modelConfig.MaxTokens
231	}
232
233	// Override max tokens if set in provider options
234	if o.providerOptions.maxTokens > 0 {
235		maxTokens = o.providerOptions.maxTokens
236	}
237	if model.CanReason {
238		params.MaxCompletionTokens = openai.Int(maxTokens)
239		switch reasoningEffort {
240		case "low":
241			params.ReasoningEffort = shared.ReasoningEffortLow
242		case "medium":
243			params.ReasoningEffort = shared.ReasoningEffortMedium
244		case "high":
245			params.ReasoningEffort = shared.ReasoningEffortHigh
246		default:
247			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
248		}
249	} else {
250		params.MaxTokens = openai.Int(maxTokens)
251	}
252
253	return params
254}
255
256func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
257	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
258	attempts := 0
259	for {
260		attempts++
261		openaiResponse, err := o.client.Chat.Completions.New(
262			ctx,
263			params,
264		)
265		// If there is an error we are going to see if we can retry the call
266		if err != nil {
267			retry, after, retryErr := o.shouldRetry(attempts, err)
268			if retryErr != nil {
269				return nil, retryErr
270			}
271			if retry {
272				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
273				select {
274				case <-ctx.Done():
275					return nil, ctx.Err()
276				case <-time.After(time.Duration(after) * time.Millisecond):
277					continue
278				}
279			}
280			return nil, retryErr
281		}
282
283		if len(openaiResponse.Choices) == 0 {
284			return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
285		}
286
287		content := ""
288		if openaiResponse.Choices[0].Message.Content != "" {
289			content = openaiResponse.Choices[0].Message.Content
290		}
291
292		toolCalls := o.toolCalls(*openaiResponse)
293		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
294
295		if len(toolCalls) > 0 {
296			finishReason = message.FinishReasonToolUse
297		}
298
299		return &ProviderResponse{
300			Content:      content,
301			ToolCalls:    toolCalls,
302			Usage:        o.usage(*openaiResponse),
303			FinishReason: finishReason,
304		}, nil
305	}
306}
307
308func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
309	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
310	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
311		IncludeUsage: openai.Bool(true),
312	}
313
314	attempts := 0
315	eventChan := make(chan ProviderEvent)
316
317	go func() {
318		for {
319			attempts++
320			// Kujtim: fixes an issue with anthropig models on openrouter
321			if len(params.Tools) == 0 {
322				params.Tools = nil
323			}
324			openaiStream := o.client.Chat.Completions.NewStreaming(
325				ctx,
326				params,
327			)
328
329			acc := openai.ChatCompletionAccumulator{}
330			currentContent := ""
331			toolCalls := make([]message.ToolCall, 0)
332
333			var currentToolCallID string
334			var currentToolCall openai.ChatCompletionMessageToolCall
335			var msgToolCalls []openai.ChatCompletionMessageToolCall
336			currentToolIndex := 0
337			for openaiStream.Next() {
338				chunk := openaiStream.Current()
339				// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
340				if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
341					chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
342					currentToolIndex++
343				}
344				acc.AddChunk(chunk)
345				// This fixes multiple tool calls for some providers
346				for _, choice := range chunk.Choices {
347					if choice.Delta.Content != "" {
348						eventChan <- ProviderEvent{
349							Type:    EventContentDelta,
350							Content: choice.Delta.Content,
351						}
352						currentContent += choice.Delta.Content
353					} else if len(choice.Delta.ToolCalls) > 0 {
354						toolCall := choice.Delta.ToolCalls[0]
355						// Detect tool use start
356						if currentToolCallID == "" {
357							if toolCall.ID != "" {
358								currentToolCallID = toolCall.ID
359								eventChan <- ProviderEvent{
360									Type: EventToolUseStart,
361									ToolCall: &message.ToolCall{
362										ID:       toolCall.ID,
363										Name:     toolCall.Function.Name,
364										Finished: false,
365									},
366								}
367								currentToolCall = openai.ChatCompletionMessageToolCall{
368									ID:   toolCall.ID,
369									Type: "function",
370									Function: openai.ChatCompletionMessageToolCallFunction{
371										Name:      toolCall.Function.Name,
372										Arguments: toolCall.Function.Arguments,
373									},
374								}
375							}
376						} else {
377							// Delta tool use
378							if toolCall.ID == "" || toolCall.ID == currentToolCallID {
379								currentToolCall.Function.Arguments += toolCall.Function.Arguments
380							} else {
381								// Detect new tool use
382								if toolCall.ID != currentToolCallID {
383									msgToolCalls = append(msgToolCalls, currentToolCall)
384									currentToolCallID = toolCall.ID
385									eventChan <- ProviderEvent{
386										Type: EventToolUseStart,
387										ToolCall: &message.ToolCall{
388											ID:       toolCall.ID,
389											Name:     toolCall.Function.Name,
390											Finished: false,
391										},
392									}
393									currentToolCall = openai.ChatCompletionMessageToolCall{
394										ID:   toolCall.ID,
395										Type: "function",
396										Function: openai.ChatCompletionMessageToolCallFunction{
397											Name:      toolCall.Function.Name,
398											Arguments: toolCall.Function.Arguments,
399										},
400									}
401								}
402							}
403						}
404					}
405					// Kujtim: some models send finish stop even for tool calls
406					if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
407						msgToolCalls = append(msgToolCalls, currentToolCall)
408						if len(acc.Choices) > 0 {
409							acc.Choices[0].Message.ToolCalls = msgToolCalls
410						}
411					}
412				}
413			}
414
415			err := openaiStream.Err()
416			if err == nil || errors.Is(err, io.EOF) {
417				if len(acc.Choices) == 0 {
418					eventChan <- ProviderEvent{
419						Type:  EventError,
420						Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
421					}
422					return
423				}
424
425				resultFinishReason := acc.Choices[0].FinishReason
426				if resultFinishReason == "" {
427					// If the finish reason is empty, we assume it was a successful completion
428					// INFO: this is happening for openrouter for some reason
429					resultFinishReason = "stop"
430				}
431				// Stream completed successfully
432				finishReason := o.finishReason(resultFinishReason)
433				if len(acc.Choices[0].Message.ToolCalls) > 0 {
434					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
435				}
436				if len(toolCalls) > 0 {
437					finishReason = message.FinishReasonToolUse
438				}
439
440				eventChan <- ProviderEvent{
441					Type: EventComplete,
442					Response: &ProviderResponse{
443						Content:      currentContent,
444						ToolCalls:    toolCalls,
445						Usage:        o.usage(acc.ChatCompletion),
446						FinishReason: finishReason,
447					},
448				}
449				close(eventChan)
450				return
451			}
452
453			// If there is an error we are going to see if we can retry the call
454			retry, after, retryErr := o.shouldRetry(attempts, err)
455			if retryErr != nil {
456				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
457				close(eventChan)
458				return
459			}
460			if retry {
461				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
462				select {
463				case <-ctx.Done():
464					// context cancelled
465					if ctx.Err() == nil {
466						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
467					}
468					close(eventChan)
469					return
470				case <-time.After(time.Duration(after) * time.Millisecond):
471					continue
472				}
473			}
474			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
475			close(eventChan)
476			return
477		}
478	}()
479
480	return eventChan
481}
482
483func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
484	if attempts > maxRetries {
485		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
486	}
487	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
488		return false, 0, err
489	}
490	var apiErr *openai.Error
491	retryMs := 0
492	retryAfterValues := []string{}
493	if errors.As(err, &apiErr) {
494		// Check for token expiration (401 Unauthorized)
495		if apiErr.StatusCode == 401 {
496			o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
497			if err != nil {
498				return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
499			}
500			o.client = createOpenAIClient(o.providerOptions)
501			return true, 0, nil
502		}
503
504		if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
505			return false, 0, err
506		}
507
508		retryAfterValues = apiErr.Response.Header.Values("Retry-After")
509	}
510
511	if apiErr != nil {
512		slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type)
513		if len(retryAfterValues) > 0 {
514			slog.Warn("Retry-After header", "values", retryAfterValues)
515		}
516	} else {
517		slog.Error("OpenAI API error", "error", err.Error(), "attempt", attempts, "max_retries", maxRetries)
518	}
519
520	backoffMs := 2000 * (1 << (attempts - 1))
521	jitterMs := int(float64(backoffMs) * 0.2)
522	retryMs = backoffMs + jitterMs
523	if len(retryAfterValues) > 0 {
524		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
525			retryMs = retryMs * 1000
526		}
527	}
528	return true, int64(retryMs), nil
529}
530
531func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
532	var toolCalls []message.ToolCall
533
534	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
535		for _, call := range completion.Choices[0].Message.ToolCalls {
536			toolCall := message.ToolCall{
537				ID:       call.ID,
538				Name:     call.Function.Name,
539				Input:    call.Function.Arguments,
540				Type:     "function",
541				Finished: true,
542			}
543			toolCalls = append(toolCalls, toolCall)
544		}
545	}
546
547	return toolCalls
548}
549
550func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
551	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
552	inputTokens := completion.Usage.PromptTokens - cachedTokens
553
554	return TokenUsage{
555		InputTokens:         inputTokens,
556		OutputTokens:        completion.Usage.CompletionTokens,
557		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
558		CacheReadTokens:     cachedTokens,
559	}
560}
561
562func (o *openaiClient) Model() catwalk.Model {
563	return o.providerOptions.model(o.providerOptions.modelType)
564}