1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/fur/provider"
13 "github.com/charmbracelet/crush/internal/llm/tools"
14 "github.com/charmbracelet/crush/internal/logging"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/openai/openai-go"
17 "github.com/openai/openai-go/option"
18 "github.com/openai/openai-go/shared"
19)
20
21type openaiOptions struct {
22 reasoningEffort string
23}
24
25type openaiClient struct {
26 providerOptions providerClientOptions
27 options openaiOptions
28 client openai.Client
29}
30
31type OpenAIClient ProviderClient
32
33func newOpenAIClient(opts providerClientOptions) OpenAIClient {
34 openaiOpts := openaiOptions{
35 reasoningEffort: "medium",
36 }
37
38 openaiClientOptions := []option.RequestOption{}
39 if opts.apiKey != "" {
40 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
41 }
42 if opts.baseURL != "" {
43 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(opts.baseURL))
44 }
45
46 if opts.extraHeaders != nil {
47 for key, value := range opts.extraHeaders {
48 openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
49 }
50 }
51
52 client := openai.NewClient(openaiClientOptions...)
53 return &openaiClient{
54 providerOptions: opts,
55 options: openaiOpts,
56 client: client,
57 }
58}
59
60func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
61 // Add system message first
62 openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
63
64 for _, msg := range messages {
65 switch msg.Role {
66 case message.User:
67 var content []openai.ChatCompletionContentPartUnionParam
68 textBlock := openai.ChatCompletionContentPartTextParam{Text: msg.Content().String()}
69 content = append(content, openai.ChatCompletionContentPartUnionParam{OfText: &textBlock})
70 for _, binaryContent := range msg.BinaryContent() {
71 imageURL := openai.ChatCompletionContentPartImageImageURLParam{URL: binaryContent.String(provider.InferenceProviderOpenAI)}
72 imageBlock := openai.ChatCompletionContentPartImageParam{ImageURL: imageURL}
73
74 content = append(content, openai.ChatCompletionContentPartUnionParam{OfImageURL: &imageBlock})
75 }
76
77 openaiMessages = append(openaiMessages, openai.UserMessage(content))
78
79 case message.Assistant:
80 assistantMsg := openai.ChatCompletionAssistantMessageParam{
81 Role: "assistant",
82 }
83
84 if msg.Content().String() != "" {
85 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
86 OfString: openai.String(msg.Content().String()),
87 }
88 }
89
90 if len(msg.ToolCalls()) > 0 {
91 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
92 for i, call := range msg.ToolCalls() {
93 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
94 ID: call.ID,
95 Type: "function",
96 Function: openai.ChatCompletionMessageToolCallFunctionParam{
97 Name: call.Name,
98 Arguments: call.Input,
99 },
100 }
101 }
102 }
103
104 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
105 OfAssistant: &assistantMsg,
106 })
107
108 case message.Tool:
109 for _, result := range msg.ToolResults() {
110 openaiMessages = append(openaiMessages,
111 openai.ToolMessage(result.Content, result.ToolCallID),
112 )
113 }
114 }
115 }
116
117 return
118}
119
120func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
121 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
122
123 for i, tool := range tools {
124 info := tool.Info()
125 openaiTools[i] = openai.ChatCompletionToolParam{
126 Function: openai.FunctionDefinitionParam{
127 Name: info.Name,
128 Description: openai.String(info.Description),
129 Parameters: openai.FunctionParameters{
130 "type": "object",
131 "properties": info.Parameters,
132 "required": info.Required,
133 },
134 },
135 }
136 }
137
138 return openaiTools
139}
140
141func (o *openaiClient) finishReason(reason string) message.FinishReason {
142 switch reason {
143 case "stop":
144 return message.FinishReasonEndTurn
145 case "length":
146 return message.FinishReasonMaxTokens
147 case "tool_calls":
148 return message.FinishReasonToolUse
149 default:
150 return message.FinishReasonUnknown
151 }
152}
153
154func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
155 model := o.providerOptions.model(o.providerOptions.modelType)
156 params := openai.ChatCompletionNewParams{
157 Model: openai.ChatModel(model.ID),
158 Messages: messages,
159 Tools: tools,
160 }
161 if model.CanReason {
162 params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
163 switch o.options.reasoningEffort {
164 case "low":
165 params.ReasoningEffort = shared.ReasoningEffortLow
166 case "medium":
167 params.ReasoningEffort = shared.ReasoningEffortMedium
168 case "high":
169 params.ReasoningEffort = shared.ReasoningEffortHigh
170 default:
171 params.ReasoningEffort = shared.ReasoningEffortMedium
172 }
173 } else {
174 params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
175 }
176
177 return params
178}
179
180func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
181 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
182 cfg := config.Get()
183 if cfg.Options.Debug {
184 jsonData, _ := json.Marshal(params)
185 logging.Debug("Prepared messages", "messages", string(jsonData))
186 }
187 attempts := 0
188 for {
189 attempts++
190 openaiResponse, err := o.client.Chat.Completions.New(
191 ctx,
192 params,
193 )
194 // If there is an error we are going to see if we can retry the call
195 if err != nil {
196 retry, after, retryErr := o.shouldRetry(attempts, err)
197 if retryErr != nil {
198 return nil, retryErr
199 }
200 if retry {
201 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
202 select {
203 case <-ctx.Done():
204 return nil, ctx.Err()
205 case <-time.After(time.Duration(after) * time.Millisecond):
206 continue
207 }
208 }
209 return nil, retryErr
210 }
211
212 content := ""
213 if openaiResponse.Choices[0].Message.Content != "" {
214 content = openaiResponse.Choices[0].Message.Content
215 }
216
217 toolCalls := o.toolCalls(*openaiResponse)
218 finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
219
220 if len(toolCalls) > 0 {
221 finishReason = message.FinishReasonToolUse
222 }
223
224 return &ProviderResponse{
225 Content: content,
226 ToolCalls: toolCalls,
227 Usage: o.usage(*openaiResponse),
228 FinishReason: finishReason,
229 }, nil
230 }
231}
232
233func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
234 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
235 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
236 IncludeUsage: openai.Bool(true),
237 }
238
239 cfg := config.Get()
240 if cfg.Options.Debug {
241 jsonData, _ := json.Marshal(params)
242 logging.Debug("Prepared messages", "messages", string(jsonData))
243 }
244
245 attempts := 0
246 eventChan := make(chan ProviderEvent)
247
248 go func() {
249 for {
250 attempts++
251 openaiStream := o.client.Chat.Completions.NewStreaming(
252 ctx,
253 params,
254 )
255
256 acc := openai.ChatCompletionAccumulator{}
257 currentContent := ""
258 toolCalls := make([]message.ToolCall, 0)
259
260 for openaiStream.Next() {
261 chunk := openaiStream.Current()
262 acc.AddChunk(chunk)
263
264 for _, choice := range chunk.Choices {
265 if choice.Delta.Content != "" {
266 eventChan <- ProviderEvent{
267 Type: EventContentDelta,
268 Content: choice.Delta.Content,
269 }
270 currentContent += choice.Delta.Content
271 }
272 }
273 }
274
275 err := openaiStream.Err()
276 if err == nil || errors.Is(err, io.EOF) {
277 // Stream completed successfully
278 finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
279 if len(acc.Choices[0].Message.ToolCalls) > 0 {
280 toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
281 }
282 if len(toolCalls) > 0 {
283 finishReason = message.FinishReasonToolUse
284 }
285
286 eventChan <- ProviderEvent{
287 Type: EventComplete,
288 Response: &ProviderResponse{
289 Content: currentContent,
290 ToolCalls: toolCalls,
291 Usage: o.usage(acc.ChatCompletion),
292 FinishReason: finishReason,
293 },
294 }
295 close(eventChan)
296 return
297 }
298
299 // If there is an error we are going to see if we can retry the call
300 retry, after, retryErr := o.shouldRetry(attempts, err)
301 if retryErr != nil {
302 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
303 close(eventChan)
304 return
305 }
306 if retry {
307 logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
308 select {
309 case <-ctx.Done():
310 // context cancelled
311 if ctx.Err() == nil {
312 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
313 }
314 close(eventChan)
315 return
316 case <-time.After(time.Duration(after) * time.Millisecond):
317 continue
318 }
319 }
320 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
321 close(eventChan)
322 return
323 }
324 }()
325
326 return eventChan
327}
328
329func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
330 var apierr *openai.Error
331 if !errors.As(err, &apierr) {
332 return false, 0, err
333 }
334
335 if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
336 return false, 0, err
337 }
338
339 if attempts > maxRetries {
340 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
341 }
342
343 retryMs := 0
344 retryAfterValues := apierr.Response.Header.Values("Retry-After")
345
346 backoffMs := 2000 * (1 << (attempts - 1))
347 jitterMs := int(float64(backoffMs) * 0.2)
348 retryMs = backoffMs + jitterMs
349 if len(retryAfterValues) > 0 {
350 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
351 retryMs = retryMs * 1000
352 }
353 }
354 return true, int64(retryMs), nil
355}
356
357func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
358 var toolCalls []message.ToolCall
359
360 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
361 for _, call := range completion.Choices[0].Message.ToolCalls {
362 toolCall := message.ToolCall{
363 ID: call.ID,
364 Name: call.Function.Name,
365 Input: call.Function.Arguments,
366 Type: "function",
367 Finished: true,
368 }
369 toolCalls = append(toolCalls, toolCall)
370 }
371 }
372
373 return toolCalls
374}
375
376func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
377 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
378 inputTokens := completion.Usage.PromptTokens - cachedTokens
379
380 return TokenUsage{
381 InputTokens: inputTokens,
382 OutputTokens: completion.Usage.CompletionTokens,
383 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
384 CacheReadTokens: cachedTokens,
385 }
386}
387
388func (a *openaiClient) Model() config.Model {
389 return a.providerOptions.model(a.providerOptions.modelType)
390}