1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/message"
15 "github.com/openai/openai-go"
16 "github.com/openai/openai-go/option"
17 "github.com/openai/openai-go/shared"
18)
19
20type openaiProvider struct {
21 *baseProvider
22 client openai.Client
23}
24
25func NewOpenAIProvider(base *baseProvider) Provider {
26 return &openaiProvider{
27 baseProvider: base,
28 client: createOpenAIClient(base),
29 }
30}
31
32func createOpenAIClient(opts *baseProvider) openai.Client {
33 openaiClientOptions := []option.RequestOption{}
34 if opts.apiKey != "" {
35 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
36 }
37 if opts.baseURL != "" {
38 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL))
39 }
40
41 for key, value := range opts.extraHeaders {
42 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
43 }
44
45 for extraKey, extraValue := range opts.extraBody {
46 openaiClientOptions = append(openaiClientOptions, option.WithJSONSet(extraKey, extraValue))
47 }
48
49 return openai.NewClient(openaiClientOptions...)
50}
51
52func (o *openaiProvider) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
53 // Add system message first
54 systemMessage := o.systemMessage
55 if o.systemPromptPrefix != "" {
56 systemMessage = o.systemPromptPrefix + "\n" + systemMessage
57 }
58 openaiMessages = append(openaiMessages, openai.SystemMessage(systemMessage))
59
60 for _, msg := range messages {
61 switch msg.Role {
62 case message.User:
63 var content []openai.ChatCompletionContentPartUnionParam
64 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
65 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
66 for _, binaryContent := range msg.BinaryContent() {
67 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(catwalk.InferenceProviderOpenAI)}
68 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
69
70 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
71 }
72
73 openaiMessages = append(openaiMessages, openai.UserMessage(content))
74
75 case message.Assistant:
76 assistantMsg := openai.ChatCompletionAssistantMessageParam{
77 Role: "assistant",
78 }
79
80 hasContent := false
81 if msg.Content().String() != "" {
82 hasContent = true
83 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
84 OfString: openai.String(msg.Content().String()),
85 }
86 }
87
88 if len(msg.ToolCalls()) > 0 {
89 hasContent = true
90 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
91 for i, call := range msg.ToolCalls() {
92 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
93 ID: call.ID,
94 Type: "function",
95 Function: openai.ChatCompletionMessageToolCallFunctionParam{
96 Name: call.Name,
97 Arguments: call.Input,
98 },
99 }
100 }
101 }
102 if !hasContent {
103 slog.Warn("There is a message without content, investigate, this should not happen")
104 continue
105 }
106
107 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
108 OfAssistant: &assistantMsg,
109 })
110
111 case message.Tool:
112 for _, result := range msg.ToolResults() {
113 openaiMessages = append(openaiMessages,
114 openai.ToolMessage(result.Content, result.ToolCallID),
115 )
116 }
117 }
118 }
119
120 return
121}
122
123func (o *openaiProvider) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
124 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
125
126 for i, tool := range tools {
127 info := tool.Info()
128 openaiTools[i] = openai.ChatCompletionToolParam{
129 Function: openai.FunctionDefinitionParam{
130 Name: info.Name,
131 Description: openai.String(info.Description),
132 Parameters: openai.FunctionParameters{
133 "type": "object",
134 "properties": info.Parameters,
135 "required": info.Required,
136 },
137 },
138 }
139 }
140
141 return openaiTools
142}
143
144func (o *openaiProvider) finishReason(reason string) message.FinishReason {
145 switch reason {
146 case "stop":
147 return message.FinishReasonEndTurn
148 case "length":
149 return message.FinishReasonMaxTokens
150 case "tool_calls":
151 return message.FinishReasonToolUse
152 default:
153 return message.FinishReasonUnknown
154 }
155}
156
157func (o *openaiProvider) preparedParams(modelID string, messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
158 model := o.Model(modelID)
159
160 reasoningEffort := o.reasoningEffort
161 if reasoningEffort == "" {
162 reasoningEffort = model.DefaultReasoningEffort
163 }
164
165 params := openai.ChatCompletionNewParams{
166 Model: openai.ChatModel(model.ID),
167 Messages: messages,
168 Tools: tools,
169 }
170
171 maxTokens := model.DefaultMaxTokens
172 if o.maxTokens > 0 {
173 maxTokens = o.maxTokens
174 }
175
176 if model.CanReason {
177 params.MaxCompletionTokens = openai.Int(maxTokens)
178 switch reasoningEffort {
179 case "low":
180 params.ReasoningEffort = shared.ReasoningEffortLow
181 case "medium":
182 params.ReasoningEffort = shared.ReasoningEffortMedium
183 case "high":
184 params.ReasoningEffort = shared.ReasoningEffortHigh
185 default:
186 params.ReasoningEffort = shared.ReasoningEffort(reasoningEffort)
187 }
188 } else {
189 params.MaxTokens = openai.Int(maxTokens)
190 }
191
192 return params
193}
194
195func (o *openaiProvider) Send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
196 messages = o.cleanMessages(messages)
197 return o.send(ctx, model, messages, tools)
198}
199
200func (o *openaiProvider) send(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
201 params := o.preparedParams(model, o.convertMessages(messages), o.convertTools(tools))
202 if o.debug {
203 jsonData, _ := json.Marshal(params)
204 slog.Debug("Prepared messages", "messages", string(jsonData))
205 }
206 attempts := 0
207 for {
208 attempts++
209 openaiResponse, err := o.client.Chat.Completions.New(
210 ctx,
211 params,
212 )
213 // If there is an error we are going to see if we can retry the call
214 if err != nil {
215 retry, after, retryErr := o.shouldRetry(attempts, err)
216 if retryErr != nil {
217 return nil, retryErr
218 }
219 if retry {
220 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
221 select {
222 case <-ctx.Done():
223 return nil, ctx.Err()
224 case <-time.After(time.Duration(after) * time.Millisecond):
225 continue
226 }
227 }
228 return nil, retryErr
229 }
230
231 if len(openaiResponse.Choices) == 0 {
232 return nil, fmt.Errorf("received empty response from OpenAI API - check endpoint configuration")
233 }
234
235 content := ""
236 if openaiResponse.Choices[0].Message.Content != "" {
237 content = openaiResponse.Choices[0].Message.Content
238 }
239
240 toolCalls := o.toolCalls(*openaiResponse)
241 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
242
243 if len(toolCalls) > 0 {
244 finishReason = message.FinishReasonToolUse
245 }
246
247 return &ProviderResponse{
248 Content: content,
249 ToolCalls: toolCalls,
250 Usage: o.usage(*openaiResponse),
251 FinishReason: finishReason,
252 }, nil
253 }
254}
255
256func (o *openaiProvider) Stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
257 messages = o.cleanMessages(messages)
258 return o.stream(ctx, model, messages, tools)
259}
260
261func (o *openaiProvider) stream(ctx context.Context, model string, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
262 params := o.preparedParams(model, o.convertMessages(messages), o.convertTools(tools))
263 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
264 IncludeUsage: openai.Bool(true),
265 }
266
267 if o.debug {
268 jsonData, _ := json.Marshal(params)
269 slog.Debug("Prepared messages", "messages", string(jsonData))
270 }
271
272 attempts := 0
273 eventChan := make(chan ProviderEvent)
274
275 go func() {
276 for {
277 attempts++
278 openaiStream := o.client.Chat.Completions.NewStreaming(
279 ctx,
280 params,
281 )
282
283 acc := openai.ChatCompletionAccumulator{}
284 currentContent := ""
285 toolCalls := make([]message.ToolCall, 0)
286
287 var currentToolCallID string
288 var currentToolCall openai.ChatCompletionMessageToolCall
289 var msgToolCalls []openai.ChatCompletionMessageToolCall
290 for openaiStream.Next() {
291 chunk := openaiStream.Current()
292 acc.AddChunk(chunk)
293 // This fixes multiple tool calls for some providers
294 for _, choice := range chunk.Choices {
295 if choice.Delta.Content != "" {
296 eventChan <- ProviderEvent{
297 Type: EventContentDelta,
298 Content: choice.Delta.Content,
299 }
300 currentContent += choice.Delta.Content
301 } else if len(choice.Delta.ToolCalls) > 0 {
302 toolCall := choice.Delta.ToolCalls[0]
303 // Detect tool use start
304 if currentToolCallID == "" {
305 if toolCall.ID != "" {
306 currentToolCallID = toolCall.ID
307 currentToolCall = openai.ChatCompletionMessageToolCall{
308 ID: toolCall.ID,
309 Type: "function",
310 Function: openai.ChatCompletionMessageToolCallFunction{
311 Name: toolCall.Function.Name,
312 Arguments: toolCall.Function.Arguments,
313 },
314 }
315 }
316 } else {
317 // Delta tool use
318 if toolCall.ID == "" {
319 currentToolCall.Function.Arguments += toolCall.Function.Arguments
320 } else {
321 // Detect new tool use
322 if toolCall.ID != currentToolCallID {
323 msgToolCalls = append(msgToolCalls, currentToolCall)
324 currentToolCallID = toolCall.ID
325 currentToolCall = openai.ChatCompletionMessageToolCall{
326 ID: toolCall.ID,
327 Type: "function",
328 Function: openai.ChatCompletionMessageToolCallFunction{
329 Name: toolCall.Function.Name,
330 Arguments: toolCall.Function.Arguments,
331 },
332 }
333 }
334 }
335 }
336 }
337 if choice.FinishReason == "tool_calls" {
338 msgToolCalls = append(msgToolCalls, currentToolCall)
339 if len(acc.Choices) > 0 {
340 acc.Choices[0].Message.ToolCalls = msgToolCalls
341 }
342 }
343 }
344 }
345
346 err := openaiStream.Err()
347 if err == nil || errors.Is(err, io.EOF) {
348 if o.debug {
349 jsonData, _ := json.Marshal(acc.ChatCompletion)
350 slog.Debug("Response", "messages", string(jsonData))
351 }
352
353 if len(acc.Choices) == 0 {
354 eventChan <- ProviderEvent{
355 Type: EventError,
356 Error: fmt.Errorf("received empty streaming response from OpenAI API - check endpoint configuration"),
357 }
358 return
359 }
360
361 resultFinishReason := acc.Choices[0].FinishReason
362 if resultFinishReason == "" {
363 // If the finish reason is empty, we assume it was a successful completion
364 // INFO: this is happening for openrouter for some reason
365 resultFinishReason = "stop"
366 }
367 // Stream completed successfully
368 finishReason := o.finishReason(resultFinishReason)
369 if len(acc.Choices[0].Message.ToolCalls) > 0 {
370 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
371 }
372 if len(toolCalls) > 0 {
373 finishReason = message.FinishReasonToolUse
374 }
375
376 eventChan <- ProviderEvent{
377 Type: EventComplete,
378 Response: &ProviderResponse{
379 Content: currentContent,
380 ToolCalls: toolCalls,
381 Usage: o.usage(acc.ChatCompletion),
382 FinishReason: finishReason,
383 },
384 }
385 close(eventChan)
386 return
387 }
388
389 // If there is an error we are going to see if we can retry the call
390 retry, after, retryErr := o.shouldRetry(attempts, err)
391 if retryErr != nil {
392 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
393 close(eventChan)
394 return
395 }
396 if retry {
397 slog.Warn("Retrying due to rate limit", "attempt", attempts, "max_retries", maxRetries)
398 select {
399 case <-ctx.Done():
400 // context cancelled
401 if ctx.Err() == nil {
402 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
403 }
404 close(eventChan)
405 return
406 case <-time.After(time.Duration(after) * time.Millisecond):
407 continue
408 }
409 }
410 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
411 close(eventChan)
412 return
413 }
414 }()
415
416 return eventChan
417}
418
419func (o *openaiProvider) shouldRetry(attempts int, err error) (bool, int64, error) {
420 var apiErr *openai.Error
421 if !errors.As(err, &apiErr) {
422 return false, 0, err
423 }
424
425 if attempts > maxRetries {
426 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
427 }
428
429 // Check for token expiration (401 Unauthorized)
430 if apiErr.StatusCode == 401 {
431 o.apiKey, err = o.resolver.ResolveValue(o.config.APIKey)
432 if err != nil {
433 return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
434 }
435 o.client = createOpenAIClient(o.baseProvider)
436 return true, 0, nil
437 }
438
439 if apiErr.StatusCode != 429 && apiErr.StatusCode != 500 {
440 return false, 0, err
441 }
442
443 retryMs := 0
444 retryAfterValues := apiErr.Response.Header.Values("Retry-After")
445
446 backoffMs := 2000 * (1 << (attempts - 1))
447 jitterMs := int(float64(backoffMs) * 0.2)
448 retryMs = backoffMs + jitterMs
449 if len(retryAfterValues) > 0 {
450 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
451 retryMs = retryMs * 1000
452 }
453 }
454 return true, int64(retryMs), nil
455}
456
457func (o *openaiProvider) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
458 var toolCalls []message.ToolCall
459
460 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
461 for _, call := range completion.Choices[0].Message.ToolCalls {
462 toolCall := message.ToolCall{
463 ID: call.ID,
464 Name: call.Function.Name,
465 Input: call.Function.Arguments,
466 Type: "function",
467 Finished: true,
468 }
469 toolCalls = append(toolCalls, toolCall)
470 }
471 }
472
473 return toolCalls
474}
475
476func (o *openaiProvider) usage(completion openai.ChatCompletion) TokenUsage {
477 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
478 inputTokens := completion.Usage.PromptTokens - cachedTokens
479
480 return TokenUsage{
481 InputTokens: inputTokens,
482 OutputTokens: completion.Usage.CompletionTokens,
483 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
484 CacheReadTokens: cachedTokens,
485 }
486}