1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "io"
9 "time"
10
11 "github.com/kujtimiihoxha/opencode/internal/config"
12 "github.com/kujtimiihoxha/opencode/internal/llm/tools"
13 "github.com/kujtimiihoxha/opencode/internal/logging"
14 "github.com/kujtimiihoxha/opencode/internal/message"
15 "github.com/openai/openai-go"
16 "github.com/openai/openai-go/option"
17)
18
19type openaiOptions struct {
20 baseURL string
21 disableCache bool
22}
23
24type OpenAIOption func(*openaiOptions)
25
26type openaiClient struct {
27 providerOptions providerClientOptions
28 options openaiOptions
29 client openai.Client
30}
31
32type OpenAIClient ProviderClient
33
34func newOpenAIClient(opts providerClientOptions) OpenAIClient {
35 openaiOpts := openaiOptions{}
36 for _, o := range opts.openaiOptions {
37 o(&openaiOpts)
38 }
39
40 openaiClientOptions := []option.RequestOption{}
41 if opts.apiKey != "" {
42 openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
43 }
44 if openaiOpts.baseURL != "" {
45 openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
46 }
47
48 client := openai.NewClient(openaiClientOptions...)
49 return &openaiClient{
50 providerOptions: opts,
51 options: openaiOpts,
52 client: client,
53 }
54}
55
56func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
57 // Add system message first
58 openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
59
60 for _, msg := range messages {
61 switch msg.Role {
62 case message.User:
63 openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
64
65 case message.Assistant:
66 assistantMsg := openai.ChatCompletionAssistantMessageParam{
67 Role: "assistant",
68 }
69
70 if msg.Content().String() != "" {
71 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
72 OfString: openai.String(msg.Content().String()),
73 }
74 }
75
76 if len(msg.ToolCalls()) > 0 {
77 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
78 for i, call := range msg.ToolCalls() {
79 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
80 ID: call.ID,
81 Type: "function",
82 Function: openai.ChatCompletionMessageToolCallFunctionParam{
83 Name: call.Name,
84 Arguments: call.Input,
85 },
86 }
87 }
88 }
89
90 openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
91 OfAssistant: &assistantMsg,
92 })
93
94 case message.Tool:
95 for _, result := range msg.ToolResults() {
96 openaiMessages = append(openaiMessages,
97 openai.ToolMessage(result.Content, result.ToolCallID),
98 )
99 }
100 }
101 }
102
103 return
104}
105
106func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
107 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
108
109 for i, tool := range tools {
110 info := tool.Info()
111 openaiTools[i] = openai.ChatCompletionToolParam{
112 Function: openai.FunctionDefinitionParam{
113 Name: info.Name,
114 Description: openai.String(info.Description),
115 Parameters: openai.FunctionParameters{
116 "type": "object",
117 "properties": info.Parameters,
118 "required": info.Required,
119 },
120 },
121 }
122 }
123
124 return openaiTools
125}
126
127func (o *openaiClient) finishReason(reason string) message.FinishReason {
128 switch reason {
129 case "stop":
130 return message.FinishReasonEndTurn
131 case "length":
132 return message.FinishReasonMaxTokens
133 case "tool_calls":
134 return message.FinishReasonToolUse
135 default:
136 return message.FinishReasonUnknown
137 }
138}
139
140func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
141 return openai.ChatCompletionNewParams{
142 Model: openai.ChatModel(o.providerOptions.model.APIModel),
143 Messages: messages,
144 MaxTokens: openai.Int(o.providerOptions.maxTokens),
145 Tools: tools,
146 }
147}
148
149func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
150 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
151 cfg := config.Get()
152 if cfg.Debug {
153 jsonData, _ := json.Marshal(params)
154 logging.Debug("Prepared messages", "messages", string(jsonData))
155 }
156 attempts := 0
157 for {
158 attempts++
159 openaiResponse, err := o.client.Chat.Completions.New(
160 ctx,
161 params,
162 )
163 // If there is an error we are going to see if we can retry the call
164 if err != nil {
165 retry, after, retryErr := o.shouldRetry(attempts, err)
166 if retryErr != nil {
167 return nil, retryErr
168 }
169 if retry {
170 logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
171 select {
172 case <-ctx.Done():
173 return nil, ctx.Err()
174 case <-time.After(time.Duration(after) * time.Millisecond):
175 continue
176 }
177 }
178 return nil, retryErr
179 }
180
181 content := ""
182 if openaiResponse.Choices[0].Message.Content != "" {
183 content = openaiResponse.Choices[0].Message.Content
184 }
185
186 return &ProviderResponse{
187 Content: content,
188 ToolCalls: o.toolCalls(*openaiResponse),
189 Usage: o.usage(*openaiResponse),
190 FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)),
191 }, nil
192 }
193}
194
195func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
196 params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
197 params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
198 IncludeUsage: openai.Bool(true),
199 }
200
201 cfg := config.Get()
202 if cfg.Debug {
203 jsonData, _ := json.Marshal(params)
204 logging.Debug("Prepared messages", "messages", string(jsonData))
205 }
206
207 attempts := 0
208 eventChan := make(chan ProviderEvent)
209
210 go func() {
211 for {
212 attempts++
213 openaiStream := o.client.Chat.Completions.NewStreaming(
214 ctx,
215 params,
216 )
217
218 acc := openai.ChatCompletionAccumulator{}
219 currentContent := ""
220 toolCalls := make([]message.ToolCall, 0)
221
222 for openaiStream.Next() {
223 chunk := openaiStream.Current()
224 acc.AddChunk(chunk)
225
226 if tool, ok := acc.JustFinishedToolCall(); ok {
227 toolCalls = append(toolCalls, message.ToolCall{
228 ID: tool.Id,
229 Name: tool.Name,
230 Input: tool.Arguments,
231 Type: "function",
232 })
233 }
234
235 for _, choice := range chunk.Choices {
236 if choice.Delta.Content != "" {
237 eventChan <- ProviderEvent{
238 Type: EventContentDelta,
239 Content: choice.Delta.Content,
240 }
241 currentContent += choice.Delta.Content
242 }
243 }
244 }
245
246 err := openaiStream.Err()
247 if err == nil || errors.Is(err, io.EOF) {
248 // Stream completed successfully
249 eventChan <- ProviderEvent{
250 Type: EventComplete,
251 Response: &ProviderResponse{
252 Content: currentContent,
253 ToolCalls: toolCalls,
254 Usage: o.usage(acc.ChatCompletion),
255 FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)),
256 },
257 }
258 close(eventChan)
259 return
260 }
261
262 // If there is an error we are going to see if we can retry the call
263 retry, after, retryErr := o.shouldRetry(attempts, err)
264 if retryErr != nil {
265 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
266 close(eventChan)
267 return
268 }
269 if retry {
270 logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
271 select {
272 case <-ctx.Done():
273 // context cancelled
274 if ctx.Err() == nil {
275 eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
276 }
277 close(eventChan)
278 return
279 case <-time.After(time.Duration(after) * time.Millisecond):
280 continue
281 }
282 }
283 eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
284 close(eventChan)
285 return
286 }
287 }()
288
289 return eventChan
290}
291
292func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
293 var apierr *openai.Error
294 if !errors.As(err, &apierr) {
295 return false, 0, err
296 }
297
298 if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
299 return false, 0, err
300 }
301
302 if attempts > maxRetries {
303 return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
304 }
305
306 retryMs := 0
307 retryAfterValues := apierr.Response.Header.Values("Retry-After")
308
309 backoffMs := 2000 * (1 << (attempts - 1))
310 jitterMs := int(float64(backoffMs) * 0.2)
311 retryMs = backoffMs + jitterMs
312 if len(retryAfterValues) > 0 {
313 if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
314 retryMs = retryMs * 1000
315 }
316 }
317 return true, int64(retryMs), nil
318}
319
320func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
321 var toolCalls []message.ToolCall
322
323 if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
324 for _, call := range completion.Choices[0].Message.ToolCalls {
325 toolCall := message.ToolCall{
326 ID: call.ID,
327 Name: call.Function.Name,
328 Input: call.Function.Arguments,
329 Type: "function",
330 }
331 toolCalls = append(toolCalls, toolCall)
332 }
333 }
334
335 return toolCalls
336}
337
338func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
339 cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
340 inputTokens := completion.Usage.PromptTokens - cachedTokens
341
342 return TokenUsage{
343 InputTokens: inputTokens,
344 OutputTokens: completion.Usage.CompletionTokens,
345 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
346 CacheReadTokens: cachedTokens,
347 }
348}
349
350func WithOpenAIBaseURL(baseURL string) OpenAIOption {
351 return func(options *openaiOptions) {
352 options.baseURL = baseURL
353 }
354}
355
356func WithOpenAIDisableCache() OpenAIOption {
357 return func(options *openaiOptions) {
358 options.disableCache = true
359 }
360}
361