1package provider
2
3import (
4 "context"
5 "errors"
6
7 "github.com/kujtimiihoxha/termai/internal/llm/models"
8 "github.com/kujtimiihoxha/termai/internal/llm/tools"
9 "github.com/kujtimiihoxha/termai/internal/message"
10 "github.com/openai/openai-go"
11 "github.com/openai/openai-go/option"
12)
13
14type openaiProvider struct {
15 client openai.Client
16 model models.Model
17 maxTokens int64
18 baseURL string
19 apiKey string
20 systemMessage string
21}
22
23type OpenAIOption func(*openaiProvider)
24
25func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
26 provider := &openaiProvider{
27 maxTokens: 5000,
28 }
29
30 for _, opt := range opts {
31 opt(provider)
32 }
33
34 clientOpts := []option.RequestOption{
35 option.WithAPIKey(provider.apiKey),
36 }
37 if provider.baseURL != "" {
38 clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
39 }
40
41 provider.client = openai.NewClient(clientOpts...)
42 if provider.systemMessage == "" {
43 return nil, errors.New("system message is required")
44 }
45
46 return provider, nil
47}
48
49func WithOpenAISystemMessage(message string) OpenAIOption {
50 return func(p *openaiProvider) {
51 p.systemMessage = message
52 }
53}
54
55func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
56 return func(p *openaiProvider) {
57 p.maxTokens = maxTokens
58 }
59}
60
61func WithOpenAIModel(model models.Model) OpenAIOption {
62 return func(p *openaiProvider) {
63 p.model = model
64 }
65}
66
67func WithOpenAIBaseURL(baseURL string) OpenAIOption {
68 return func(p *openaiProvider) {
69 p.baseURL = baseURL
70 }
71}
72
73func WithOpenAIKey(apiKey string) OpenAIOption {
74 return func(p *openaiProvider) {
75 p.apiKey = apiKey
76 }
77}
78
79func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
80 var chatMessages []openai.ChatCompletionMessageParamUnion
81
82 chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
83
84 for _, msg := range messages {
85 switch msg.Role {
86 case message.User:
87 chatMessages = append(chatMessages, openai.UserMessage(msg.Content))
88
89 case message.Assistant:
90 assistantMsg := openai.ChatCompletionAssistantMessageParam{
91 Role: "assistant",
92 }
93
94 if msg.Content != "" {
95 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
96 OfString: openai.String(msg.Content),
97 }
98 }
99
100 if len(msg.ToolCalls) > 0 {
101 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls))
102 for i, call := range msg.ToolCalls {
103 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
104 ID: call.ID,
105 Type: "function",
106 Function: openai.ChatCompletionMessageToolCallFunctionParam{
107 Name: call.Name,
108 Arguments: call.Input,
109 },
110 }
111 }
112 }
113
114 chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
115 OfAssistant: &assistantMsg,
116 })
117
118 case message.Tool:
119 for _, result := range msg.ToolResults {
120 chatMessages = append(chatMessages,
121 openai.ToolMessage(result.Content, result.ToolCallID),
122 )
123 }
124 }
125 }
126
127 return chatMessages
128}
129
130func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
131 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
132
133 for i, tool := range tools {
134 info := tool.Info()
135 openaiTools[i] = openai.ChatCompletionToolParam{
136 Function: openai.FunctionDefinitionParam{
137 Name: info.Name,
138 Description: openai.String(info.Description),
139 Parameters: openai.FunctionParameters{
140 "type": "object",
141 "properties": info.Parameters,
142 "required": info.Required,
143 },
144 },
145 }
146 }
147
148 return openaiTools
149}
150
151func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
152 cachedTokens := int64(0)
153
154 cachedTokens = usage.PromptTokensDetails.CachedTokens
155 inputTokens := usage.PromptTokens - cachedTokens
156
157 return TokenUsage{
158 InputTokens: inputTokens,
159 OutputTokens: usage.CompletionTokens,
160 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
161 CacheReadTokens: cachedTokens,
162 }
163}
164
165func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
166 chatMessages := p.convertToOpenAIMessages(messages)
167 openaiTools := p.convertToOpenAITools(tools)
168
169 params := openai.ChatCompletionNewParams{
170 Model: openai.ChatModel(p.model.APIModel),
171 Messages: chatMessages,
172 MaxTokens: openai.Int(p.maxTokens),
173 Tools: openaiTools,
174 }
175
176 response, err := p.client.Chat.Completions.New(ctx, params)
177 if err != nil {
178 return nil, err
179 }
180
181 content := ""
182 if response.Choices[0].Message.Content != "" {
183 content = response.Choices[0].Message.Content
184 }
185
186 var toolCalls []message.ToolCall
187 if len(response.Choices[0].Message.ToolCalls) > 0 {
188 toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
189 for i, call := range response.Choices[0].Message.ToolCalls {
190 toolCalls[i] = message.ToolCall{
191 ID: call.ID,
192 Name: call.Function.Name,
193 Input: call.Function.Arguments,
194 Type: "function",
195 }
196 }
197 }
198
199 tokenUsage := p.extractTokenUsage(response.Usage)
200
201 return &ProviderResponse{
202 Content: content,
203 ToolCalls: toolCalls,
204 Usage: tokenUsage,
205 }, nil
206}
207
208func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
209 chatMessages := p.convertToOpenAIMessages(messages)
210 openaiTools := p.convertToOpenAITools(tools)
211
212 params := openai.ChatCompletionNewParams{
213 Model: openai.ChatModel(p.model.APIModel),
214 Messages: chatMessages,
215 MaxTokens: openai.Int(p.maxTokens),
216 Tools: openaiTools,
217 StreamOptions: openai.ChatCompletionStreamOptionsParam{
218 IncludeUsage: openai.Bool(true),
219 },
220 }
221
222 stream := p.client.Chat.Completions.NewStreaming(ctx, params)
223
224 eventChan := make(chan ProviderEvent)
225
226 toolCalls := make([]message.ToolCall, 0)
227 go func() {
228 defer close(eventChan)
229
230 acc := openai.ChatCompletionAccumulator{}
231 currentContent := ""
232
233 for stream.Next() {
234 chunk := stream.Current()
235 acc.AddChunk(chunk)
236
237 if tool, ok := acc.JustFinishedToolCall(); ok {
238 toolCalls = append(toolCalls, message.ToolCall{
239 ID: tool.Id,
240 Name: tool.Name,
241 Input: tool.Arguments,
242 Type: "function",
243 })
244 }
245
246 for _, choice := range chunk.Choices {
247 if choice.Delta.Content != "" {
248 eventChan <- ProviderEvent{
249 Type: EventContentDelta,
250 Content: choice.Delta.Content,
251 }
252 currentContent += choice.Delta.Content
253 }
254 }
255 }
256
257 if err := stream.Err(); err != nil {
258 eventChan <- ProviderEvent{
259 Type: EventError,
260 Error: err,
261 }
262 return
263 }
264
265 tokenUsage := p.extractTokenUsage(acc.Usage)
266
267 eventChan <- ProviderEvent{
268 Type: EventComplete,
269 Response: &ProviderResponse{
270 Content: currentContent,
271 ToolCalls: toolCalls,
272 Usage: tokenUsage,
273 },
274 }
275 }()
276
277 return eventChan, nil
278}