openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"time"
 10
 11	"github.com/openai/openai-go"
 12	"github.com/openai/openai-go/option"
 13	"github.com/openai/openai-go/shared"
 14	"github.com/opencode-ai/opencode/internal/config"
 15	"github.com/opencode-ai/opencode/internal/llm/tools"
 16	"github.com/opencode-ai/opencode/internal/logging"
 17	"github.com/opencode-ai/opencode/internal/message"
 18)
 19
 20type openaiOptions struct {
 21	baseURL         string
 22	disableCache    bool
 23	reasoningEffort string
 24	extraHeaders    map[string]string
 25}
 26
 27type OpenAIOption func(*openaiOptions)
 28
 29type openaiClient struct {
 30	providerOptions providerClientOptions
 31	options         openaiOptions
 32	client          openai.Client
 33}
 34
 35type OpenAIClient ProviderClient
 36
 37func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 38	openaiOpts := openaiOptions{
 39		reasoningEffort: "medium",
 40	}
 41	for _, o := range opts.openaiOptions {
 42		o(&openaiOpts)
 43	}
 44
 45	openaiClientOptions := []option.RequestOption{}
 46	if opts.apiKey != "" {
 47		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 48	}
 49	if openaiOpts.baseURL != "" {
 50		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
 51	}
 52
 53	if openaiOpts.extraHeaders != nil {
 54		for key, value := range openaiOpts.extraHeaders {
 55			openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 56		}
 57	}
 58
 59	client := openai.NewClient(openaiClientOptions...)
 60	return &openaiClient{
 61		providerOptions: opts,
 62		options:         openaiOpts,
 63		client:          client,
 64	}
 65}
 66
 67func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 68	// Add system message first
 69	openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
 70
 71	for _, msg := range messages {
 72		switch msg.Role {
 73		case message.User:
 74			openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
 75
 76		case message.Assistant:
 77			assistantMsg := openai.ChatCompletionAssistantMessageParam{
 78				Role: "assistant",
 79			}
 80
 81			if msg.Content().String() != "" {
 82				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
 83					OfString: openai.String(msg.Content().String()),
 84				}
 85			}
 86
 87			if len(msg.ToolCalls()) > 0 {
 88				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
 89				for i, call := range msg.ToolCalls() {
 90					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
 91						ID:   call.ID,
 92						Type: "function",
 93						Function: openai.ChatCompletionMessageToolCallFunctionParam{
 94							Name:      call.Name,
 95							Arguments: call.Input,
 96						},
 97					}
 98				}
 99			}
