openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/openai/openai-go"
 18	"github.com/openai/openai-go/option"
 19	"github.com/openai/openai-go/packages/param"
 20	"github.com/openai/openai-go/shared"
 21)
 22
 23type openaiClient struct {
 24	providerOptions providerClientOptions
 25	client          openai.Client
 26}
 27
 28type OpenAIClient ProviderClient
 29
 30func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 31	return &openaiClient{
 32		providerOptions: opts,
 33		client:          createOpenAIClient(opts),
 34	}
 35}
 36
 37func createOpenAIClient(opts providerClientOptions) openai.Client {
 38	openaiClientOptions := []option.RequestOption{}
 39	if opts.apiKey != "" {
 40		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 41	}
 42	if opts.baseURL != "" {
 43		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 44		if err == nil {
 45			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 46		}
 47	}
 48
 49	for key, value := range opts.extraHeaders {
 50		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 51	}
 52
 53	for extraKey, extraValue := range opts.extraBody {
 54		openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
 55	}
 56
 57	return openai.NewClient(openaiClientOptions...)
 58}
 59
 60func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 61	isAnthropicModel := o.providerOptions.config.ID == string(catwalk.InferenceProviderOpenRouter) && strings.HasPrefix(o.Model().ID, "anthropic/")
 62	// Add system message first
 63	systemMessage := o.providerOptions.systemMessage
 64	if o.providerOptions.systemPromptPrefix != "" {
 65		systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
 66	}
 67
 68	systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
 69	if isAnthropicModel && !o.providerOptions.disableCache {
 70		systemTextBlock.SetExtraFields(
 71			map[string]any{
 72				"cache_control": map[string]string{
 73					"type": "ephemeral",
 74				},
 75			},
 76		)
 77	}
 78	var content []openai.ChatCompletionContentPartTextParam
 79	content = append(content, systemTextBlock)
 80	system := openai.SystemMessage(content)
 81	openaiMessages = append(openaiMessages, system)
 82
 83	for i, msg := range messages {
 84		cache := false
 85		if i > len(messages)-3 {
 86			cache = true
 87		}
 88		switch msg.Role {
 89		case message.User:
 90			var content []openai.ChatCompletionContentPartUnionParam
 91			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 92			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 93			for _, binaryContent := range msg.BinaryContent() {
 94				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
 95				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 96
 97				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 98			}
 99			if cache && !o.providerOptions.disableCache && isAnthropicModel {
100				textBlock.SetExtraFields(map[string]any{
101					"cache_control": map[string]string{
102						"type": "ephemeral",
103					},
104				})
105			}
106
107			openaiMessages = append(openaiMessages, openai.UserMessage(content))
108
109		case message.Assistant:
110			assistantMsg := openai.ChatCompletionAssistantMessageParam{
111				Role: "assistant",
112			}
113
114			hasContent := false
115			if msg.Content().String() != "" {
116				hasContent = true
117				textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
118				if cache && !o.providerOptions.disableCache && isAnthropicModel {
119					textBlock.SetExtraFields(map[string]any{
120						"cache_control": map[string]string{
121							"type": "ephemeral",
122						},
123					})
124				}
125				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
126					OfArrayOfContentParts: []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
127						{
128							OfText: &textBlock,
129						},
130					},
131				}
132			}
133
134			if len(msg.ToolCalls()) > 0 {
135				hasContent = true
136				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
137					OfString: param.NewOpt(msg.Content().String()),
138				}
139				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
140				for i, call := range msg.ToolCalls() {
141					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
142						ID:   call.ID,
143						Type: "function",
144						Function: openai.ChatCompletionMessageToolCallFunctionParam{
145							Name:      call.Name,
146							Arguments: call.Input,
147						},
148					}
149				}
150			}
151			if !hasContent {
152				slog.Warn("There is a message without content, investigate, this should not happen")
153				continue
154			}
155
156			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
157				OfAssistant: &assistantMsg,
158			})
159
160		case message.Tool:
161			for _, result := range msg.ToolResults() {
162				openaiMessages = append(openaiMessages,
163					openai.ToolMessage(result.Content, result.ToolCallID),
164				)
165			}
166		}
167	}
168
169	return
170}
171
172func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
173	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
174
175	for i, tool := range tools {
176		info := tool.Info()
177		openaiTools[i] = openai.ChatCompletionToolParam{
178			Function: openai.FunctionDefinitionParam{
179				Name:        info.Name,
180				Description: openai.String(info.Description),
181				Parameters: openai.FunctionParameters{
182					"type":       "object",
183					"properties": info.Parameters,
184					"required":   info.Required,
185				},
186			},
187		}
188	}
189
190	return openaiTools
191}
192
193func (o *openaiClient) finishReason(reason string) message.FinishReason {
194	switch reason {
195	case "stop":
196		return message.FinishReasonEndTurn
197	case "length":
198		return message.FinishReasonMaxTokens
199	case "tool_calls":
200		return message.FinishReasonToolUse
201	default:
202		return message.FinishReasonUnknown
203	}
204}
205
206func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
207	model := o.providerOptions.model(o.providerOptions.modelType)
208	cfg := config.Get()
209
210	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
211	if o.providerOptions.modelType == config.SelectedModelTypeSmall {
212		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
213	}
214
215	reasoningEffort := modelConfig.ReasoningEffort
216
217	params := openai.ChatCompletionNewParams{
218		Model:    openai.ChatModel(model.ID),
219		Messages: messages,
220		Tools:    tools,
221	}
222
223	maxTokens := model.DefaultMaxTokens
224	if modelConfig.MaxTokens > 0 {
225		maxTokens = modelConfig.MaxTokens
226	}
227
228	// Override max tokens if set in provider options
229	if o.providerOptions.maxTokens > 0 {
230		maxTokens = o.providerOptions.maxTokens
231	}
232	if model.CanReason {
233		params.MaxCompletionTokens = openai.Int(maxTokens)
234		switch reasoningEffort {
235		case "low":
236			params.ReasoningEffort = shared.ReasoningEffortLow
237		case "medium":
238			params.ReasoningEffort = shared.ReasoningEffortMedium
239		case "high":
240			params.ReasoningEffort = shared.ReasoningEffortHigh
241		default:
242			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
243		}
244	} else {
245		params.MaxTokens = openai.Int(maxTokens)
246	}
247
248	return params
249}
250
251func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
252	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
253	cfg := config.Get()
254	if cfg.Options.Debug {
255		jsonData, _ := json.Marshal(params)
256		slog.Debug("Prepared messages", "messages", string(jsonData))
257	}
258	attempts := 0
259	for {
260		attempts++
261		openaiResponse, err := o.client.Chat.Completions.New(
262			ctx,
263			params,
264		)
265		// If there is an error we are going to see if we can retry the call
266		if err != nil {
267			retry, after, retryErr := o.shouldRetry(attempts, err)
268			if retryErr != nil {
269				return nil, retryErr
270			}
271			if retry {
272				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
273				select {
274				case <-ctx.Done():
275					return nil, ctx.Err()
276				case <-time.After(time.Duration(after) * time.Millisecond):
277					continue
278				}
279			}
280			return nil, retryErr
281		}
282
283		if len(openaiResponse.Choices) == 0 {
284			return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
285		}
286
287		content := ""
288		if openaiResponse.Choices[0].Message.Content != "" {
289			content = openaiResponse.Choices[0].Message.Content
290		}
291
292		toolCalls := o.toolCalls(*openaiResponse)
293		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
294
295		if len(toolCalls) > 0 {
296			finishReason = message.FinishReasonToolUse
297		}
298
299		return &ProviderResponse{
300			Content:      content,
301			ToolCalls:    toolCalls,
302			Usage:        o.usage(*openaiResponse),
303			FinishReason: finishReason,
304		}, nil
305	}
306}
307
308func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
309	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
310	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
311		IncludeUsage: openai.Bool(true),
312	}
313
314	cfg := config.Get()
315	if cfg.Options.Debug {
316		jsonData, _ := json.Marshal(params)
317		slog.Debug("Prepared messages", "messages", string(jsonData))
318	}
319
320	attempts := 0
321	eventChan := make(chan ProviderEvent)
322
323	go func() {
324		for {
325			attempts++
326			// Kujtim: fixes an issue with anthropig models on openrouter
327			if len(params.Tools) == 0 {
328				params.Tools = nil
329			}
330			openaiStream := o.client.Chat.Completions.NewStreaming(
331				ctx,
332				params,
333			)
334
335			acc := openai.ChatCompletionAccumulator{}
336			currentContent := ""
337			toolCalls := make([]message.ToolCall, 0)
338
339			var currentToolCallID string
340			var currentToolCall openai.ChatCompletionMessageToolCall
341			var msgToolCalls []openai.ChatCompletionMessageToolCall
342			currentToolIndex := 0
343			for openaiStream.Next() {
344				chunk := openaiStream.Current()
345				// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
346				if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
347					chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
348					currentToolIndex++
349				}
350				acc.AddChunk(chunk)
351				// This fixes multiple tool calls for some providers
352				for _, choice := range chunk.Choices {
353					if choice.Delta.Content != "" {
354						eventChan <- ProviderEvent{
355							Type:    EventContentDelta,
356							Content: choice.Delta.Content,
357						}
358						currentContent += choice.Delta.Content
359					} else if len(choice.Delta.ToolCalls) > 0 {
360						toolCall := choice.Delta.ToolCalls[0]
361						// Detect tool use start
362						if currentToolCallID == "" {
363							if toolCall.ID != "" {
364								currentToolCallID = toolCall.ID
365								eventChan <- ProviderEvent{
366									Type: EventToolUseStart,
367									ToolCall: &message.ToolCall{
368										ID:       toolCall.ID,
369										Name:     toolCall.Function.Name,
370										Finished: false,
371									},
372								}
373								currentToolCall = openai.ChatCompletionMessageToolCall{
374									ID:   toolCall.ID,
375									Type: "function",
376									Function: openai.ChatCompletionMessageToolCallFunction{
377										Name:      toolCall.Function.Name,
378										Arguments: toolCall.Function.Arguments,
379									},
380								}
381							}
382						} else {
383							// Delta tool use
384							if toolCall.ID == "" || toolCall.ID == currentToolCallID {
385								currentToolCall.Function.Arguments += toolCall.Function.Arguments
386							} else {
387								// Detect new tool use
388								if toolCall.ID != currentToolCallID {
389									msgToolCalls = append(msgToolCalls, currentToolCall)
390									currentToolCallID = toolCall.ID
391									eventChan <- ProviderEvent{
392										Type: EventToolUseStart,
393										ToolCall: &message.ToolCall{
394											ID:       toolCall.ID,
395											Name:     toolCall.Function.Name,
396											Finished: false,
397										},
398									}
399									currentToolCall = openai.ChatCompletionMessageToolCall{
400										ID:   toolCall.ID,
401										Type: "function",
402										Function: openai.ChatCompletionMessageToolCallFunction{
403											Name:      toolCall.Function.Name,
404											Arguments: toolCall.Function.Arguments,
405										},
406									}
407								}
408							}
409						}
410					}
411					// Kujtim: some models send finish stop even for tool calls
412					if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
413						msgToolCalls = append(msgToolCalls, currentToolCall)
414						if len(acc.Choices) > 0 {
415							acc.Choices[0].Message.ToolCalls = msgToolCalls
416						}
417					}
418				}
419			}
420
421			err := openaiStream.Err()
422			if err == nil || errors.Is(err, io.EOF) {
423				if cfg.Options.Debug {
424					jsonData, _ := json.Marshal(acc.ChatCompletion)
425					slog.Debug("Response", "messages", string(jsonData))
426				}
427
428				if len(acc.Choices) == 0 {
429					eventChan <- ProviderEvent{
430						Type:  EventError,
431						Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
432					}
433					return
434				}
435
436				resultFinishReason := acc.Choices[0].FinishReason
437				if resultFinishReason == "" {
438					// If the finish reason is empty, we assume it was a successful completion
439					// INFO: this is happening for openrouter for some reason
440					resultFinishReason = "stop"
441				}
442				// Stream completed successfully
443				finishReason := o.finishReason(resultFinishReason)
444				if len(acc.Choices[0].Message.ToolCalls) > 0 {
445					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
446				}
447				if len(toolCalls) > 0 {
448					finishReason = message.FinishReasonToolUse
449				}
450
451				eventChan <- ProviderEvent{
452					Type: EventComplete,
453					Response: &ProviderResponse{
454						Content:      currentContent,
455						ToolCalls:    toolCalls,
456						Usage:        o.usage(acc.ChatCompletion),
457						FinishReason: finishReason,
458					},
459				}
460				close(eventChan)
461				return
462			}
463
464			// If there is an error we are going to see if we can retry the call
465			retry, after, retryErr := o.shouldRetry(attempts, err)
466			if retryErr != nil {
467				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
468				close(eventChan)
469				return
470			}
471			if retry {
472				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
473				select {
474				case <-ctx.Done():
475					// context cancelled
476					if ctx.Err() == nil {
477						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
478					}
479					close(eventChan)
480					return
481				case <-time.After(time.Duration(after) * time.Millisecond):
482					continue
483				}
484			}
485			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
486			close(eventChan)
487			return
488		}
489	}()
490
491	return eventChan
492}
493
494func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
495	if attempts > maxRetries {
496		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
497	}
498	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
499		return false, 0, err
500	}
501	var apiErr *openai.Error
502	retryMs := 0
503	retryAfterValues := []string{}
504	if errors.As(err, &apiErr) {
505		// Check for token expiration (401 Unauthorized)
506		if apiErr.StatusCode == 401 {
507			o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
508			if err != nil {
509				return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
510			}
511			o.client = createOpenAIClient(o.providerOptions)
512			return true, 0, nil
513		}
514
515		if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
516			return false, 0, err
517		}
518
519		retryAfterValues = apiErr.Response.Header.Values("Retry-After")
520	}
521
522	if apiErr != nil {
523		slog.Warn("OpenAI API error", "status_code", apiErr.StatusCode, "message", apiErr.Message, "type", apiErr.Type)
524		if len(retryAfterValues) > 0 {
525			slog.Warn("Retry-After header", "values", retryAfterValues)
526		}
527	} else {
528		slog.Warn("OpenAI API error", "error", err.Error())
529	}
530
531	backoffMs := 2000 * (1 << (attempts - 1))
532	jitterMs := int(float64(backoffMs) * 0.2)
533	retryMs = backoffMs + jitterMs
534	if len(retryAfterValues) > 0 {
535		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
536			retryMs = retryMs * 1000
537		}
538	}
539	return true, int64(retryMs), nil
540}
541
542func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
543	var toolCalls []message.ToolCall
544
545	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
546		for _, call := range completion.Choices[0].Message.ToolCalls {
547			toolCall := message.ToolCall{
548				ID:       call.ID,
549				Name:     call.Function.Name,
550				Input:    call.Function.Arguments,
551				Type:     "function",
552				Finished: true,
553			}
554			toolCalls = append(toolCalls, toolCall)
555		}
556	}
557
558	return toolCalls
559}
560
561func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
562	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
563	inputTokens := completion.Usage.PromptTokens - cachedTokens
564
565	return TokenUsage{
566		InputTokens:         inputTokens,
567		OutputTokens:        completion.Usage.CompletionTokens,
568		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
569		CacheReadTokens:     cachedTokens,
570	}
571}
572
573func (o *openaiClient) Model() catwalk.Model {
574	return o.providerOptions.model(o.providerOptions.modelType)
575}