1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log"
8 "sync"
9
10 "github.com/kujtimiihoxha/termai/internal/app"
11 "github.com/kujtimiihoxha/termai/internal/config"
12 "github.com/kujtimiihoxha/termai/internal/llm/models"
13 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
14 "github.com/kujtimiihoxha/termai/internal/llm/provider"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/message"
17)
18
19type Agent interface {
20 Generate(sessionID string, content string) error
21}
22
23type agent struct {
24 *app.App
25 model models.Model
26 tools []tools.BaseTool
27 agent provider.Provider
28 titleGenerator provider.Provider
29}
30
31func (c *agent) handleTitleGeneration(sessionID, content string) {
32 response, err := c.titleGenerator.SendMessages(
33 c.Context,
34 []message.Message{
35 {
36 Role: message.User,
37 Content: content,
38 },
39 },
40 nil,
41 )
42 if err != nil {
43 return
44 }
45
46 session, err := c.Sessions.Get(sessionID)
47 if err != nil {
48 return
49 }
50 if response.Content != "" {
51 session.Title = response.Content
52 c.Sessions.Save(session)
53 }
54}
55
56func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
57 session, err := c.Sessions.Get(sessionID)
58 if err != nil {
59 return err
60 }
61
62 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
63 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
64 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
65 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
66
67 session.Cost += cost
68 session.CompletionTokens += usage.OutputTokens
69 session.PromptTokens += usage.InputTokens
70
71 _, err = c.Sessions.Save(session)
72 return err
73}
74
75func (c *agent) processEvent(
76 sessionID string,
77 assistantMsg *message.Message,
78 event provider.ProviderEvent,
79) error {
80 switch event.Type {
81 case provider.EventThinkingDelta:
82 assistantMsg.Thinking += event.Thinking
83 return c.Messages.Update(*assistantMsg)
84 case provider.EventContentDelta:
85 assistantMsg.Content += event.Content
86 return c.Messages.Update(*assistantMsg)
87 case provider.EventError:
88 log.Println("error", event.Error)
89 return event.Error
90
91 case provider.EventComplete:
92 assistantMsg.ToolCalls = event.Response.ToolCalls
93 err := c.Messages.Update(*assistantMsg)
94 if err != nil {
95 return err
96 }
97 return c.TrackUsage(sessionID, c.model, event.Response.Usage)
98 }
99
100 return nil
101}
102
103func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
104 var wg sync.WaitGroup
105 toolResults := make([]message.ToolResult, len(toolCalls))
106 mutex := &sync.Mutex{}
107
108 for i, tc := range toolCalls {
109 wg.Add(1)
110 go func(index int, toolCall message.ToolCall) {
111 defer wg.Done()
112
113 response := ""
114 isError := false
115 found := false
116
117 for _, tool := range tls {
118 if tool.Info().Name == toolCall.Name {
119 found = true
120 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
121 ID: toolCall.ID,
122 Name: toolCall.Name,
123 Input: toolCall.Input,
124 })
125 if toolErr != nil {
126 response = fmt.Sprintf("error running tool: %s", toolErr)
127 isError = true
128 } else {
129 response = toolResult.Content
130 isError = toolResult.IsError
131 }
132 break
133 }
134 }
135
136 if !found {
137 response = fmt.Sprintf("tool not found: %s", toolCall.Name)
138 isError = true
139 }
140
141 mutex.Lock()
142 defer mutex.Unlock()
143
144 toolResults[index] = message.ToolResult{
145 ToolCallID: toolCall.ID,
146 Content: response,
147 IsError: isError,
148 }
149 }(i, tc)
150 }
151
152 wg.Wait()
153 return toolResults, nil
154}
155
156func (c *agent) handleToolExecution(
157 ctx context.Context,
158 assistantMsg message.Message,
159) (*message.Message, error) {
160 if len(assistantMsg.ToolCalls) == 0 {
161 return nil, nil
162 }
163
164 toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls, c.tools)
165 if err != nil {
166 return nil, err
167 }
168
169 msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
170 Role: message.Tool,
171 ToolResults: toolResults,
172 })
173
174 return &msg, err
175}
176
177func (c *agent) generate(sessionID string, content string) error {
178 messages, err := c.Messages.List(sessionID)
179 if err != nil {
180 return err
181 }
182
183 if len(messages) == 0 {
184 go c.handleTitleGeneration(sessionID, content)
185 }
186
187 userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
188 Role: message.User,
189 Content: content,
190 })
191 if err != nil {
192 return err
193 }
194
195 messages = append(messages, userMsg)
196 for {
197
198 eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
199 if err != nil {
200 return err
201 }
202
203 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
204 Role: message.Assistant,
205 Content: "",
206 })
207 if err != nil {
208 return err
209 }
210 for event := range eventChan {
211 err = c.processEvent(sessionID, &assistantMsg, event)
212 if err != nil {
213 assistantMsg.Finished = true
214 c.Messages.Update(assistantMsg)
215 return err
216 }
217 }
218
219 msg, err := c.handleToolExecution(c.Context, assistantMsg)
220 assistantMsg.Finished = true
221 c.Messages.Update(assistantMsg)
222 if err != nil {
223 return err
224 }
225
226 if len(assistantMsg.ToolCalls) == 0 {
227 break
228 }
229
230 messages = append(messages, assistantMsg)
231 if msg != nil {
232 messages = append(messages, *msg)
233 }
234 }
235 return nil
236}
237
238func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
239 maxTokens := config.Get().Model.CoderMaxTokens
240
241 providerConfig, ok := config.Get().Providers[model.Provider]
242 if !ok || !providerConfig.Enabled {
243 return nil, nil, errors.New("provider is not enabled")
244 }
245 var agentProvider provider.Provider
246 var titleGenerator provider.Provider
247
248 switch model.Provider {
249 case models.ProviderOpenAI:
250 var err error
251 agentProvider, err = provider.NewOpenAIProvider(
252 provider.WithOpenAISystemMessage(
253 prompt.CoderOpenAISystemPrompt(),
254 ),
255 provider.WithOpenAIMaxTokens(maxTokens),
256 provider.WithOpenAIModel(model),
257 provider.WithOpenAIKey(providerConfig.APIKey),
258 )
259 if err != nil {
260 return nil, nil, err
261 }
262 titleGenerator, err = provider.NewOpenAIProvider(
263 provider.WithOpenAISystemMessage(
264 prompt.TitlePrompt(),
265 ),
266 provider.WithOpenAIMaxTokens(80),
267 provider.WithOpenAIModel(model),
268 provider.WithOpenAIKey(providerConfig.APIKey),
269 )
270 if err != nil {
271 return nil, nil, err
272 }
273 case models.ProviderAnthropic:
274 var err error
275 agentProvider, err = provider.NewAnthropicProvider(
276 provider.WithAnthropicSystemMessage(
277 prompt.CoderAnthropicSystemPrompt(),
278 ),
279 provider.WithAnthropicMaxTokens(maxTokens),
280 provider.WithAnthropicKey(providerConfig.APIKey),
281 provider.WithAnthropicModel(model),
282 )
283 if err != nil {
284 return nil, nil, err
285 }
286 titleGenerator, err = provider.NewAnthropicProvider(
287 provider.WithAnthropicSystemMessage(
288 prompt.TitlePrompt(),
289 ),
290 provider.WithAnthropicMaxTokens(80),
291 provider.WithAnthropicKey(providerConfig.APIKey),
292 provider.WithAnthropicModel(model),
293 )
294 if err != nil {
295 return nil, nil, err
296 }
297
298 case models.ProviderGemini:
299 var err error
300 agentProvider, err = provider.NewGeminiProvider(
301 ctx,
302 provider.WithGeminiSystemMessage(
303 prompt.CoderOpenAISystemPrompt(),
304 ),
305 provider.WithGeminiMaxTokens(int32(maxTokens)),
306 provider.WithGeminiKey(providerConfig.APIKey),
307 provider.WithGeminiModel(model),
308 )
309 if err != nil {
310 return nil, nil, err
311 }
312 titleGenerator, err = provider.NewGeminiProvider(
313 ctx,
314 provider.WithGeminiSystemMessage(
315 prompt.TitlePrompt(),
316 ),
317 provider.WithGeminiMaxTokens(80),
318 provider.WithGeminiKey(providerConfig.APIKey),
319 provider.WithGeminiModel(model),
320 )
321 if err != nil {
322 return nil, nil, err
323 }
324 case models.ProviderGROQ:
325 var err error
326 agentProvider, err = provider.NewOpenAIProvider(
327 provider.WithOpenAISystemMessage(
328 prompt.CoderAnthropicSystemPrompt(),
329 ),
330 provider.WithOpenAIMaxTokens(maxTokens),
331 provider.WithOpenAIModel(model),
332 provider.WithOpenAIKey(providerConfig.APIKey),
333 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
334 )
335 if err != nil {
336 return nil, nil, err
337 }
338 titleGenerator, err = provider.NewOpenAIProvider(
339 provider.WithOpenAISystemMessage(
340 prompt.TitlePrompt(),
341 ),
342 provider.WithOpenAIMaxTokens(80),
343 provider.WithOpenAIModel(model),
344 provider.WithOpenAIKey(providerConfig.APIKey),
345 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
346 )
347 if err != nil {
348 return nil, nil, err
349 }
350
351 }
352
353 return agentProvider, titleGenerator, nil
354}