100
101			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
102				OfAssistant: &assistantMsg,
103			})
104
105		case message.Tool:
106			for _, result := range msg.ToolResults() {
107				openaiMessages = append(openaiMessages,
108					openai.ToolMessage(result.Content, result.ToolCallID),
109				)
110			}
111		}
112	}
113
114	return
115}
116
117func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
118	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
119
120	for i, tool := range tools {
121		info := tool.Info()
122		openaiTools[i] = openai.ChatCompletionToolParam{
123			Function: openai.FunctionDefinitionParam{
124				Name:        info.Name,
125				Description: openai.String(info.Description),
126				Parameters: openai.FunctionParameters{
127					"type":       "object",
128					"properties": info.Parameters,
129					"required":   info.Required,
130				},
131			},
132		}
133	}
134
135	return openaiTools
136}
137
138func (o *openaiClient) finishReason(reason string) message.FinishReason {
139	switch reason {
140	case "stop":
141		return message.FinishReasonEndTurn
142	case "length":
143		return message.FinishReasonMaxTokens
144	case "tool_calls":
145		return message.FinishReasonToolUse
146	default:
147		return message.FinishReasonUnknown
148	}
149}
150
151func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
152	params := openai.ChatCompletionNewParams{
153		Model:    openai.ChatModel(o.providerOptions.model.APIModel),
154		Messages: messages,
155		Tools:    tools,
156	}
157
158	if o.providerOptions.model.CanReason == true {
159		params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
160		switch o.options.reasoningEffort {
161		case "low":
162			params.ReasoningEffort = shared.ReasoningEffortLow
163		case "medium":
164			params.ReasoningEffort = shared.ReasoningEffortMedium
165		case "high":
166			params.ReasoningEffort = shared.ReasoningEffortHigh
167		default:
168			params.ReasoningEffort = shared.ReasoningEffortMedium
169		}
170	} else {
171		params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
172	}
173
174	return params
175}
176
177func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
178	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
179	cfg := config.Get()
180	if cfg.Debug {
181		jsonData, _ := json.Marshal(params)
182		logging.Debug("Prepared messages", "messages", string(jsonData))
183	}
184	attempts := 0
185	for {
186		attempts++
187		openaiResponse, err := o.client.Chat.Completions.New(
188			ctx,
189			params,
190		)
191		// If there is an error we are going to see if we can retry the call
192		if err != nil {
193			retry, after, retryErr := o.shouldRetry(attempts, err)
194			if retryErr != nil {
195				return nil, retryErr
196			}
197			if retry {
198				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
199				select {
200				case <-ctx.Done():
201					return nil, ctx.Err()
202				case <-time.After(time.Duration(after) * time.Millisecond):
203					continue
204				}
205			}
206			return nil, retryErr
207		}
208
209		content := ""
210		if openaiResponse.Choices[0].Message.Content != "" {
211			content = openaiResponse.Choices[0].Message.Content
212		}
213
214		toolCalls := o.toolCalls(*openaiResponse)
215		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
216
217		if len(toolCalls) > 0 {
218			finishReason = message.FinishReasonToolUse
219		}
220
221		return &ProviderResponse{
222			Content:      content,
223			ToolCalls:    toolCalls,
224			Usage:        o.usage(*openaiResponse),
225			FinishReason: finishReason,
226		}, nil
227	}
228}
229
230func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
231	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
232	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
233		IncludeUsage: openai.Bool(true),
234	}
235
236	cfg := config.Get()
237	if cfg.Debug {
238		jsonData, _ := json.Marshal(params)
239		logging.Debug("Prepared messages", "messages", string(jsonData))
240	}
241
242	attempts := 0
243	eventChan := make(chan ProviderEvent)
244
245	go func() {
246		for {
247			attempts++
248			openaiStream := o.client.Chat.Completions.NewStreaming(
249				ctx,
250				params,
251			)
252
253			acc := openai.ChatCompletionAccumulator{}
254			currentContent := ""
255			toolCalls := make([]message.ToolCall, 0)
256
257			for openaiStream.Next() {
258				chunk := openaiStream.Current()
259				acc.AddChunk(chunk)
260
261				if tool, ok := acc.JustFinishedToolCall(); ok {
262					toolCalls = append(toolCalls, message.ToolCall{
263						ID:    tool.Id,
264						Name:  tool.Name,
265						Input: tool.Arguments,
266						Type:  "function",
267					})
268				}
269
270				for _, choice := range chunk.Choices {
271					if choice.Delta.Content != "" {
272						eventChan <- ProviderEvent{
273							Type:    EventContentDelta,
274							Content: choice.Delta.Content,
275						}
276						currentContent += choice.Delta.Content
277					}
278				}
279			}
280
281			err := openaiStream.Err()
282			if err == nil || errors.Is(err, io.EOF) {
283				// Stream completed successfully
284				finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
285
286				if len(toolCalls) > 0 {
287					finishReason = message.FinishReasonToolUse
288				}
289
290				eventChan <- ProviderEvent{
291					Type: EventComplete,
292					Response: &ProviderResponse{
293						Content:      currentContent,
294						ToolCalls:    toolCalls,
295						Usage:        o.usage(acc.ChatCompletion),
296						FinishReason: finishReason,
297					},
298				}
299				close(eventChan)
300				return
301			}
302
303			// If there is an error we are going to see if we can retry the call
304			retry, after, retryErr := o.shouldRetry(attempts, err)
305			if retryErr != nil {
306				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
307				close(eventChan)
308				return
309			}
310			if retry {
311				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
312				select {
313				case <-ctx.Done():
314					// context cancelled
315					if ctx.Err() == nil {
316						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
317					}
318					close(eventChan)
319					return
320				case <-time.After(time.Duration(after) * time.Millisecond):
321					continue
322				}
323			}
324			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
325			close(eventChan)
326			return
327		}
328	}()
329
330	return eventChan
331}
332
333func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
334	var apierr *openai.Error
335	if !errors.As(err, &apierr) {
336		return false, 0, err
337	}
338
339	if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
340		return false, 0, err
341	}
342
343	if attempts > maxRetries {
344		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
345	}
346
347	retryMs := 0
348	retryAfterValues := apierr.Response.Header.Values("Retry-After")
349
350	backoffMs := 2000 * (1 << (attempts - 1))
351	jitterMs := int(float64(backoffMs) * 0.2)
352	retryMs = backoffMs + jitterMs
353	if len(retryAfterValues) > 0 {
354		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
355			retryMs = retryMs * 1000
356		}
357	}
358	return true, int64(retryMs), nil
359}
360
361func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
362	var toolCalls []message.ToolCall
363
364	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
365		for _, call := range completion.Choices[0].Message.ToolCalls {
366			toolCall := message.ToolCall{
367				ID:       call.ID,
368				Name:     call.Function.Name,
369				Input:    call.Function.Arguments,
370				Type:     "function",
371				Finished: true,
372			}
373			toolCalls = append(toolCalls, toolCall)
374		}
375	}
376
377	return toolCalls
378}
379
380func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
381	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
382	inputTokens := completion.Usage.PromptTokens - cachedTokens
383
384	return TokenUsage{
385		InputTokens:         inputTokens,
386		OutputTokens:        completion.Usage.CompletionTokens,
387		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
388		CacheReadTokens:     cachedTokens,
389	}
390}
391
392func WithOpenAIBaseURL(baseURL string) OpenAIOption {
393	return func(options *openaiOptions) {
394		options.baseURL = baseURL
395	}
396}
397
398func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption {
399	return func(options *openaiOptions) {
400		options.extraHeaders = headers
401	}
402}
403
404func WithOpenAIDisableCache() OpenAIOption {
405	return func(options *openaiOptions) {
406		options.disableCache = true
407	}
408}
409
410func WithReasoningEffort(effort string) OpenAIOption {
411	return func(options *openaiOptions) {
412		defaultReasoningEffort := "medium"
413		switch effort {
414		case "low", "medium", "high":
415			defaultReasoningEffort = effort
416		default:
417			logging.Warn("Invalid reasoning effort, using default: medium")
418		}
419		options.reasoningEffort = defaultReasoningEffort
420	}
421}