openai.go

  1package provider
  2
  3import (
  4	"context"
  5	"errors"
  6
  7	"github.com/kujtimiihoxha/termai/internal/llm/models"
  8	"github.com/kujtimiihoxha/termai/internal/llm/tools"
  9	"github.com/kujtimiihoxha/termai/internal/message"
 10	"github.com/openai/openai-go"
 11	"github.com/openai/openai-go/option"
 12)
 13
 14type openaiProvider struct {
 15	client        openai.Client
 16	model         models.Model
 17	maxTokens     int64
 18	baseURL       string
 19	apiKey        string
 20	systemMessage string
 21}
 22
 23type OpenAIOption func(*openaiProvider)
 24
 25func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
 26	provider := &openaiProvider{
 27		maxTokens: 5000,
 28	}
 29
 30	for _, opt := range opts {
 31		opt(provider)
 32	}
 33
 34	clientOpts := []option.RequestOption{
 35		option.WithAPIKey(provider.apiKey),
 36	}
 37	if provider.baseURL != "" {
 38		clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
 39	}
 40
 41	provider.client = openai.NewClient(clientOpts...)
 42	if provider.systemMessage == "" {
 43		return nil, errors.New("system message is required")
 44	}
 45
 46	return provider, nil
 47}
 48
 49func WithOpenAISystemMessage(message string) OpenAIOption {
 50	return func(p *openaiProvider) {
 51		p.systemMessage = message
 52	}
 53}
 54
 55func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
 56	return func(p *openaiProvider) {
 57		p.maxTokens = maxTokens
 58	}
 59}
 60
 61func WithOpenAIModel(model models.Model) OpenAIOption {
 62	return func(p *openaiProvider) {
 63		p.model = model
 64	}
 65}
 66
 67func WithOpenAIBaseURL(baseURL string) OpenAIOption {
 68	return func(p *openaiProvider) {
 69		p.baseURL = baseURL
 70	}
 71}
 72
 73func WithOpenAIKey(apiKey string) OpenAIOption {
 74	return func(p *openaiProvider) {
 75		p.apiKey = apiKey
 76	}
 77}
 78
 79func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
 80	var chatMessages []openai.ChatCompletionMessageParamUnion
 81
 82	chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
 83
 84	for _, msg := range messages {
 85		switch msg.Role {
 86		case message.User:
 87			chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String()))
 88
 89		case message.Assistant:
 90			assistantMsg := openai.ChatCompletionAssistantMessageParam{
 91				Role: "assistant",
 92			}
 93
 94			if msg.Content().String() != "" {
 95				assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
 96					OfString: openai.String(msg.Content().String()),
 97				}
 98			}
 99
