openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"log/slog"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/log"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/openai/openai-go"
 18	"github.com/openai/openai-go/option"
 19	"github.com/openai/openai-go/packages/param"
 20	"github.com/openai/openai-go/shared"
 21)
 22
 23type openaiClient struct {
 24	providerOptions providerClientOptions
 25	client          openai.Client
 26}
 27
 28type OpenAIClient ProviderClient
 29
 30func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 31	return &openaiClient{
 32		providerOptions: opts,
 33		client:          createOpenAIClient(opts),
 34	}
 35}
 36
 37func createOpenAIClient(opts providerClientOptions) openai.Client {
 38	openaiClientOptions := []option.RequestOption{}
 39	if opts.apiKey != "" {
 40		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 41	}
 42	if opts.baseURL != "" {
 43		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 44		if err == nil {
 45			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 46		}
 47	}
 48
 49	if config.Get().Options.Debug {
 50		httpClient := log.NewHTTPClient()
 51		openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(httpClient))
 52	}
 53
 54	for key, value := range opts.extraHeaders {
 55		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 56	}
 57
 58	for extraKey, extraValue := range opts.extraBody {
 59		openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
 60	}
 61
 62	return openai.NewClient(openaiClientOptions...)
 63}
 64
 65func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 66	isAnthropicModel := o.providerOptions.config.ID == string(catwalk.InferenceProviderOpenRouter) && strings.HasPrefix(o.Model().ID, "anthropic/")
 67	// Add system message first
 68	systemMessage := o.providerOptions.systemMessage
 69	if o.providerOptions.systemPromptPrefix != "" {
 70		systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
 71	}
 72
 73	system := openai.SystemMessage(systemMessage)
 74	if isAnthropicModel && !o.providerOptions.disableCache {
 75		systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
 76		systemTextBlock.SetExtraFields(
 77			map[string]any{
 78				"cache_control": map[string]string{
 79					"type": "ephemeral",
 80				},
 81			},
 82		)
 83		var content []openai.ChatCompletionContentPartTextParam
 84		content = append(content, systemTextBlock)
 85		system = openai.SystemMessage(content)
 86	}
 87	openaiMessages = append(openaiMessages, system)
 88
 89	for i, msg := range messages {
 90		cache := false
 91		if i > len(messages)-3 {
 92			cache = true
 93		}
 94		switch msg.Role {
 95		case message.User:
 96			var content []openai.ChatCompletionContentPartUnionParam
 97
 98			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 99			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
100			hasBinaryContent := false
101			for _, binaryContent := range msg.BinaryContent() {
102				hasBinaryContent = true
103				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
104				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
105
106				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
107			}
108			if cache && !o.providerOptions.disableCache && isAnthropicModel {
109				textBlock.SetExtraFields(map[string]any{
110					"cache_control": map[string]string{
111						"type": "ephemeral",
112					},
113				})
114			}
115			if hasBinaryContent || (isAnthropicModel && !o.providerOptions.disableCache) {
116				openaiMessages = append(openaiMessages, openai.UserMessage(content))
117			} else {
118				openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
119			}
120
121		case message.Assistant:
122			assistantMsg := openai.ChatCompletionAssistantMessageParam{
123				Role: "assistant",
124			}
125
126			hasContent := false
127			if msg.Content().String() != "" {
128				hasContent = true
129				textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
130				if cache && !o.providerOptions.disableCache && isAnthropicModel {
131					textBlock.SetExtraFields(map[string]any{
132						"cache_control": map[string]string{
133							"type": "ephemeral",
134						},
135					})
136				}
137				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
138					OfArrayOfContentParts: []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
139						{
140							OfText: &textBlock,
141						},
142					},
143				}
144				if !isAnthropicModel {
145					assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
146						OfString: param.NewOpt(msg.Content().String()),
147					}
148				}
149			}
150
151			if len(msg.ToolCalls()) > 0 {
152				hasContent = true
153				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
154				for i, call := range msg.ToolCalls() {
155					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
156						ID:   call.ID,
157						Type: "function",
158						Function: openai.ChatCompletionMessageToolCallFunctionParam{
159							Name:      call.Name,
160							Arguments: call.Input,
161						},
162					}
163				}
164			}
165			if !hasContent {
166				slog.Warn("There is a message without content, investigate, this should not happen")
167				continue
168			}
169
170			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
171				OfAssistant: &assistantMsg,
172			})
173
174		case message.Tool:
175			for _, result := range msg.ToolResults() {
176				openaiMessages = append(openaiMessages,
177					openai.ToolMessage(result.Content, result.ToolCallID),
178				)
179			}
180		}
181	}
182
183	return
184}
185
186func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
187	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
188
189	for i, tool := range tools {
190		info := tool.Info()
191		openaiTools[i] = openai.ChatCompletionToolParam{
192			Function: openai.FunctionDefinitionParam{
193				Name:        info.Name,
194				Description: openai.String(info.Description),
195				Parameters: openai.FunctionParameters{
196					"type":       "object",
197					"properties": info.Parameters,
198					"required":   info.Required,
199				},
200			},
201		}
202	}
203
204	return openaiTools
205}
206
207func (o *openaiClient) finishReason(reason string) message.FinishReason {
208	switch reason {
209	case "stop":
210		return message.FinishReasonEndTurn
211	case "length":
212		return message.FinishReasonMaxTokens
213	case "tool_calls":
214		return message.FinishReasonToolUse
215	default:
216		return message.FinishReasonUnknown
217	}
218}
219
220func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
221	model := o.providerOptions.model(o.providerOptions.modelType)
222	cfg := config.Get()
223
224	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
225	if o.providerOptions.modelType == config.SelectedModelTypeSmall {
226		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
227	}
228
229	reasoningEffort := modelConfig.ReasoningEffort
230
231	params := openai.ChatCompletionNewParams{
232		Model:    openai.ChatModel(model.ID),
233		Messages: messages,
234		Tools:    tools,
235	}
236
237	maxTokens := model.DefaultMaxTokens
238	if modelConfig.MaxTokens > 0 {
239		maxTokens = modelConfig.MaxTokens
240	}
241
242	// Override max tokens if set in provider options
243	if o.providerOptions.maxTokens > 0 {
244		maxTokens = o.providerOptions.maxTokens
245	}
246	if model.CanReason {
247		params.MaxCompletionTokens = openai.Int(maxTokens)
248		switch reasoningEffort {
249		case "low":
250			params.ReasoningEffort = shared.ReasoningEffortLow
251		case "medium":
252			params.ReasoningEffort = shared.ReasoningEffortMedium
253		case "high":
254			params.ReasoningEffort = shared.ReasoningEffortHigh
255		default:
256			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
257		}
258	} else {
259		params.MaxTokens = openai.Int(maxTokens)
260	}
261
262	return params
263}
264
265func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
266	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
267	attempts := 0
268	for {
269		attempts++
270		openaiResponse, err := o.client.Chat.Completions.New(
271			ctx,
272			params,
273		)
274		// If there is an error we are going to see if we can retry the call
275		if err != nil {
276			retry, after, retryErr := o.shouldRetry(attempts, err)
277			if retryErr != nil {
278				return nil, retryErr
279			}
280			if retry {
281				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
282				select {
283				case <-ctx.Done():
284					return nil, ctx.Err()
285				case <-time.After(time.Duration(after) * time.Millisecond):
286					continue
287				}
288			}
289			return nil, retryErr
290		}
291
292		if len(openaiResponse.Choices) == 0 {
293			return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
294		}
295
296		content := ""
297		if openaiResponse.Choices[0].Message.Content != "" {
298			content = openaiResponse.Choices[0].Message.Content
299		}
300
301		toolCalls := o.toolCalls(*openaiResponse)
302		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
303
304		if len(toolCalls) > 0 {
305			finishReason = message.FinishReasonToolUse
306		}
307
308		return &ProviderResponse{
309			Content:      content,
310			ToolCalls:    toolCalls,
311			Usage:        o.usage(*openaiResponse),
312			FinishReason: finishReason,
313		}, nil
314	}
315}
316
317func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
318	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
319	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
320		IncludeUsage: openai.Bool(true),
321	}
322
323	attempts := 0
324	eventChan := make(chan ProviderEvent)
325
326	go func() {
327		for {
328			attempts++
329			// Kujtim: fixes an issue with anthropig models on openrouter
330			if len(params.Tools) == 0 {
331				params.Tools = nil
332			}
333			openaiStream := o.client.Chat.Completions.NewStreaming(
334				ctx,
335				params,
336			)
337
338			acc := openai.ChatCompletionAccumulator{}
339			currentContent := ""
340			toolCalls := make([]message.ToolCall, 0)
341
342			var currentToolCallID string
343			var currentToolCall openai.ChatCompletionMessageToolCall
344			var msgToolCalls []openai.ChatCompletionMessageToolCall
345			currentToolIndex := 0
346			for openaiStream.Next() {
347				chunk := openaiStream.Current()
348				// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
349				if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
350					chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
351					currentToolIndex++
352				}
353				acc.AddChunk(chunk)
354				// This fixes multiple tool calls for some providers
355				for _, choice := range chunk.Choices {
356					if choice.Delta.Content != "" {
357						eventChan <- ProviderEvent{
358							Type:    EventContentDelta,
359							Content: choice.Delta.Content,
360						}
361						currentContent += choice.Delta.Content
362					} else if len(choice.Delta.ToolCalls) > 0 {
363						toolCall := choice.Delta.ToolCalls[0]
364						// Detect tool use start
365						if currentToolCallID == "" {
366							if toolCall.ID != "" {
367								currentToolCallID = toolCall.ID
368								eventChan <- ProviderEvent{
369									Type: EventToolUseStart,
370									ToolCall: &message.ToolCall{
371										ID:       toolCall.ID,
372										Name:     toolCall.Function.Name,
373										Finished: false,
374									},
375								}
376								currentToolCall = openai.ChatCompletionMessageToolCall{
377									ID:   toolCall.ID,
378									Type: "function",
379									Function: openai.ChatCompletionMessageToolCallFunction{
380										Name:      toolCall.Function.Name,
381										Arguments: toolCall.Function.Arguments,
382									},
383								}
384							}
385						} else {
386							// Delta tool use
387							if toolCall.ID == "" || toolCall.ID == currentToolCallID {
388								currentToolCall.Function.Arguments += toolCall.Function.Arguments
389							} else {
390								// Detect new tool use
391								if toolCall.ID != currentToolCallID {
392									msgToolCalls = append(msgToolCalls, currentToolCall)
393									currentToolCallID = toolCall.ID
394									eventChan <- ProviderEvent{
395										Type: EventToolUseStart,
396										ToolCall: &message.ToolCall{
397											ID:       toolCall.ID,
398											Name:     toolCall.Function.Name,
399											Finished: false,
400										},
401									}
402									currentToolCall = openai.ChatCompletionMessageToolCall{
403										ID:   toolCall.ID,
404										Type: "function",
405										Function: openai.ChatCompletionMessageToolCallFunction{
406											Name:      toolCall.Function.Name,
407											Arguments: toolCall.Function.Arguments,
408										},
409									}
410								}
411							}
412						}
413					}
414					// Kujtim: some models send finish stop even for tool calls
415					if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
416						msgToolCalls = append(msgToolCalls, currentToolCall)
417						if len(acc.Choices) > 0 {
418							acc.Choices[0].Message.ToolCalls = msgToolCalls
419						}
420					}
421				}
422			}
423
424			err := openaiStream.Err()
425			if err == nil || errors.Is(err, io.EOF) {
426				if len(acc.Choices) == 0 {
427					eventChan <- ProviderEvent{
428						Type:  EventError,
429						Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
430					}
431					return
432				}
433
434				resultFinishReason := acc.Choices[0].FinishReason
435				if resultFinishReason == "" {
436					// If the finish reason is empty, we assume it was a successful completion
437					// INFO: this is happening for openrouter for some reason
438					resultFinishReason = "stop"
439				}
440				// Stream completed successfully
441				finishReason := o.finishReason(resultFinishReason)
442				if len(acc.Choices[0].Message.ToolCalls) > 0 {
443					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
444				}
445				if len(toolCalls) > 0 {
446					finishReason = message.FinishReasonToolUse
447				}
448
449				eventChan <- ProviderEvent{
450					Type: EventComplete,
451					Response: &ProviderResponse{
452						Content:      currentContent,
453						ToolCalls:    toolCalls,
454						Usage:        o.usage(acc.ChatCompletion),
455						FinishReason: finishReason,
456					},
457				}
458				close(eventChan)
459				return
460			}
461
462			// If there is an error we are going to see if we can retry the call
463			retry, after, retryErr := o.shouldRetry(attempts, err)
464			if retryErr != nil {
465				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
466				close(eventChan)
467				return
468			}
469			if retry {
470				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
471				select {
472				case <-ctx.Done():
473					// context cancelled
474					if ctx.Err() == nil {
475						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
476					}
477					close(eventChan)
478					return
479				case <-time.After(time.Duration(after) * time.Millisecond):
480					continue
481				}
482			}
483			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
484			close(eventChan)
485			return
486		}
487	}()
488
489	return eventChan
490}
491
492func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
493	if attempts > maxRetries {
494		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
495	}
496	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
497		return false, 0, err
498	}
499	var apiErr *openai.Error
500	retryMs := 0
501	retryAfterValues := []string{}
502	if errors.As(err, &apiErr) {
503		// Check for token expiration (401 Unauthorized)
504		if apiErr.StatusCode == 401 {
505			o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
506			if err != nil {
507				return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
508			}
509			o.client = createOpenAIClient(o.providerOptions)
510			return true, 0, nil
511		}
512
513		if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
514			return false, 0, err
515		}
516
517		retryAfterValues = apiErr.Response.Header.Values("Retry-After")
518	}
519
520	if apiErr != nil {
521		slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type)
522		if len(retryAfterValues) > 0 {
523			slog.Warn("Retry-After header", "values", retryAfterValues)
524		}
525	} else {
526		slog.Error("OpenAI API error", "error", err.Error(), "attempt", attempts, "max_retries", maxRetries)
527	}
528
529	backoffMs := 2000 * (1 << (attempts - 1))
530	jitterMs := int(float64(backoffMs) * 0.2)
531	retryMs = backoffMs + jitterMs
532	if len(retryAfterValues) > 0 {
533		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
534			retryMs = retryMs * 1000
535		}
536	}
537	return true, int64(retryMs), nil
538}
539
540func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
541	var toolCalls []message.ToolCall
542
543	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
544		for _, call := range completion.Choices[0].Message.ToolCalls {
545			toolCall := message.ToolCall{
546				ID:       call.ID,
547				Name:     call.Function.Name,
548				Input:    call.Function.Arguments,
549				Type:     "function",
550				Finished: true,
551			}
552			toolCalls = append(toolCalls, toolCall)
553		}
554	}
555
556	return toolCalls
557}
558
559func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
560	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
561	inputTokens := completion.Usage.PromptTokens - cachedTokens
562
563	return TokenUsage{
564		InputTokens:         inputTokens,
565		OutputTokens:        completion.Usage.CompletionTokens,
566		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
567		CacheReadTokens:     cachedTokens,
568	}
569}
570
571func (o *openaiClient) Model() catwalk.Model {
572	return o.providerOptions.model(o.providerOptions.modelType)
573}