openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"time"
 10
 11	"github.com/openai/openai-go"
 12	"github.com/openai/openai-go/option"
 13	"github.com/openai/openai-go/shared"
 14	"github.com/opencode-ai/opencode/internal/config"
 15	"github.com/opencode-ai/opencode/internal/llm/tools"
 16	"github.com/opencode-ai/opencode/internal/logging"
 17	"github.com/opencode-ai/opencode/internal/message"
 18)
 19
 20type openaiOptions struct {
 21	baseURL         string
 22	disableCache    bool
 23	reasoningEffort string
 24	extraHeaders    map[string]string
 25}
 26
 27type OpenAIOption func(*openaiOptions)
 28
 29type openaiClient struct {
 30	providerOptions providerClientOptions
 31	options         openaiOptions
 32	client          openai.Client
 33}
 34
 35type OpenAIClient ProviderClient
 36
 37func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 38	openaiOpts := openaiOptions{
 39		reasoningEffort: "medium",
 40	}
 41	for _, o := range opts.openaiOptions {
 42		o(&openaiOpts)
 43	}
 44
 45	openaiClientOptions := []option.RequestOption{}
 46	if opts.apiKey != "" {
 47		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 48	}
 49	if openaiOpts.baseURL != "" {
 50		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
 51	}
 52
 53	if openaiOpts.extraHeaders != nil {
 54		for key, value := range openaiOpts.extraHeaders {
 55			openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
 56		}
 57	}
 58
 59	client := openai.NewClient(openaiClientOptions...)
 60	return &openaiClient{
 61		providerOptions: opts,
 62		options:         openaiOpts,
 63		client:          client,
 64	}
 65}
 66
 67func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 68	// Add system message first
 69	openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
 70
 71	for _, msg := range messages {
 72		switch msg.Role {
 73		case message.User:
 74			openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
 75
 76		case message.Assistant:
 77			assistantMsg := openai.ChatCompletionAssistantMessageParam{
 78				Role: "assistant",
 79			}
 80
 81			if msg.Content().String() != "" {
 82				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
 83					OfString: openai.String(msg.Content().String()),
 84				}
 85			}
 86
 87			if len(msg.ToolCalls()) > 0 {
 88				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
 89				for i, call := range msg.ToolCalls() {
 90					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
 91						ID:   call.ID,
 92						Type: "function",
 93						Function: openai.ChatCompletionMessageToolCallFunctionParam{
 94							Name:      call.Name,
 95							Arguments: call.Input,
 96						},
 97					}
 98				}
 99			}
