1package provider
2
3import (
4 "context"
5 "errors"
6
7 "github.com/kujtimiihoxha/termai/internal/llm/models"
8 "github.com/kujtimiihoxha/termai/internal/llm/tools"
9 "github.com/kujtimiihoxha/termai/internal/message"
10 "github.com/openai/openai-go"
11 "github.com/openai/openai-go/option"
12)
13
14type openaiProvider struct {
15 client openai.Client
16 model models.Model
17 maxTokens int64
18 baseURL string
19 apiKey string
20 systemMessage string
21}
22
23type OpenAIOption func(*openaiProvider)
24
25func NewOpenAIProvider(opts ...OpenAIOption) (Provider, error) {
26 provider := &openaiProvider{
27 maxTokens: 5000,
28 }
29
30 for _, opt := range opts {
31 opt(provider)
32 }
33
34 clientOpts := []option.RequestOption{
35 option.WithAPIKey(provider.apiKey),
36 }
37 if provider.baseURL != "" {
38 clientOpts = append(clientOpts, option.WithBaseURL(provider.baseURL))
39 }
40
41 provider.client = openai.NewClient(clientOpts...)
42 if provider.systemMessage == "" {
43 return nil, errors.New("system message is required")
44 }
45
46 return provider, nil
47}
48
49func WithOpenAISystemMessage(message string) OpenAIOption {
50 return func(p *openaiProvider) {
51 p.systemMessage = message
52 }
53}
54
55func WithOpenAIMaxTokens(maxTokens int64) OpenAIOption {
56 return func(p *openaiProvider) {
57 p.maxTokens = maxTokens
58 }
59}
60
61func WithOpenAIModel(model models.Model) OpenAIOption {
62 return func(p *openaiProvider) {
63 p.model = model
64 }
65}
66
67func WithOpenAIBaseURL(baseURL string) OpenAIOption {
68 return func(p *openaiProvider) {
69 p.baseURL = baseURL
70 }
71}
72
73func WithOpenAIKey(apiKey string) OpenAIOption {
74 return func(p *openaiProvider) {
75 p.apiKey = apiKey
76 }
77}
78
79func (p *openaiProvider) convertToOpenAIMessages(messages []message.Message) []openai.ChatCompletionMessageParamUnion {
80 var chatMessages []openai.ChatCompletionMessageParamUnion
81
82 chatMessages = append(chatMessages, openai.SystemMessage(p.systemMessage))
83
84 for _, msg := range messages {
85 switch msg.Role {
86 case message.User:
87 chatMessages = append(chatMessages, openai.UserMessage(msg.Content().String()))
88
89 case message.Assistant:
90 assistantMsg := openai.ChatCompletionAssistantMessageParam{
91 Role: "assistant",
92 }
93
94 if msg.Content().String() != "" {
95 assistantMsg.Content = openai.ChatCompletionAssistantMessageParamContentUnion{
96 OfString: openai.String(msg.Content().String()),
97 }
98 }
99
100 if len(msg.ToolCalls()) > 0 {
101 assistantMsg.ToolCalls = make([]openai.ChatCompletionMessageToolCallParam, len(msg.ToolCalls()))
102 for i, call := range msg.ToolCalls() {
103 assistantMsg.ToolCalls[i] = openai.ChatCompletionMessageToolCallParam{
104 ID: call.ID,
105 Type: "function",
106 Function: openai.ChatCompletionMessageToolCallFunctionParam{
107 Name: call.Name,
108 Arguments: call.Input,
109 },
110 }
111 }
112 }
113
114 chatMessages = append(chatMessages, openai.ChatCompletionMessageParamUnion{
115 OfAssistant: &assistantMsg,
116 })
117
118 case message.Tool:
119 for _, result := range msg.ToolResults() {
120 chatMessages = append(chatMessages,
121 openai.ToolMessage(result.Content, result.ToolCallID),
122 )
123 }
124 }
125 }
126
127 return chatMessages
128}
129
130func (p *openaiProvider) convertToOpenAITools(tools []tools.BaseTool) []openai.ChatCompletionToolParam {
131 openaiTools := make([]openai.ChatCompletionToolParam, len(tools))
132
133 for i, tool := range tools {
134 info := tool.Info()
135 openaiTools[i] = openai.ChatCompletionToolParam{
136 Function: openai.FunctionDefinitionParam{
137 Name: info.Name,
138 Description: openai.String(info.Description),
139 Parameters: openai.FunctionParameters{
140 "type": "object",
141 "properties": info.Parameters,
142 "required": info.Required,
143 },
144 },
145 }
146 }
147
148 return openaiTools
149}
150
151func (p *openaiProvider) extractTokenUsage(usage openai.CompletionUsage) TokenUsage {
152 cachedTokens := int64(0)
153
154 cachedTokens = usage.PromptTokensDetails.CachedTokens
155 inputTokens := usage.PromptTokens - cachedTokens
156
157 return TokenUsage{
158 InputTokens: inputTokens,
159 OutputTokens: usage.CompletionTokens,
160 CacheCreationTokens: 0, // OpenAI doesn't provide this directly
161 CacheReadTokens: cachedTokens,
162 }
163}
164
165func (p *openaiProvider) SendMessages(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
166 messages = cleanupMessages(messages)
167 chatMessages := p.convertToOpenAIMessages(messages)
168 openaiTools := p.convertToOpenAITools(tools)
169
170 params := openai.ChatCompletionNewParams{
171 Model: openai.ChatModel(p.model.APIModel),
172 Messages: chatMessages,
173 MaxTokens: openai.Int(p.maxTokens),
174 Tools: openaiTools,
175 }
176
177 response, err := p.client.Chat.Completions.New(ctx, params)
178 if err != nil {
179 return nil, err
180 }
181
182 content := ""
183 if response.Choices[0].Message.Content != "" {
184 content = response.Choices[0].Message.Content
185 }
186
187 var toolCalls []message.ToolCall
188 if len(response.Choices[0].Message.ToolCalls) > 0 {
189 toolCalls = make([]message.ToolCall, len(response.Choices[0].Message.ToolCalls))
190 for i, call := range response.Choices[0].Message.ToolCalls {
191 toolCalls[i] = message.ToolCall{
192 ID: call.ID,
193 Name: call.Function.Name,
194 Input: call.Function.Arguments,
195 Type: "function",
196 }
197 }
198 }
199
200 tokenUsage := p.extractTokenUsage(response.Usage)
201
202 return &ProviderResponse{
203 Content: content,
204 ToolCalls: toolCalls,
205 Usage: tokenUsage,
206 }, nil
207}
208
209func (p *openaiProvider) StreamResponse(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (<-chan ProviderEvent, error) {
210 messages = cleanupMessages(messages)
211 chatMessages := p.convertToOpenAIMessages(messages)
212 openaiTools := p.convertToOpenAITools(tools)
213
214 params := openai.ChatCompletionNewParams{
215 Model: openai.ChatModel(p.model.APIModel),
216 Messages: chatMessages,
217 MaxTokens: openai.Int(p.maxTokens),
218 Tools: openaiTools,
219 StreamOptions: openai.ChatCompletionStreamOptionsParam{
220 IncludeUsage: openai.Bool(true),
221 },
222 }
223
224 stream := p.client.Chat.Completions.NewStreaming(ctx, params)
225
226 eventChan := make(chan ProviderEvent)
227
228 toolCalls := make([]message.ToolCall, 0)
229 go func() {
230 defer close(eventChan)
231
232 acc := openai.ChatCompletionAccumulator{}
233 currentContent := ""
234
235 for stream.Next() {
236 chunk := stream.Current()
237 acc.AddChunk(chunk)
238
239 if tool, ok := acc.JustFinishedToolCall(); ok {
240 toolCalls = append(toolCalls, message.ToolCall{
241 ID: tool.Id,
242 Name: tool.Name,
243 Input: tool.Arguments,
244 Type: "function",
245 })
246 }
247
248 for _, choice := range chunk.Choices {
249 if choice.Delta.Content != "" {
250 eventChan <- ProviderEvent{
251 Type: EventContentDelta,
252 Content: choice.Delta.Content,
253 }
254 currentContent += choice.Delta.Content
255 }
256 }
257 }
258
259 if err := stream.Err(); err != nil {
260 eventChan <- ProviderEvent{
261 Type: EventError,
262 Error: err,
263 }
264 return
265 }
266
267 tokenUsage := p.extractTokenUsage(acc.Usage)
268
269 eventChan <- ProviderEvent{
270 Type: EventComplete,
271 Response: &ProviderResponse{
272 Content: currentContent,
273 ToolCalls: toolCalls,
274 Usage: tokenUsage,
275 },
276 }
277 }()
278
279 return eventChan, nil
280}