openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"time"
 10
 11	"github.com/kujtimiihoxha/termai/internal/config"
 12	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 13	"github.com/kujtimiihoxha/termai/internal/logging"
 14	"github.com/kujtimiihoxha/termai/internal/message"
 15	"github.com/openai/openai-go"
 16	"github.com/openai/openai-go/option"
 17)
 18
 19type openaiOptions struct {
 20	baseURL      string
 21	disableCache bool
 22}
 23
 24type OpenAIOption func(*openaiOptions)
 25
 26type openaiClient struct {
 27	providerOptions providerClientOptions
 28	options         openaiOptions
 29	client          openai.Client
 30}
 31
 32type OpenAIClient ProviderClient
 33
 34func newOpenAIClient(opts providerClientOptions) OpenAIClient {
 35	openaiOpts := openaiOptions{}
 36	for _, o := range opts.openaiOptions {
 37		o(&openaiOpts)
 38	}
 39
 40	openaiClientOptions := []option.RequestOption{}
 41	if opts.apiKey != "" {
 42		openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
 43	}
 44	if openaiOpts.baseURL != "" {
 45		openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(openaiOpts.baseURL))
 46	}
 47
 48	client := openai.NewClient(openaiClientOptions...)
 49	return &openaiClient{
 50		providerOptions: opts,
 51		options:         openaiOpts,
 52		client:          client,
 53	}
 54}
 55
 56func (o *openaiClient) convertMessages(messages []message.Message) (openaiMessages []openai.ChatCompletionMessageParamUnion) {
 57	// Add system message first
 58	openaiMessages = append(openaiMessages, openai.SystemMessage(o.providerOptions.systemMessage))
 59
 60	for _, msg := range messages {
 61		switch msg.Role {
 62		case message.User:
 63			openaiMessages = append(openaiMessages, openai.UserMessage(msg.Content().String()))
 64
 65		case message.Assistant:
 66			assistantMsg := openai.ChatCompletionAssistantMessageParam{
 67				Role: "assistant",
 68			}
 69
 70			if msg.Content().String() != "" {
 71				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
 72					OfString: openai.String(msg.Content().String()),
 73				}
 74			}
 75
 76			if len(msg.ToolCalls()) > 0 {
 77				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
 78				for i, call := range msg.ToolCalls() {
 79					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
 80						ID:   call.ID,
 81						Type: "function",
 82						Function: openai.ChatCompletionMessageToolCallFunctionParam{
 83							Name:      call.Name,
 84							Arguments: call.Input,
 85						},
 86					}
 87				}
 88			}
 89
 90			openaiMessages = append(openaiMessages, openai.ChatCompletionMessageParamUnion{
 91				OfAssistant: &assistantMsg,
 92			})
 93
 94		case message.Tool:
 95			for _, result := range msg.ToolResults() {
 96				openaiMessages = append(openaiMessages,
 97					openai.ToolMessage(result.Content, result.ToolCallID),
 98				)
 99			}
100		}
101	}
102
103	return
104}
105
106func (o *openaiClient) convertTools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
107	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
108
109	for i, tool := range tools {
110		info := tool.Info()
111		openaiTools[i] = openai.ChatCompletionToolParam{
112			Function: openai.FunctionDefinitionParam{
113				Name:        info.Name,
114				Description: openai.String(info.Description),
115				Parameters: openai.FunctionParameters{
116					"type":       "object",
117					"properties": info.Parameters,
118					"required":   info.Required,
119				},
120			},
121		}
122	}
123
124	return openaiTools
125}
126
127func (o *openaiClient) finishReason(reason string) message.FinishReason {
128	switch reason {
129	case "stop":
130		return message.FinishReasonEndTurn
131	case "length":
132		return message.FinishReasonMaxTokens
133	case "tool_calls":
134		return message.FinishReasonToolUse
135	default:
136		return message.FinishReasonUnknown
137	}
138}
139
140func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
141	return openai.ChatCompletionNewParams{
142		Model:     openai.ChatModel(o.providerOptions.model.APIModel),
143		Messages:  messages,
144		MaxTokens: openai.Int(o.providerOptions.maxTokens),
145		Tools:     tools,
146	}
147}
148
149func (o *openaiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (response *ProviderResponse, err error) {
150	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
151	cfg := config.Get()
152	if cfg.Debug {
153		jsonData, _ := json.Marshal(params)
154		logging.Debug("Prepared messages", "messages", string(jsonData))
155	}
156	attempts := 0
157	for {
158		attempts++
159		openaiResponse, err := o.client.Chat.Completions.New(
160			ctx,
161			params,
162		)
163		// If there is an error we are going to see if we can retry the call
164		if err != nil {
165			retry, after, retryErr := o.shouldRetry(attempts, err)
166			if retryErr != nil {
167				return nil, retryErr
168			}
169			if retry {
170				logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
171				select {
172				case <-ctx.Done():
173					return nil, ctx.Err()
174				case <-time.After(time.Duration(after) * time.Millisecond):
175					continue
176				}
177			}
178			return nil, retryErr
179		}
180
181		content := ""
182		if openaiResponse.Choices[0].Message.Content != "" {
183			content = openaiResponse.Choices[0].Message.Content
184		}
185
186		return &ProviderResponse{
187			Content:      content,
188			ToolCalls:    o.toolCalls(*openaiResponse),
189			Usage:        o.usage(*openaiResponse),
190			FinishReason: o.finishReason(string(openaiResponse.Choices[0].FinishReason)),
191		}, nil
192	}
193}
194
195func (o *openaiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
196	params := o.preparedParams(o.convertMessages(messages), o.convertTools(tools))
197	params.StreamOptions = openai.ChatCompletionStreamOptionsParam{
198		IncludeUsage: openai.Bool(true),
199	}
200
201	cfg := config.Get()
202	if cfg.Debug {
203		jsonData, _ := json.Marshal(params)
204		logging.Debug("Prepared messages", "messages", string(jsonData))
205	}
206
207	attempts := 0
208	eventChan := make(chan ProviderEvent)
209
210	go func() {
211		for {
212			attempts++
213			openaiStream := o.client.Chat.Completions.NewStreaming(
214				ctx,
215				params,
216			)
217
218			acc := openai.ChatCompletionAccumulator{}
219			currentContent := ""
220			toolCalls := make([]message.ToolCall, 0)
221
222			for openaiStream.Next() {
223				chunk := openaiStream.Current()
224				acc.AddChunk(chunk)
225
226				if tool, ok := acc.JustFinishedToolCall(); ok {
227					toolCalls = append(toolCalls, message.ToolCall{
228						ID:    tool.Id,
229						Name:  tool.Name,
230						Input: tool.Arguments,
231						Type:  "function",
232					})
233				}
234
235				for _, choice := range chunk.Choices {
236					if choice.Delta.Content != "" {
237						eventChan <- ProviderEvent{
238							Type:    EventContentDelta,
239							Content: choice.Delta.Content,
240						}
241						currentContent += choice.Delta.Content
242					}
243				}
244			}
245
246			err := openaiStream.Err()
247			if err == nil || errors.Is(err, io.EOF) {
248				// Stream completed successfully
249				eventChan <- ProviderEvent{
250					Type: EventComplete,
251					Response: &ProviderResponse{
252						Content:      currentContent,
253						ToolCalls:    toolCalls,
254						Usage:        o.usage(acc.ChatCompletion),
255						FinishReason: o.finishReason(string(acc.ChatCompletion.Choices[0].FinishReason)),
256					},
257				}
258				close(eventChan)
259				return
260			}
261
262			// If there is an error we are going to see if we can retry the call
263			retry, after, retryErr := o.shouldRetry(attempts, err)
264			if retryErr != nil {
265				eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
266				close(eventChan)
267				return
268			}
269			if retry {
270				logging.WarnPersist("Retrying due to rate limit... attempt %d of %d", logging.PersistTimeArg, time.Millisecond*time.Duration(after+100))
271				select {
272				case <-ctx.Done():
273					// context cancelled
274					if ctx.Err() == nil {
275						eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
276					}
277					close(eventChan)
278					return
279				case <-time.After(time.Duration(after) * time.Millisecond):
280					continue
281				}
282			}
283			eventChan <- ProviderEvent{Type: EventError, Error: retryErr}
284			close(eventChan)
285			return
286		}
287	}()
288
289	return eventChan
290}
291
292func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) {
293	var apierr *openai.Error
294	if !errors.As(err, &apierr) {
295		return false, 0, err
296	}
297
298	if apierr.StatusCode != 429 && apierr.StatusCode != 500 {
299		return false, 0, err
300	}
301
302	if attempts > maxRetries {
303		return false, 0, fmt.Errorf("maximum retry attempts reached for rate limit: %d retries", maxRetries)
304	}
305
306	retryMs := 0
307	retryAfterValues := apierr.Response.Header.Values("Retry-After")
308
309	backoffMs := 2000 * (1 << (attempts - 1))
310	jitterMs := int(float64(backoffMs) * 0.2)
311	retryMs = backoffMs + jitterMs
312	if len(retryAfterValues) > 0 {
313		if _, err := fmt.Sscanf(retryAfterValues[0], "%d", &retryMs); err == nil {
314			retryMs = retryMs * 1000
315		}
316	}
317	return true, int64(retryMs), nil
318}
319
320func (o *openaiClient) toolCalls(completion openai.ChatCompletion) []message.ToolCall {
321	var toolCalls []message.ToolCall
322
323	if len(completion.Choices) > 0 && len(completion.Choices[0].Message.ToolCalls) > 0 {
324		for _, call := range completion.Choices[0].Message.ToolCalls {
325			toolCall := message.ToolCall{
326				ID:    call.ID,
327				Name:  call.Function.Name,
328				Input: call.Function.Arguments,
329				Type:  "function",
330			}
331			toolCalls = append(toolCalls, toolCall)
332		}
333	}
334
335	return toolCalls
336}
337
338func (o *openaiClient) usage(completion openai.ChatCompletion) TokenUsage {
339	cachedTokens := completion.Usage.PromptTokensDetails.CachedTokens
340	inputTokens := completion.Usage.PromptTokens - cachedTokens
341
342	return TokenUsage{
343		InputTokens:         inputTokens,
344		OutputTokens:        completion.Usage.CompletionTokens,
345		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
346		CacheReadTokens:     cachedTokens,
347	}
348}
349
350func WithOpenAIBaseURL(baseURL string) OpenAIOption {
351	return func(options *openaiOptions) {
352		options.baseURL = baseURL
353	}
354}
355
356func WithOpenAIDisableCache() OpenAIOption {
357	return func(options *openaiOptions) {
358		options.disableCache = true
359	}
360}
361