100
101			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
102				OfAssistant: &assistantMsg,
103			})
104
105		case message.Tool:
106			for _, result := range msg.ToolResults() {
107				openaiMessages = append(openaiMessages,
108					openai.ToolMessage(result.Content, result.ToolCallID),
109				)
110			}
111		}
112	}
113
114	return
115}
116
117func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
118	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
119
120	for i, tool := range tools {
121		info := tool.Info()
122		openaiTools[i] = openai.ChatCompletionToolParam{
123			Function: openai.FunctionDefinitionParam{
124				Name:        info.Name,
125				Description: openai.String(info.Description),
126				Parameters: openai.FunctionParameters{
127					"type":       "object",
128					"properties": info.Parameters,
129					"required":   info.Required,
130				},
131			},
132		}
133	}
134
135	return openaiTools
136}
137
138func (o *openaiClient) finishReason(reason string) message.FinishReason {
139	switch reason {
140	case "stop":
141		return message.FinishReasonEndTurn
142	case "length":
143		return message.FinishReasonMaxTokens
144	case "tool_calls":
145		return message.FinishReasonToolUse
146	default:
147		return message.FinishReasonUnknown
148	}
149}
150
151func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
152	params := openai.ChatCompletionNewParams{
153		Model:    openai.ChatModel(o.providerOptions.model.APIModel),
154		Messages: messages,
155		Tools:    tools,
156	}
157
158	if o.providerOptions.model.CanReason == true {
159		params.MaxCompletionTokens = openai.Int(o.providerOptions.maxTokens)
160		switch o.options.reasoningEffort {
161		case "low":
162			params.ReasoningEffort = shared.ReasoningEffortLow
163		case "medium":
164			params.ReasoningEffort = shared.ReasoningEffortMedium
165		case "high":
166			params.ReasoningEffort = shared.ReasoningEffortHigh
167		default:
168			params.ReasoningEffort = shared.ReasoningEffortMedium
169		}
170	} else {
171		params.MaxTokens = openai.Int(o.providerOptions.maxTokens)
172	}
173
174	return params
175}
176
177func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
178	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
179	cfg := config.Get()
180	if cfg.Debug {
181		jsonData, _ := json.Marshal(params)
182		logging.Debug("Prepared messages", "messages", string(jsonData))
183	}
184	attempts := 0
185	for {
186		attempts++
187		openaiResponse, err := o.client.Chat.Completions.New(
188			ctx,
189			params,
190		)
191		// If there is an error we are going to see if we can retry the call
192		if err != nil {
193			retry, after, retryErr := o.shouldRetry(attempts, err)
194			if retryErr != nil {
195				return nil, retryErr
196			}
197			if retry {
198				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
199				select {
200				case <-ctx.Done():
201					return nil, ctx.Err()
202				case <-time.After(time.Duration(after) * time.Millisecond):
203					continue
204				}
205			}
206			return nil, retryErr
207		}
208
209		content := ""
210		if openaiResponse.Choices[0].Message.Content != "" {
211			content = openaiResponse.Choices[0].Message.Content
212		}
213
214		toolCalls := o.toolCalls(*openaiResponse)
215		finishReason := o.finishReason(string(openaiResponse.Choices[0].FinishReason))
216
217		if len(toolCalls) > 0 {
218			finishReason = message.FinishReasonToolUse
219		}
220
221		return &ProviderResponse{
222			Content:      content,
223			ToolCalls:    toolCalls,
224			Usage:        o.usage(*openaiResponse),
225			FinishReason: finishReason,
226		}, nil
227	}
228}
229
230func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
231	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
232	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
233		IncludeUsage: openai.Bool(true),
234	}
235
236	cfg := config.Get()
237	if cfg.Debug {
238		jsonData, _ := json.Marshal(params)
239		logging.Debug("Prepared messages", "messages", string(jsonData))
240	}
241
242	attempts := 0
243	eventChan := make(chan ProviderEvent)
244
245	go func() {
246		for {
247			attempts++
248			openaiStream := o.client.Chat.Completions.NewStreaming(
249				ctx,
250				params,
251			)
252
253			acc := openai.ChatCompletionAccumulator{}
254			currentContent := ""
255			toolCalls := make([]message.ToolCall, 0)
256
257			for openaiStream.Next() {
258				chunk := openaiStream.Current()
259				acc.AddChunk(chunk)
260
261				for _, choice := range chunk.Choices {
262					if choice.Delta.Content != "" {
263						eventChan <- ProviderEvent{
264							Type:    EventContentDelta,
265							Content: choice.Delta.Content,
266						}
267						currentContent += choice.Delta.Content
268					}
269				}
270			}
271
272			err := openaiStream.Err()
273			if err == nil || errors.Is(err, io.EOF) {
274				// Stream completed successfully
275				finishReason := o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason))
276				if len(acc.ChatCompletion.Choices[0].Message.ToolCalls) > 0 {
277					toolCalls = append(toolCalls, o.toolCalls(acc.ChatCompletion)...)
278				}
279				if len(toolCalls) > 0 {
280					finishReason = message.FinishReasonToolUse
281				}
282
283				eventChan <- ProviderEvent{
284					Type: EventComplete,
285					Response: &ProviderResponse{
286						Content:      currentContent,
287						ToolCalls:    toolCalls,
288						Usage:        o.usage(acc.ChatCompletion),
289						FinishReason: finishReason,
290					},
291				}
292				close(eventChan)
293				return
294			}
295
296			// If there is an error we are going to see if we can retry the call
297			retry, after, retryErr := o.shouldRetry(attempts, err)
298			if retryErr != nil {
299				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
300				close(eventChan)
301				return
302			}
303			if retry {
304				logging.WarnPersist(fmt.Sprintf("Retrying due to rate limit... attempt %d of %d", attempts, maxRetries), logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
305				select {
306				case <-ctx.Done():
307					// context cancelled
308					if ctx.Err() == nil {
309						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
310					}
311					close(eventChan)
312					return
313				case <-time.After(time.Duration(after) * time.Millisecond):
314					continue
315				}
316			}
317			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
318			close(eventChan)
319			return
320		}
321	}()
322
323	return eventChan
324}
325
326func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
327	var apierr *openai.Error
328	if !errors.As(err, &apierr) {
329		return false, 0, err
330	}
331
332	if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
333		return false, 0, err
334	}
335
336	if attempts > maxRetries {
337		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
338	}
339
340	retryMs := 0
341	retryAfterValues := apierr.Response.Header.Values("Retry-After")
342
343	backoffMs := 2000 * (1 << (attempts - 1))
344	jitterMs := int(float64(backoffMs) * 0.2)
345	retryMs = backoffMs + jitterMs
346	if len(retryAfterValues) > 0 {
347		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
348			retryMs = retryMs * 1000
349		}
350	}
351	return true, int64(retryMs), nil
352}
353
354func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
355	var toolCalls []message.ToolCall
356
357	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
358		for _, call := range completion.Choices[0].Message.ToolCalls {
359			toolCall := message.ToolCall{
360				ID:       call.ID,
361				Name:     call.Function.Name,
362				Input:    call.Function.Arguments,
363				Type:     "function",
364				Finished: true,
365			}
366			toolCalls = append(toolCalls, toolCall)
367		}
368	}
369
370	return toolCalls
371}
372
373func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
374	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
375	inputTokens := completion.Usage.PromptTokens - cachedTokens
376
377	return TokenUsage{
378		InputTokens:         inputTokens,
379		OutputTokens:        completion.Usage.CompletionTokens,
380		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
381		CacheReadTokens:     cachedTokens,
382	}
383}
384
385func WithOpenAIBaseURL(baseURL string) OpenAIOption {
386	return func(options *openaiOptions) {
387		options.baseURL = baseURL
388	}
389}
390
391func WithOpenAIExtraHeaders(headers map[string]string) OpenAIOption {
392	return func(options *openaiOptions) {
393		options.extraHeaders = headers
394	}
395}
396
397func WithOpenAIDisableCache() OpenAIOption {
398	return func(options *openaiOptions) {
399		options.disableCache = true
400	}
401}
402
403func WithReasoningEffort(effort string) OpenAIOption {
404	return func(options *openaiOptions) {
405		defaultReasoningEffort := "medium"
406		switch effort {
407		case "low", "medium", "high":
408			defaultReasoningEffort = effort
409		default:
410			logging.Warn("Invalid reasoning effort, using default: medium")
411		}
412		options.reasoningEffort = defaultReasoningEffort
413	}
414}