1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/llm/models"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/logging"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/openai/openai-go"
17 "github.com/openai/openai-go/option"
18 "github.com/openai/openai-go/shared"
19)
20
21type openaiOptions struct {
22 baseURL string
23 disableCache bool
24 reasoningEffort string
25 extraHeaders map[string]string
26}
27
28type OpenAIOption func(*openaiOptions)
29
30type openaiClient struct {
31 providerOptions providerClientOptions
32 options openaiOptions
33 client openai.Client
34}
35
36type OpenAIClient ProviderClient
37
38func newOpenAIClient(opts providerClientOptions) OpenAIClient {
39 openaiOpts := openaiOptions{
40 reasoningEffort: "medium",
41 }
42 for _, o := range opts.openaiOptions {
43 o(&openaiOpts)
44 }
45
46 openaiClientOptions := []option.RequestOption{}
47 if opts.apiKey != "" {
48 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
49 }
50 if openaiOpts.baseURL != "" {
51 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
52 }
53
54 if openaiOpts.extraHeaders != nil {
55 for key, value := range openaiOpts.extraHeaders {
56 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
57 }
58 }
59
60 client := openai.NewClient(openaiClientOptions...)
61 return &openaiClient{
62 providerOptions: opts,
63 options: openaiOpts,
64 client: client,
65 }
66}
67
68func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
69 // Add system message first
70 openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
71
72 for _, msg := range messages {
73 switch msg.Role {
74 case message.User:
75 var content []openai.ChatCompletionContentPartUnionParam
76 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
77 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
78 for _, binaryContent := range msg.BinaryContent() {
79 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(models.ProviderOpenAI)}
80 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
81
82 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
83 }
84
85 openaiMessages = append(openaiMessages, openai.UserMessage(content))
86
87 case message.Assistant:
88 assistantMsg := openai.ChatCompletionAssistantMessageParam{
89 Role: "assistant",
90 }
91
92 if msg.Content().String() != "" {
93 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
94 OfString: openai.String(msg.Content().String()),
95 }
96 }
97
98 if len(msg.ToolCalls()) > 0 {
99 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
100 for i, call := range msg.ToolCalls() {
101 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
102 ID: call.ID,
103 Type: "function",
104 Function: openai.ChatCompletionMessageToolCallFunctionParam{
105 Name: call.Name,
106 Arguments: call.Input,
107 },
108 }
109 }
110 }
111
112 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
113 OfAssistant: &assistantMsg,
114 })
115
116 case message.Tool:
117 for _, result := range msg.ToolResults() {
118 openaiMessages = append(openaiMessages,
119 openai.ToolMessage(result.Content, result.ToolCallID),
120 )
121 }
122 }
123 }
124
125 return
126}
127
128func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
129 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
130
131 for i, tool := range tools {
132 info := tool.Info()
133 openaiTools[i] = openai.ChatCompletionToolParam{
134 Function: openai.FunctionDefinitionParam{
135 Name: info.Name,
136 Description: openai.String(info.Description),
137 Parameters: openai.FunctionParameters{
138 "type": "object",
139 "properties": info.Parameters,
140 "required": info.Required,
141 },
142 },
143 }
144 }
145
146 return openaiTools
147}
148
149func (o *openaiClient) finishReason(reason string) message.FinishReason {
150 switch reason {
151 case "stop":
152 return message.FinishReasonEndTurn
153 case "length":
154 return message.FinishReasonMaxTokens
155 case "tool_calls":
156 return message.FinishReasonToolUse
157 default:
158 return message.FinishReasonUnknown
159 }
160}
161
162func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
163 params := openai.ChatCompletionNewParams{
164 Model: openai.ChatModel(o.providerOptions.model.APIModel),
165 Messages: messages,
166 Tools: tools,
167 }
168
169 if o.providerOptions.model.CanReason {
170 params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
171 switch o.options.reasoningEffort {
172 case "low":
173 params.ReasoningEffort = shared.ReasoningEffortLow
174 case "medium":
175 params.ReasoningEffort = shared.ReasoningEffortMedium
176 case "high":
177 params.ReasoningEffort = shared.ReasoningEffortHigh
178 default:
179 params.ReasoningEffort = shared.ReasoningEffortMedium
180 }
181 } else {
182 params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
183 }
184
185 return params
186}
187
188func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
189 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
190 cfg := config.Get()
191 if cfg.Debug {
192 jsonData, _ := json.Marshal(params)
193 logging.Debug("Prepared messages", "messages", string(jsonData))
194 }
195 attempts := 0
196 for {
197 attempts++
198 openaiResponse, err := o.client.Chat.Completions.New(
199 ctx,
200 params,
201 )
202 // If there is an error we are going to see if we can retry the call
203 if err != nil {
204 retry, after, retryErr := o.shouldRetry(attempts, err)
205 if retryErr != nil {
206 return nil, retryErr
207 }
208 if retry {
209 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
210 select {
211 case <-ctx.Done():
212 return nil, ctx.Err()
213 case <-time.After(time.Duration(after) * time.Millisecond):
214 continue
215 }
216 }
217 return nil, retryErr
218 }
219
220 content := ""
221 if openaiResponse.Choices[0].Message.Content != "" {
222 content = openaiResponse.Choices[0].Message.Content
223 }
224
225 toolCalls := o.toolCalls(*openaiResponse)
226 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
227
228 if len(toolCalls) > 0 {
229 finishReason = message.FinishReasonToolUse
230 }
231
232 return &ProviderResponse{
233 Content: content,
234 ToolCalls: toolCalls,
235 Usage: o.usage(*openaiResponse),
236 FinishReason: finishReason,
237 }, nil
238 }
239}
240
241func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
242 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
243 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
244 IncludeUsage: openai.Bool(true),
245 }
246
247 cfg := config.Get()
248 if cfg.Debug {
249 jsonData, _ := json.Marshal(params)
250 logging.Debug("Prepared messages", "messages", string(jsonData))
251 }
252
253 attempts := 0
254 eventChan := make(chan ProviderEvent)
255
256 go func() {
257 for {
258 attempts++
259 openaiStream := o.client.Chat.Completions.NewStreaming(
260 ctx,
261 params,
262 )
263
264 acc := openai.ChatCompletionAccumulator{}
265 currentContent := ""
266 toolCalls := make([]message.ToolCall, 0)
267
268 for openaiStream.Next() {
269 chunk := openaiStream.Current()
270 acc.AddChunk(chunk)
271
272 for _, choice := range chunk.Choices {
273 if choice.Delta.Content != "" {
274 eventChan <- ProviderEvent{
275 Type: EventContentDelta,
276 Content: choice.Delta.Content,
277 }
278 currentContent += choice.Delta.Content
279 }
280 }
281 }
282
283 err := openaiStream.Err()
284 if err == nil || errors.Is(err, io.EOF) {
285 // Stream completed successfully
286 finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
287 if len(acc.Choices[0].Message.ToolCalls) > 0 {
288 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
289 }
290 if len(toolCalls) > 0 {
291 finishReason = message.FinishReasonToolUse
292 }
293
294 eventChan <- ProviderEvent{
295 Type: EventComplete,
296 Response: &ProviderResponse{
297 Content: currentContent,
298 ToolCalls: toolCalls,
299 Usage: o.usage(acc.ChatCompletion),
300 FinishReason: finishReason,
301 },
302 }
303 close(eventChan)
304 return
305 }
306
307 // If there is an error we are going to see if we can retry the call
308 retry, after, retryErr := o.shouldRetry(attempts, err)
309 if retryErr != nil {
310 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
311 close(eventChan)
312 return
313 }
314 if retry {
315 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
316 select {
317 case <-ctx.Done():
318 // context cancelled
319 if ctx.Err() == nil {
320 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
321 }
322 close(eventChan)
323 return
324 case <-time.After(time.Duration(after) * time.Millisecond):
325 continue
326 }
327 }
328 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
329 close(eventChan)
330 return
331 }
332 }()
333
334 return eventChan
335}
336
337func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
338 var apierr *openai.Error
339 if !errors.As(err, &apierr) {
340 return false, 0, err
341 }
342
343 if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
344 return false, 0, err
345 }
346
347 if attempts > maxRetries {
348 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
349 }
350
351 retryMs := 0
352 retryAfterValues := apierr.Response.Header.Values("Retry-After")
353
354 backoffMs := 2000 * (1 << (attempts - 1))
355 jitterMs := int(float64(backoffMs) * 0.2)
356 retryMs = backoffMs + jitterMs
357 if len(retryAfterValues) > 0 {
358 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
359 retryMs = retryMs * 1000
360 }
361 }
362 return true, int64(retryMs), nil
363}
364
365func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
366 var toolCalls []message.ToolCall
367
368 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
369 for _, call := range completion.Choices[0].Message.ToolCalls {
370 toolCall := message.ToolCall{
371 ID: call.ID,
372 Name: call.Function.Name,
373 Input: call.Function.Arguments,
374 Type: "function",
375 Finished: true,
376 }
377 toolCalls = append(toolCalls, toolCall)
378 }
379 }
380
381 return toolCalls
382}
383
384func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
385 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
386 inputTokens := completion.Usage.PromptTokens - cachedTokens
387
388 return TokenUsage{
389 InputTokens: inputTokens,
390 OutputTokens: completion.Usage.CompletionTokens,
391 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
392 CacheReadTokens: cachedTokens,
393 }
394}
395
396func WithOpenAIBaseURL(baseURL string) OpenAIOption {
397 return func(options *openaiOptions) {
398 options.baseURL = baseURL
399 }
400}
401
402func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption {
403 return func(options *openaiOptions) {
404 options.extraHeaders = headers
405 }
406}
407
408func WithOpenAIDisableCache() OpenAIOption {
409 return func(options *openaiOptions) {
410 options.disableCache = true
411 }
412}
413
414func WithReasoningEffort(effort string) OpenAIOption {
415 return func(options *openaiOptions) {
416 defaultReasoningEffort := "medium"
417 switch effort {
418 case "low", "medium", "high":
419 defaultReasoningEffort = effort
420 default:
421 logging.Warn("Invalid reasoning effort, using default: medium")
422 }
423 options.reasoningEffort = defaultReasoningEffort
424 }
425}