100			if len(msg.ToolCalls()) > 0 {
101				assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
102				for i, call := range msg.ToolCalls() {
103					assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
104						ID:   call.ID,
105						Type: "function",
106						Function: openai.ChatCompletionMessageToolCallFunctionParam{
107							Name:      call.Name,
108							Arguments: call.Input,
109						},
110					}
111				}
112			}
113
114			chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
115				OfAssistant: &assistantMsg,
116			})
117
118		case message.Tool:
119			for _, result := range msg.ToolResults() {
120				chatMessages = append(chatMessages,
121					openai.ToolMessage(result.Content, result.ToolCallID),
122				)
123			}
124		}
125	}
126
127	return chatMessages
128}
129
130func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
131	openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
132
133	for i, tool := range tools {
134		info := tool.Info()
135		openaiTools[i] = openai.ChatCompletionToolParam{
136			Function: openai.FunctionDefinitionParam{
137				Name:        info.Name,
138				Description: openai.String(info.Description),
139				Parameters: openai.FunctionParameters{
140					"type":       "object",
141					"properties": info.Parameters,
142					"required":   info.Required,
143				},
144			},
145		}
146	}
147
148	return openaiTools
149}
150
151func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
152	cachedTokens := int64(0)
153
154	cachedTokens = usage.PromptTokensDetails.CachedTokens
155	inputTokens := usage.PromptTokens - cachedTokens
156
157	return TokenUsage{
158		InputTokens:         inputTokens,
159		OutputTokens:        usage.CompletionTokens,
160		CacheCreationTokens: 0, // OpenAI doesn't provide this directly
161		CacheReadTokens:     cachedTokens,
162	}
163}
164
165func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
166	chatMessages := p.convertToOpenAIMessages(messages)
167	openaiTools := p.convertToOpenAITools(tools)
168
169	params := openai.ChatCompletionNewParams{
170		Model:     openai.ChatModel(p.model.APIModel),
171		Messages:  chatMessages,
172		MaxTokens: openai.Int(p.maxTokens),
173		Tools:     openaiTools,
174	}
175
176	response, err := p.client.Chat.Completions.New(ctx, params)
177	if err != nil {
178		return nil, err
179	}
180
181	content := ""
182	if response.Choices[0].Message.Content != "" {
183		content = response.Choices[0].Message.Content
184	}
185
186	var toolCalls []message.ToolCall
187	if len(response.Choices[0].Message.ToolCalls) > 0 {
188		toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
189		for i, call := range response.Choices[0].Message.ToolCalls {
190			toolCalls[i] = message.ToolCall{
191				ID:    call.ID,
192				Name:  call.Function.Name,
193				Input: call.Function.Arguments,
194				Type:  "function",
195			}
196		}
197	}
198
199	tokenUsage := p.extractTokenUsage(response.Usage)
200
201	return &ProviderResponse{
202		Content:   content,
203		ToolCalls: toolCalls,
204		Usage:     tokenUsage,
205	}, nil
206}
207
208func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
209	chatMessages := p.convertToOpenAIMessages(messages)
210	openaiTools := p.convertToOpenAITools(tools)
211
212	params := openai.ChatCompletionNewParams{
213		Model:     openai.ChatModel(p.model.APIModel),
214		Messages:  chatMessages,
215		MaxTokens: openai.Int(p.maxTokens),
216		Tools:     openaiTools,
217		StreamOptions: openai.ChatCompletionStreamOptionsParam{
218			IncludeUsage: openai.Bool(true),
219		},
220	}
221
222	stream := p.client.Chat.Completions.NewStreaming(ctx, params)
223
224	eventChan := make(chan ProviderEvent)
225
226	toolCalls := make([]message.ToolCall, 0)
227	go func() {
228		defer close(eventChan)
229
230		acc := openai.ChatCompletionAccumulator{}
231		currentContent := ""
232
233		for stream.Next() {
234			chunk := stream.Current()
235			acc.AddChunk(chunk)
236
237			if tool, ok := acc.JustFinishedToolCall(); ok {
238				toolCalls = append(toolCalls, message.ToolCall{
239					ID:    tool.Id,
240					Name:  tool.Name,
241					Input: tool.Arguments,
242					Type:  "function",
243				})
244			}
245
246			for _, choice := range chunk.Choices {
247				if choice.Delta.Content != "" {
248					eventChan <- ProviderEvent{
249						Type:    EventContentDelta,
250						Content: choice.Delta.Content,
251					}
252					currentContent += choice.Delta.Content
253				}
254			}
255		}
256
257		if err := stream.Err(); err != nil {
258			eventChan <- ProviderEvent{
259				Type:  EventError,
260				Error: err,
261			}
262			return
263		}
264
265		tokenUsage := p.extractTokenUsage(acc.Usage)
266
267		eventChan <- ProviderEvent{
268			Type: EventComplete,
269			Response: &ProviderResponse{
270				Content:   currentContent,
271				ToolCalls: toolCalls,
272				Usage:     tokenUsage,
273			},
274		}
275	}()
276
277	return eventChan, nil
278}
279