1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log"
8 "strings"
9 "sync"
10
11 "github.com/kujtimiihoxha/termai/internal/app"
12 "github.com/kujtimiihoxha/termai/internal/config"
13 "github.com/kujtimiihoxha/termai/internal/llm/models"
14 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
15 "github.com/kujtimiihoxha/termai/internal/llm/provider"
16 "github.com/kujtimiihoxha/termai/internal/llm/tools"
17 "github.com/kujtimiihoxha/termai/internal/message"
18)
19
20type Agent interface {
21 Generate(sessionID string, content string) error
22}
23
24type agent struct {
25 *app.App
26 model models.Model
27 tools []tools.BaseTool
28 agent provider.Provider
29 titleGenerator provider.Provider
30}
31
32func (c *agent) handleTitleGeneration(sessionID, content string) {
33 response, err := c.titleGenerator.SendMessages(
34 c.Context,
35 []message.Message{
36 {
37 Role: message.User,
38 Parts: []message.ContentPart{
39 message.TextContent{
40 Text: content,
41 },
42 },
43 },
44 },
45 nil,
46 )
47 if err != nil {
48 return
49 }
50
51 session, err := c.Sessions.Get(sessionID)
52 if err != nil {
53 return
54 }
55 if response.Content != "" {
56 session.Title = response.Content
57 session.Title = strings.TrimSpace(session.Title)
58 session.Title = strings.ReplaceAll(session.Title, "\n", " ")
59 c.Sessions.Save(session)
60 }
61}
62
63func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
64 session, err := c.Sessions.Get(sessionID)
65 if err != nil {
66 return err
67 }
68
69 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
70 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
71 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
72 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
73
74 session.Cost += cost
75 session.CompletionTokens += usage.OutputTokens
76 session.PromptTokens += usage.InputTokens
77
78 _, err = c.Sessions.Save(session)
79 return err
80}
81
82func (c *agent) processEvent(
83 sessionID string,
84 assistantMsg *message.Message,
85 event provider.ProviderEvent,
86) error {
87 switch event.Type {
88 case provider.EventThinkingDelta:
89 assistantMsg.AppendReasoningContent(event.Content)
90 return c.Messages.Update(*assistantMsg)
91 case provider.EventContentDelta:
92 assistantMsg.AppendContent(event.Content)
93 return c.Messages.Update(*assistantMsg)
94 case provider.EventError:
95 log.Println("error", event.Error)
96 return event.Error
97
98 case provider.EventComplete:
99 assistantMsg.SetToolCalls(event.Response.ToolCalls)
100 assistantMsg.AddFinish(event.Response.FinishReason)
101 err := c.Messages.Update(*assistantMsg)
102 if err != nil {
103 return err
104 }
105 return c.TrackUsage(sessionID, c.model, event.Response.Usage)
106 }
107
108 return nil
109}
110
111func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
112 var wg sync.WaitGroup
113 toolResults := make([]message.ToolResult, len(toolCalls))
114 mutex := &sync.Mutex{}
115
116 for i, tc := range toolCalls {
117 wg.Add(1)
118 go func(index int, toolCall message.ToolCall) {
119 defer wg.Done()
120
121 response := ""
122 isError := false
123 found := false
124
125 for _, tool := range tls {
126 if tool.Info().Name == toolCall.Name {
127 found = true
128 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
129 ID: toolCall.ID,
130 Name: toolCall.Name,
131 Input: toolCall.Input,
132 })
133 if toolErr != nil {
134 response = fmt.Sprintf("error running tool: %s", toolErr)
135 isError = true
136 } else {
137 response = toolResult.Content
138 isError = toolResult.IsError
139 }
140 break
141 }
142 }
143
144 if !found {
145 response = fmt.Sprintf("tool not found: %s", toolCall.Name)
146 isError = true
147 }
148
149 mutex.Lock()
150 defer mutex.Unlock()
151
152 toolResults[index] = message.ToolResult{
153 ToolCallID: toolCall.ID,
154 Content: response,
155 IsError: isError,
156 }
157 }(i, tc)
158 }
159
160 wg.Wait()
161 return toolResults, nil
162}
163
164func (c *agent) handleToolExecution(
165 ctx context.Context,
166 assistantMsg message.Message,
167) (*message.Message, error) {
168 if len(assistantMsg.ToolCalls()) == 0 {
169 return nil, nil
170 }
171
172 toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
173 if err != nil {
174 return nil, err
175 }
176 parts := make([]message.ContentPart, 0)
177 for _, toolResult := range toolResults {
178 parts = append(parts, toolResult)
179 }
180 msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
181 Role: message.Tool,
182 Parts: parts,
183 })
184
185 return &msg, err
186}
187
188func (c *agent) generate(sessionID string, content string) error {
189 messages, err := c.Messages.List(sessionID)
190 if err != nil {
191 return err
192 }
193
194 if len(messages) == 0 {
195 go c.handleTitleGeneration(sessionID, content)
196 }
197
198 userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
199 Role: message.User,
200 Parts: []message.ContentPart{
201 message.TextContent{
202 Text: content,
203 },
204 },
205 })
206 if err != nil {
207 return err
208 }
209
210 messages = append(messages, userMsg)
211 for {
212
213 eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
214 if err != nil {
215 return err
216 }
217
218 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
219 Role: message.Assistant,
220 Parts: []message.ContentPart{},
221 })
222 if err != nil {
223 return err
224 }
225 for event := range eventChan {
226 err = c.processEvent(sessionID, &assistantMsg, event)
227 if err != nil {
228 assistantMsg.AddFinish("error:" + err.Error())
229 c.Messages.Update(assistantMsg)
230 return err
231 }
232 }
233
234 msg, err := c.handleToolExecution(c.Context, assistantMsg)
235
236 c.Messages.Update(assistantMsg)
237 if err != nil {
238 return err
239 }
240
241 if len(assistantMsg.ToolCalls()) == 0 {
242 break
243 }
244
245 messages = append(messages, assistantMsg)
246 if msg != nil {
247 messages = append(messages, *msg)
248 }
249 }
250 return nil
251}
252
253func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
254 maxTokens := config.Get().Model.CoderMaxTokens
255
256 providerConfig, ok := config.Get().Providers[model.Provider]
257 if !ok || !providerConfig.Enabled {
258 return nil, nil, errors.New("provider is not enabled")
259 }
260 var agentProvider provider.Provider
261 var titleGenerator provider.Provider
262
263 switch model.Provider {
264 case models.ProviderOpenAI:
265 var err error
266 agentProvider, err = provider.NewOpenAIProvider(
267 provider.WithOpenAISystemMessage(
268 prompt.CoderOpenAISystemPrompt(),
269 ),
270 provider.WithOpenAIMaxTokens(maxTokens),
271 provider.WithOpenAIModel(model),
272 provider.WithOpenAIKey(providerConfig.APIKey),
273 )
274 if err != nil {
275 return nil, nil, err
276 }
277 titleGenerator, err = provider.NewOpenAIProvider(
278 provider.WithOpenAISystemMessage(
279 prompt.TitlePrompt(),
280 ),
281 provider.WithOpenAIMaxTokens(80),
282 provider.WithOpenAIModel(model),
283 provider.WithOpenAIKey(providerConfig.APIKey),
284 )
285 if err != nil {
286 return nil, nil, err
287 }
288 case models.ProviderAnthropic:
289 var err error
290 agentProvider, err = provider.NewAnthropicProvider(
291 provider.WithAnthropicSystemMessage(
292 prompt.CoderAnthropicSystemPrompt(),
293 ),
294 provider.WithAnthropicMaxTokens(maxTokens),
295 provider.WithAnthropicKey(providerConfig.APIKey),
296 provider.WithAnthropicModel(model),
297 )
298 if err != nil {
299 return nil, nil, err
300 }
301 titleGenerator, err = provider.NewAnthropicProvider(
302 provider.WithAnthropicSystemMessage(
303 prompt.TitlePrompt(),
304 ),
305 provider.WithAnthropicMaxTokens(80),
306 provider.WithAnthropicKey(providerConfig.APIKey),
307 provider.WithAnthropicModel(model),
308 )
309 if err != nil {
310 return nil, nil, err
311 }
312
313 case models.ProviderGemini:
314 var err error
315 agentProvider, err = provider.NewGeminiProvider(
316 ctx,
317 provider.WithGeminiSystemMessage(
318 prompt.CoderOpenAISystemPrompt(),
319 ),
320 provider.WithGeminiMaxTokens(int32(maxTokens)),
321 provider.WithGeminiKey(providerConfig.APIKey),
322 provider.WithGeminiModel(model),
323 )
324 if err != nil {
325 return nil, nil, err
326 }
327 titleGenerator, err = provider.NewGeminiProvider(
328 ctx,
329 provider.WithGeminiSystemMessage(
330 prompt.TitlePrompt(),
331 ),
332 provider.WithGeminiMaxTokens(80),
333 provider.WithGeminiKey(providerConfig.APIKey),
334 provider.WithGeminiModel(model),
335 )
336 if err != nil {
337 return nil, nil, err
338 }
339 case models.ProviderGROQ:
340 var err error
341 agentProvider, err = provider.NewOpenAIProvider(
342 provider.WithOpenAISystemMessage(
343 prompt.CoderAnthropicSystemPrompt(),
344 ),
345 provider.WithOpenAIMaxTokens(maxTokens),
346 provider.WithOpenAIModel(model),
347 provider.WithOpenAIKey(providerConfig.APIKey),
348 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
349 )
350 if err != nil {
351 return nil, nil, err
352 }
353 titleGenerator, err = provider.NewOpenAIProvider(
354 provider.WithOpenAISystemMessage(
355 prompt.TitlePrompt(),
356 ),
357 provider.WithOpenAIMaxTokens(80),
358 provider.WithOpenAIModel(model),
359 provider.WithOpenAIKey(providerConfig.APIKey),
360 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
361 )
362 if err != nil {
363 return nil, nil, err
364 }
365
366 }
367
368 return agentProvider, titleGenerator, nil
369}