openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"strings"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/openai/openai-go"
 18	"github.com/openai/openai-go/option"
 19	"github.com/openai/openai-go/shared"
 20)
 21
 22type openaiClient struct {
 23	providerOptions providerClientOptions
 24	client          openai.Client
 25}
 26
 27type OpenAIClient ProviderClient
 28
 29func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 30	return &openaiClient{
 31		providerOptions: opts,
 32		client:          createOpenAIClient(opts),
 33	}
 34}
 35
 36func createOpenAIClient(opts providerClientOptions) openai.Client {
 37	openaiClientOptions := []option.RequestOption{}
 38	if opts.apiKey != "" {
 39		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 40	}
 41	if opts.baseURL != "" {
 42		resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
 43		if err == nil {
 44			openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
 45		}
 46	}
 47
 48	for key, value := range opts.extraHeaders {
 49		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 50	}
 51
 52	for extraKey, extraValue := range opts.extraBody {
 53		openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
 54	}
 55
 56	return openai.NewClient(openaiClientOptions...)
 57}
 58
 59func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 60	isAnthropicModel := o.providerOptions.config.ID == string(catwalk.InferenceProviderOpenRouter) && strings.HasPrefix(o.Model().ID, "anthropic/")
 61	// Add system message first
 62	systemMessage := o.providerOptions.systemMessage
 63	if o.providerOptions.systemPromptPrefix != "" {
 64		systemMessage = o.providerOptions.systemPromptPrefix + "\n" + systemMessage
 65	}
 66
 67	systemTextBlock := openai.ChatCompletionContentPartTextParam{Text: systemMessage}
 68	if isAnthropicModel && !o.providerOptions.disableCache {
 69		systemTextBlock.SetExtraFields(
 70			map[string]any{
 71				"cache_control": map[string]string{
 72					"type": "ephemeral",
 73				},
 74			},
 75		)
 76	}
 77	var content []openai.ChatCompletionContentPartTextParam
 78	content = append(content, systemTextBlock)
 79	system := openai.SystemMessage(content)
 80	openaiMessages = append(openaiMessages, system)
 81
 82	for i, msg := range messages {
 83		cache := false
 84		if i > len(messages)-3 {
 85			cache = true
 86		}
 87		switch msg.Role {
 88		case message.User:
 89			var content []openai.ChatCompletionContentPartUnionParam
 90			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 91			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 92			for _, binaryContent := range msg.BinaryContent() {
 93				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
 94				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 95
 96				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 97			}
 98			if cache && !o.providerOptions.disableCache && isAnthropicModel {
 99				textBlock.SetExtraFields(map[string]any{
100					"cache_control": map[string]string{
101						"type": "ephemeral",
102					},
103				})
104			}
105
106			openaiMessages = append(openaiMessages, openai.UserMessage(content))
107
108		case message.Assistant:
109			assistantMsg := openai.ChatCompletionAssistantMessageParam{
110				Role: "assistant",
111			}
112
113			hasContent := false
114			if msg.Content().String() != "" {
115				hasContent = true
116				textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
117				if cache && !o.providerOptions.disableCache && isAnthropicModel {
118					textBlock.SetExtraFields(map[string]any{
119						"cache_control": map[string]string{
120							"type": "ephemeral",
121						},
122					})
123				}
124				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
125					OfArrayOfContentParts: []openai.ChatCompletionAssistantMessageParamContentArrayOfContentPartUnion{
126						{
127							OfText: &textBlock,
128						},
129					},
130				}
131			}
132
133			if len(msg.ToolCalls()) > 0 {
134				hasContent = true
135				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
136				for i, call := range msg.ToolCalls() {
137					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
138						ID:   call.ID,
139						Type: "function",
140						Function: openai.ChatCompletionMessageToolCallFunctionParam{
141							Name:      call.Name,
142							Arguments: call.Input,
143						},
144					}
145				}
146			}
147			if !hasContent {
148				slog.Warn("There is a message without content, investigate, this should not happen")
149				continue
150			}
151
152			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
153				OfAssistant: &assistantMsg,
154			})
155
156		case message.Tool:
157			for _, result := range msg.ToolResults() {
158				openaiMessages = append(openaiMessages,
159					openai.ToolMessage(result.Content, result.ToolCallID),
160				)
161			}
162		}
163	}
164
165	return
166}
167
168func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
169	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
170
171	for i, tool := range tools {
172		info := tool.Info()
173		openaiTools[i] = openai.ChatCompletionToolParam{
174			Function: openai.FunctionDefinitionParam{
175				Name:        info.Name,
176				Description: openai.String(info.Description),
177				Parameters: openai.FunctionParameters{
178					"type":       "object",
179					"properties": info.Parameters,
180					"required":   info.Required,
181				},
182			},
183		}
184	}
185
186	return openaiTools
187}
188
189func (o *openaiClient) finishReason(reason string) message.FinishReason {
190	switch reason {
191	case "stop":
192		return message.FinishReasonEndTurn
193	case "length":
194		return message.FinishReasonMaxTokens
195	case "tool_calls":
196		return message.FinishReasonToolUse
197	default:
198		return message.FinishReasonUnknown
199	}
200}
201
202func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
203	model := o.providerOptions.model(o.providerOptions.modelType)
204	cfg := config.Get()
205
206	modelConfig := cfg.Models[config.SelectedModelTypeLarge]
207	if o.providerOptions.modelType == config.SelectedModelTypeSmall {
208		modelConfig = cfg.Models[config.SelectedModelTypeSmall]
209	}
210
211	reasoningEffort := modelConfig.ReasoningEffort
212
213	params := openai.ChatCompletionNewParams{
214		Model:    openai.ChatModel(model.ID),
215		Messages: messages,
216		Tools:    tools,
217	}
218
219	maxTokens := model.DefaultMaxTokens
220	if modelConfig.MaxTokens > 0 {
221		maxTokens = modelConfig.MaxTokens
222	}
223
224	// Override max tokens if set in provider options
225	if o.providerOptions.maxTokens > 0 {
226		maxTokens = o.providerOptions.maxTokens
227	}
228	if model.CanReason {
229		params.MaxCompletionTokens = openai.Int(maxTokens)
230		switch reasoningEffort {
231		case "low":
232			params.ReasoningEffort = shared.ReasoningEffortLow
233		case "medium":
234			params.ReasoningEffort = shared.ReasoningEffortMedium
235		case "high":
236			params.ReasoningEffort = shared.ReasoningEffortHigh
237		default:
238			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
239		}
240	} else {
241		params.MaxTokens = openai.Int(maxTokens)
242	}
243
244	return params
245}
246
247func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
248	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
249	cfg := config.Get()
250	if cfg.Options.Debug {
251		jsonData, _ := json.Marshal(params)
252		slog.Debug("Prepared messages", "messages", string(jsonData))
253	}
254	attempts := 0
255	for {
256		attempts++
257		openaiResponse, err := o.client.Chat.Completions.New(
258			ctx,
259			params,
260		)
261		// If there is an error we are going to see if we can retry the call
262		if err != nil {
263			retry, after, retryErr := o.shouldRetry(attempts, err)
264			if retryErr != nil {
265				return nil, retryErr
266			}
267			if retry {
268				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
269				select {
270				case <-ctx.Done():
271					return nil, ctx.Err()
272				case <-time.After(time.Duration(after) * time.Millisecond):
273					continue
274				}
275			}
276			return nil, retryErr
277		}
278
279		if len(openaiResponse.Choices) == 0 {
280			return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
281		}
282
283		content := ""
284		if openaiResponse.Choices[0].Message.Content != "" {
285			content = openaiResponse.Choices[0].Message.Content
286		}
287
288		toolCalls := o.toolCalls(*openaiResponse)
289		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
290
291		if len(toolCalls) > 0 {
292			finishReason = message.FinishReasonToolUse
293		}
294
295		return &ProviderResponse{
296			Content:      content,
297			ToolCalls:    toolCalls,
298			Usage:        o.usage(*openaiResponse),
299			FinishReason: finishReason,
300		}, nil
301	}
302}
303
304func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
305	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
306	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
307		IncludeUsage: openai.Bool(true),
308	}
309
310	cfg := config.Get()
311	if cfg.Options.Debug {
312		jsonData, _ := json.Marshal(params)
313		slog.Debug("Prepared messages", "messages", string(jsonData))
314	}
315
316	attempts := 0
317	eventChan := make(chan ProviderEvent)
318
319	go func() {
320		for {
321			attempts++
322			openaiStream := o.client.Chat.Completions.NewStreaming(
323				ctx,
324				params,
325			)
326
327			acc := openai.ChatCompletionAccumulator{}
328			currentContent := ""
329			toolCalls := make([]message.ToolCall, 0)
330
331			var currentToolCallID string
332			var currentToolCall openai.ChatCompletionMessageToolCall
333			var msgToolCalls []openai.ChatCompletionMessageToolCall
334			currentToolIndex := 0
335			for openaiStream.Next() {
336				chunk := openaiStream.Current()
337				// Kujtim: this is an issue with openrouter qwen, its sending -1 for the tool index
338				if len(chunk.Choices) > 0 && len(chunk.Choices[0].Delta.ToolCalls) > 0 && chunk.Choices[0].Delta.ToolCalls[0].Index == -1 {
339					chunk.Choices[0].Delta.ToolCalls[0].Index = int64(currentToolIndex)
340					currentToolIndex++
341				}
342				acc.AddChunk(chunk)
343				// This fixes multiple tool calls for some providers
344				for _, choice := range chunk.Choices {
345					if choice.Delta.Content != "" {
346						eventChan <- ProviderEvent{
347							Type:    EventContentDelta,
348							Content: choice.Delta.Content,
349						}
350						currentContent += choice.Delta.Content
351					} else if len(choice.Delta.ToolCalls) > 0 {
352						toolCall := choice.Delta.ToolCalls[0]
353						// Detect tool use start
354						if currentToolCallID == "" {
355							if toolCall.ID != "" {
356								currentToolCallID = toolCall.ID
357								eventChan <- ProviderEvent{
358									Type: EventToolUseStart,
359									ToolCall: &message.ToolCall{
360										ID:       toolCall.ID,
361										Name:     toolCall.Function.Name,
362										Finished: false,
363									},
364								}
365								currentToolCall = openai.ChatCompletionMessageToolCall{
366									ID:   toolCall.ID,
367									Type: "function",
368									Function: openai.ChatCompletionMessageToolCallFunction{
369										Name:      toolCall.Function.Name,
370										Arguments: toolCall.Function.Arguments,
371									},
372								}
373							}
374						} else {
375							// Delta tool use
376							if toolCall.ID == "" || toolCall.ID == currentToolCallID {
377								currentToolCall.Function.Arguments += toolCall.Function.Arguments
378							} else {
379								// Detect new tool use
380								if toolCall.ID != currentToolCallID {
381									msgToolCalls = append(msgToolCalls, currentToolCall)
382									currentToolCallID = toolCall.ID
383									eventChan <- ProviderEvent{
384										Type: EventToolUseStart,
385										ToolCall: &message.ToolCall{
386											ID:       toolCall.ID,
387											Name:     toolCall.Function.Name,
388											Finished: false,
389										},
390									}
391									currentToolCall = openai.ChatCompletionMessageToolCall{
392										ID:   toolCall.ID,
393										Type: "function",
394										Function: openai.ChatCompletionMessageToolCallFunction{
395											Name:      toolCall.Function.Name,
396											Arguments: toolCall.Function.Arguments,
397										},
398									}
399								}
400							}
401						}
402					}
403					// Kujtim: some models send finish stop even for tool calls
404					if choice.FinishReason == "tool_calls" || (choice.FinishReason == "stop" && currentToolCallID != "") {
405						msgToolCalls = append(msgToolCalls, currentToolCall)
406						if len(acc.Choices) > 0 {
407							acc.Choices[0].Message.ToolCalls = msgToolCalls
408						}
409					}
410				}
411			}
412
413			err := openaiStream.Err()
414			if err == nil || errors.Is(err, io.EOF) {
415				if cfg.Options.Debug {
416					jsonData, _ := json.Marshal(acc.ChatCompletion)
417					slog.Debug("Response", "messages", string(jsonData))
418				}
419
420				if len(acc.Choices) == 0 {
421					eventChan <- ProviderEvent{
422						Type:  EventError,
423						Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
424					}
425					return
426				}
427
428				resultFinishReason := acc.Choices[0].FinishReason
429				if resultFinishReason == "" {
430					// If the finish reason is empty, we assume it was a successful completion
431					// INFO: this is happening for openrouter for some reason
432					resultFinishReason = "stop"
433				}
434				// Stream completed successfully
435				finishReason := o.finishReason(resultFinishReason)
436				if len(acc.Choices[0].Message.ToolCalls) > 0 {
437					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
438				}
439				if len(toolCalls) > 0 {
440					finishReason = message.FinishReasonToolUse
441				}
442
443				eventChan <- ProviderEvent{
444					Type: EventComplete,
445					Response: &ProviderResponse{
446						Content:      currentContent,
447						ToolCalls:    toolCalls,
448						Usage:        o.usage(acc.ChatCompletion),
449						FinishReason: finishReason,
450					},
451				}
452				close(eventChan)
453				return
454			}
455
456			// If there is an error we are going to see if we can retry the call
457			retry, after, retryErr := o.shouldRetry(attempts, err)
458			if retryErr != nil {
459				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
460				close(eventChan)
461				return
462			}
463			if retry {
464				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
465				select {
466				case <-ctx.Done():
467					// context cancelled
468					if ctx.Err() == nil {
469						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
470					}
471					close(eventChan)
472					return
473				case <-time.After(time.Duration(after) * time.Millisecond):
474					continue
475				}
476			}
477			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
478			close(eventChan)
479			return
480		}
481	}()
482
483	return eventChan
484}
485
486func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
487	var apiErr *openai.Error
488	if !errors.As(err, &apiErr) {
489		return false, 0, err
490	}
491
492	if attempts > maxRetries {
493		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
494	}
495
496	// Check for token expiration (401 Unauthorized)
497	if apiErr.StatusCode == 401 {
498		o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
499		if err != nil {
500			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
501		}
502		o.client = createOpenAIClient(o.providerOptions)
503		return true, 0, nil
504	}
505
506	if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
507		return false, 0, err
508	}
509
510	retryMs := 0
511	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
512
513	backoffMs := 2000 * (1 << (attempts - 1))
514	jitterMs := int(float64(backoffMs) * 0.2)
515	retryMs = backoffMs + jitterMs
516	if len(retryAfterValues) > 0 {
517		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
518			retryMs = retryMs * 1000
519		}
520	}
521	return true, int64(retryMs), nil
522}
523
524func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
525	var toolCalls []message.ToolCall
526
527	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
528		for _, call := range completion.Choices[0].Message.ToolCalls {
529			toolCall := message.ToolCall{
530				ID:       call.ID,
531				Name:     call.Function.Name,
532				Input:    call.Function.Arguments,
533				Type:     "function",
534				Finished: true,
535			}
536			toolCalls = append(toolCalls, toolCall)
537		}
538	}
539
540	return toolCalls
541}
542
543func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
544	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
545	inputTokens := completion.Usage.PromptTokens - cachedTokens
546
547	return TokenUsage{
548		InputTokens:         inputTokens,
549		OutputTokens:        completion.Usage.CompletionTokens,
550		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
551		CacheReadTokens:     cachedTokens,
552	}
553}
554
555func (o *openaiClient) Model() catwalk.Model {
556	return o.providerOptions.model(o.providerOptions.modelType)
557}