openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14	"github.com/charmbracelet/crush/internal/message"
 15	"github.com/openai/openai-go"
 16	"github.com/openai/openai-go/option"
 17	"github.com/openai/openai-go/shared"
 18)
 19
 20type openaiProvider struct {
 21	*baseProvider
 22	client openai.Client
 23}
 24
 25func NewOpenAIProvider(base *baseProvider) Provider {
 26	return &openaiProvider{
 27		baseProvider: base,
 28		client:       createOpenAIClient(base),
 29	}
 30}
 31
 32func createOpenAIClient(opts *baseProvider) openai.Client {
 33	openaiClientOptions := []option.RequestOption{}
 34	if opts.apiKey != "" {
 35		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 36	}
 37	if opts.baseURL != "" {
 38		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL))
 39	}
 40
 41	for key, value := range opts.extraHeaders {
 42		openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 43	}
 44
 45	for extraKey, extraValue := range opts.extraBody {
 46		openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
 47	}
 48
 49	return openai.NewClient(openaiClientOptions...)
 50}
 51
 52func (o *openaiProvider) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 53	// Add system message first
 54	systemMessage := o.systemMessage
 55	if o.systemPromptPrefix != "" {
 56		systemMessage = o.systemPromptPrefix + "\n" + systemMessage
 57	}
 58	openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage))
 59
 60	for _, msg := range messages {
 61		switch msg.Role {
 62		case message.User:
 63			var content []openai.ChatCompletionContentPartUnionParam
 64			textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
 65			content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
 66			for _, binaryContent := range msg.BinaryContent() {
 67				imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
 68				imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
 69
 70				content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
 71			}
 72
 73			openaiMessages = append(openaiMessages, openai.UserMessage(content))
 74
 75		case message.Assistant:
 76			assistantMsg := openai.ChatCompletionAssistantMessageParam{
 77				Role: "assistant",
 78			}
 79
 80			hasContent := false
 81			if msg.Content().String() != "" {
 82				hasContent = true
 83				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
 84					OfString: openai.String(msg.Content().String()),
 85				}
 86			}
 87
 88			if len(msg.ToolCalls()) > 0 {
 89				hasContent = true
 90				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
 91				for i, call := range msg.ToolCalls() {
 92					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
 93						ID:   call.ID,
 94						Type: "function",
 95						Function: openai.ChatCompletionMessageToolCallFunctionParam{
 96							Name:      call.Name,
 97							Arguments: call.Input,
 98						},
 99					}
100				}
101			}
102			if !hasContent {
103				slog.Warn("There is a message without content, investigate, this should not happen")
104				continue
105			}
106
107			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
108				OfAssistant: &assistantMsg,
109			})
110
111		case message.Tool:
112			for _, result := range msg.ToolResults() {
113				openaiMessages = append(openaiMessages,
114					openai.ToolMessage(result.Content, result.ToolCallID),
115				)
116			}
117		}
118	}
119
120	return
121}
122
123func (o *openaiProvider) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
124	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
125
126	for i, tool := range tools {
127		info := tool.Info()
128		openaiTools[i] = openai.ChatCompletionToolParam{
129			Function: openai.FunctionDefinitionParam{
130				Name:        info.Name,
131				Description: openai.String(info.Description),
132				Parameters: openai.FunctionParameters{
133					"type":       "object",
134					"properties": info.Parameters,
135					"required":   info.Required,
136				},
137			},
138		}
139	}
140
141	return openaiTools
142}
143
144func (o *openaiProvider) finishReason(reason string) message.FinishReason {
145	switch reason {
146	case "stop":
147		return message.FinishReasonEndTurn
148	case "length":
149		return message.FinishReasonMaxTokens
150	case "tool_calls":
151		return message.FinishReasonToolUse
152	default:
153		return message.FinishReasonUnknown
154	}
155}
156
157func (o *openaiProvider) preparedParams(modelID string, messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
158	model := o.Model(modelID)
159
160	reasoningEffort := o.reasoningEffort
161	if reasoningEffort == "" {
162		reasoningEffort = model.DefaultReasoningEffort
163	}
164
165	params := openai.ChatCompletionNewParams{
166		Model:    openai.ChatModel(model.ID),
167		Messages: messages,
168		Tools:    tools,
169	}
170
171	maxTokens := model.DefaultMaxTokens
172	if o.maxTokens > 0 {
173		maxTokens = o.maxTokens
174	}
175
176	if model.CanReason {
177		params.MaxCompletionTokens = openai.Int(maxTokens)
178		switch reasoningEffort {
179		case "low":
180			params.ReasoningEffort = shared.ReasoningEffortLow
181		case "medium":
182			params.ReasoningEffort = shared.ReasoningEffortMedium
183		case "high":
184			params.ReasoningEffort = shared.ReasoningEffortHigh
185		default:
186			params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
187		}
188	} else {
189		params.MaxTokens = openai.Int(maxTokens)
190	}
191
192	return params
193}
194
195func (o *openaiProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
196	messages = o.cleanMessages(messages)
197	return o.send(ctx, model, messages, tools)
198}
199
200func (o *openaiProvider) send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
201	params := o.preparedParams(model, o.convertMessages(messages), o.convertTools(tools))
202	if o.debug {
203		jsonData, _ := json.Marshal(params)
204		slog.Debug("Prepared messages", "messages", string(jsonData))
205	}
206	attempts := 0
207	for {
208		attempts++
209		openaiResponse, err := o.client.Chat.Completions.New(
210			ctx,
211			params,
212		)
213		// If there is an error we are going to see if we can retry the call
214		if err != nil {
215			retry, after, retryErr := o.shouldRetry(attempts, err)
216			if retryErr != nil {
217				return nil, retryErr
218			}
219			if retry {
220				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
221				select {
222				case <-ctx.Done():
223					return nil, ctx.Err()
224				case <-time.After(time.Duration(after) * time.Millisecond):
225					continue
226				}
227			}
228			return nil, retryErr
229		}
230
231		if len(openaiResponse.Choices) == 0 {
232			return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
233		}
234
235		content := ""
236		if openaiResponse.Choices[0].Message.Content != "" {
237			content = openaiResponse.Choices[0].Message.Content
238		}
239
240		toolCalls := o.toolCalls(*openaiResponse)
241		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
242
243		if len(toolCalls) > 0 {
244			finishReason = message.FinishReasonToolUse
245		}
246
247		return &ProviderResponse{
248			Content:      content,
249			ToolCalls:    toolCalls,
250			Usage:        o.usage(*openaiResponse),
251			FinishReason: finishReason,
252		}, nil
253	}
254}
255
256func (o *openaiProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
257	messages = o.cleanMessages(messages)
258	return o.stream(ctx, model, messages, tools)
259}
260
261func (o *openaiProvider) stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
262	params := o.preparedParams(model, o.convertMessages(messages), o.convertTools(tools))
263	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
264		IncludeUsage: openai.Bool(true),
265	}
266
267	if o.debug {
268		jsonData, _ := json.Marshal(params)
269		slog.Debug("Prepared messages", "messages", string(jsonData))
270	}
271
272	attempts := 0
273	eventChan := make(chan ProviderEvent)
274
275	go func() {
276		for {
277			attempts++
278			openaiStream := o.client.Chat.Completions.NewStreaming(
279				ctx,
280				params,
281			)
282
283			acc := openai.ChatCompletionAccumulator{}
284			currentContent := ""
285			toolCalls := make([]message.ToolCall, 0)
286
287			var currentToolCallID string
288			var currentToolCall openai.ChatCompletionMessageToolCall
289			var msgToolCalls []openai.ChatCompletionMessageToolCall
290			for openaiStream.Next() {
291				chunk := openaiStream.Current()
292				acc.AddChunk(chunk)
293				// This fixes multiple tool calls for some providers
294				for _, choice := range chunk.Choices {
295					if choice.Delta.Content != "" {
296						eventChan <- ProviderEvent{
297							Type:    EventContentDelta,
298							Content: choice.Delta.Content,
299						}
300						currentContent += choice.Delta.Content
301					} else if len(choice.Delta.ToolCalls) > 0 {
302						toolCall := choice.Delta.ToolCalls[0]
303						// Detect tool use start
304						if currentToolCallID == "" {
305							if toolCall.ID != "" {
306								currentToolCallID = toolCall.ID
307								currentToolCall = openai.ChatCompletionMessageToolCall{
308									ID:   toolCall.ID,
309									Type: "function",
310									Function: openai.ChatCompletionMessageToolCallFunction{
311										Name:      toolCall.Function.Name,
312										Arguments: toolCall.Function.Arguments,
313									},
314								}
315							}
316						} else {
317							// Delta tool use
318							if toolCall.ID == "" {
319								currentToolCall.Function.Arguments += toolCall.Function.Arguments
320							} else {
321								// Detect new tool use
322								if toolCall.ID != currentToolCallID {
323									msgToolCalls = append(msgToolCalls, currentToolCall)
324									currentToolCallID = toolCall.ID
325									currentToolCall = openai.ChatCompletionMessageToolCall{
326										ID:   toolCall.ID,
327										Type: "function",
328										Function: openai.ChatCompletionMessageToolCallFunction{
329											Name:      toolCall.Function.Name,
330											Arguments: toolCall.Function.Arguments,
331										},
332									}
333								}
334							}
335						}
336					}
337					if choice.FinishReason == "tool_calls" {
338						msgToolCalls = append(msgToolCalls, currentToolCall)
339						if len(acc.Choices) > 0 {
340							acc.Choices[0].Message.ToolCalls = msgToolCalls
341						}
342					}
343				}
344			}
345
346			err := openaiStream.Err()
347			if err == nil || errors.Is(err, io.EOF) {
348				if o.debug {
349					jsonData, _ := json.Marshal(acc.ChatCompletion)
350					slog.Debug("Response", "messages", string(jsonData))
351				}
352
353				if len(acc.Choices) == 0 {
354					eventChan <- ProviderEvent{
355						Type:  EventError,
356						Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
357					}
358					return
359				}
360
361				resultFinishReason := acc.Choices[0].FinishReason
362				if resultFinishReason == "" {
363					// If the finish reason is empty, we assume it was a successful completion
364					// INFO: this is happening for openrouter for some reason
365					resultFinishReason = "stop"
366				}
367				// Stream completed successfully
368				finishReason := o.finishReason(resultFinishReason)
369				if len(acc.Choices[0].Message.ToolCalls) > 0 {
370					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
371				}
372				if len(toolCalls) > 0 {
373					finishReason = message.FinishReasonToolUse
374				}
375
376				eventChan <- ProviderEvent{
377					Type: EventComplete,
378					Response: &ProviderResponse{
379						Content:      currentContent,
380						ToolCalls:    toolCalls,
381						Usage:        o.usage(acc.ChatCompletion),
382						FinishReason: finishReason,
383					},
384				}
385				close(eventChan)
386				return
387			}
388
389			// If there is an error we are going to see if we can retry the call
390			retry, after, retryErr := o.shouldRetry(attempts, err)
391			if retryErr != nil {
392				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
393				close(eventChan)
394				return
395			}
396			if retry {
397				slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
398				select {
399				case <-ctx.Done():
400					// context cancelled
401					if ctx.Err() == nil {
402						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
403					}
404					close(eventChan)
405					return
406				case <-time.After(time.Duration(after) * time.Millisecond):
407					continue
408				}
409			}
410			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
411			close(eventChan)
412			return
413		}
414	}()
415
416	return eventChan
417}
418
419func (o *openaiProvider) shouldRetry(attempts int, err error) (bool, int64, error) {
420	var apiErr *openai.Error
421	if !errors.As(err, &apiErr) {
422		return false, 0, err
423	}
424
425	if attempts > maxRetries {
426		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
427	}
428
429	// Check for token expiration (401 Unauthorized)
430	if apiErr.StatusCode == 401 {
431		o.apiKey, err = o.resolver.ResolveValue(o.config.APIKey)
432		if err != nil {
433			return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
434		}
435		o.client = createOpenAIClient(o.baseProvider)
436		return true, 0, nil
437	}
438
439	if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
440		return false, 0, err
441	}
442
443	retryMs := 0
444	retryAfterValues := apiErr.Response.Header.Values("Retry-After")
445
446	backoffMs := 2000 * (1 << (attempts - 1))
447	jitterMs := int(float64(backoffMs) * 0.2)
448	retryMs = backoffMs + jitterMs
449	if len(retryAfterValues) > 0 {
450		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
451			retryMs = retryMs * 1000
452		}
453	}
454	return true, int64(retryMs), nil
455}
456
457func (o *openaiProvider) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
458	var toolCalls []message.ToolCall
459
460	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
461		for _, call := range completion.Choices[0].Message.ToolCalls {
462			toolCall := message.ToolCall{
463				ID:       call.ID,
464				Name:     call.Function.Name,
465				Input:    call.Function.Arguments,
466				Type:     "function",
467				Finished: true,
468			}
469			toolCalls = append(toolCalls, toolCall)
470		}
471	}
472
473	return toolCalls
474}
475
476func (o *openaiProvider) usage(completion openai.ChatCompletion) TokenUsage {
477	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
478	inputTokens := completion.Usage.PromptTokens - cachedTokens
479
480	return TokenUsage{
481		InputTokens:         inputTokens,
482		OutputTokens:        completion.Usage.CompletionTokens,
483		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
484		CacheReadTokens:     cachedTokens,
485	}
486}