1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/kujtimiihoxha/termai/internal/app"
11 "github.com/kujtimiihoxha/termai/internal/config"
12 "github.com/kujtimiihoxha/termai/internal/llm/models"
13 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
14 "github.com/kujtimiihoxha/termai/internal/llm/provider"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/logging"
17 "github.com/kujtimiihoxha/termai/internal/message"
18)
19
20type Agent interface {
21 Generate(ctx context.Context, sessionID string, content string) error
22}
23
24type agent struct {
25 *app.App
26 model models.Model
27 tools []tools.BaseTool
28 agent provider.Provider
29 titleGenerator provider.Provider
30}
31
32func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
33 response, err := c.titleGenerator.SendMessages(
34 ctx,
35 []message.Message{
36 {
37 Role: message.User,
38 Parts: []message.ContentPart{
39 message.TextContent{
40 Text: content,
41 },
42 },
43 },
44 },
45 nil,
46 )
47 if err != nil {
48 return
49 }
50
51 session, err := c.Sessions.Get(sessionID)
52 if err != nil {
53 return
54 }
55 if response.Content != "" {
56 session.Title = response.Content
57 session.Title = strings.TrimSpace(session.Title)
58 session.Title = strings.ReplaceAll(session.Title, "\n", " ")
59 c.Sessions.Save(session)
60 }
61}
62
63func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
64 session, err := c.Sessions.Get(sessionID)
65 if err != nil {
66 return err
67 }
68
69 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
70 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
71 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
72 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
73
74 session.Cost += cost
75 session.CompletionTokens += usage.OutputTokens
76 session.PromptTokens += usage.InputTokens
77
78 _, err = c.Sessions.Save(session)
79 return err
80}
81
82func (c *agent) processEvent(
83 sessionID string,
84 assistantMsg *message.Message,
85 event provider.ProviderEvent,
86) error {
87 switch event.Type {
88 case provider.EventThinkingDelta:
89 assistantMsg.AppendReasoningContent(event.Content)
90 return c.Messages.Update(*assistantMsg)
91 case provider.EventContentDelta:
92 assistantMsg.AppendContent(event.Content)
93 return c.Messages.Update(*assistantMsg)
94 case provider.EventError:
95 if errors.Is(event.Error, context.Canceled) {
96 return nil
97 }
98 logging.ErrorPersist(event.Error.Error())
99 return event.Error
100 case provider.EventWarning:
101 logging.WarnPersist(event.Info)
102 return nil
103 case provider.EventInfo:
104 logging.InfoPersist(event.Info)
105 case provider.EventComplete:
106 assistantMsg.SetToolCalls(event.Response.ToolCalls)
107 assistantMsg.AddFinish(event.Response.FinishReason)
108 err := c.Messages.Update(*assistantMsg)
109 if err != nil {
110 return err
111 }
112 return c.TrackUsage(sessionID, c.model, event.Response.Usage)
113 }
114
115 return nil
116}
117
118func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
119 var wg sync.WaitGroup
120 toolResults := make([]message.ToolResult, len(toolCalls))
121 mutex := &sync.Mutex{}
122 errChan := make(chan error, 1)
123
124 // Create a child context that can be canceled
125 ctx, cancel := context.WithCancel(ctx)
126 defer cancel()
127
128 for i, tc := range toolCalls {
129 wg.Add(1)
130 go func(index int, toolCall message.ToolCall) {
131 defer wg.Done()
132
133 // Check if context is already canceled
134 select {
135 case <-ctx.Done():
136 mutex.Lock()
137 toolResults[index] = message.ToolResult{
138 ToolCallID: toolCall.ID,
139 Content: "Tool execution canceled",
140 IsError: true,
141 }
142 mutex.Unlock()
143
144 // Send cancellation error to error channel if it's empty
145 select {
146 case errChan <- ctx.Err():
147 default:
148 }
149 return
150 default:
151 }
152
153 response := ""
154 isError := false
155 found := false
156
157 for _, tool := range tls {
158 if tool.Info().Name == toolCall.Name {
159 found = true
160 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
161 ID: toolCall.ID,
162 Name: toolCall.Name,
163 Input: toolCall.Input,
164 })
165
166 if toolErr != nil {
167 if errors.Is(toolErr, context.Canceled) {
168 response = "Tool execution canceled"
169
170 // Send cancellation error to error channel if it's empty
171 select {
172 case errChan <- ctx.Err():
173 default:
174 }
175 } else {
176 response = fmt.Sprintf("error running tool: %s", toolErr)
177 }
178 isError = true
179 } else {
180 response = toolResult.Content
181 isError = toolResult.IsError
182 }
183 break
184 }
185 }
186
187 if !found {
188 response = fmt.Sprintf("tool not found: %s", toolCall.Name)
189 isError = true
190 }
191
192 mutex.Lock()
193 defer mutex.Unlock()
194
195 toolResults[index] = message.ToolResult{
196 ToolCallID: toolCall.ID,
197 Content: response,
198 IsError: isError,
199 }
200 }(i, tc)
201 }
202
203 // Wait for all goroutines to finish or context to be canceled
204 done := make(chan struct{})
205 go func() {
206 wg.Wait()
207 close(done)
208 }()
209
210 select {
211 case <-done:
212 // All tools completed successfully
213 case err := <-errChan:
214 // One of the tools encountered a cancellation
215 return toolResults, err
216 case <-ctx.Done():
217 // Context was canceled externally
218 return toolResults, ctx.Err()
219 }
220
221 return toolResults, nil
222}
223
224func (c *agent) handleToolExecution(
225 ctx context.Context,
226 assistantMsg message.Message,
227) (*message.Message, error) {
228 if len(assistantMsg.ToolCalls()) == 0 {
229 return nil, nil
230 }
231
232 toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
233 if err != nil {
234 return nil, err
235 }
236 parts := make([]message.ContentPart, 0)
237 for _, toolResult := range toolResults {
238 parts = append(parts, toolResult)
239 }
240 msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
241 Role: message.Tool,
242 Parts: parts,
243 })
244
245 return &msg, err
246}
247
248func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
249 messages, err := c.Messages.List(sessionID)
250 if err != nil {
251 return err
252 }
253
254 if len(messages) == 0 {
255 go c.handleTitleGeneration(ctx, sessionID, content)
256 }
257
258 userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
259 Role: message.User,
260 Parts: []message.ContentPart{
261 message.TextContent{
262 Text: content,
263 },
264 },
265 })
266 if err != nil {
267 return err
268 }
269
270 messages = append(messages, userMsg)
271 for {
272 select {
273 case <-ctx.Done():
274 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
275 Role: message.Assistant,
276 Parts: []message.ContentPart{},
277 })
278 if err != nil {
279 return err
280 }
281 assistantMsg.AddFinish("canceled")
282 c.Messages.Update(assistantMsg)
283 return context.Canceled
284 default:
285 // Continue processing
286 }
287
288 eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
289 if err != nil {
290 if errors.Is(err, context.Canceled) {
291 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
292 Role: message.Assistant,
293 Parts: []message.ContentPart{},
294 })
295 if err != nil {
296 return err
297 }
298 assistantMsg.AddFinish("canceled")
299 c.Messages.Update(assistantMsg)
300 return context.Canceled
301 }
302 return err
303 }
304
305 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
306 Role: message.Assistant,
307 Parts: []message.ContentPart{},
308 })
309 if err != nil {
310 return err
311 }
312 for event := range eventChan {
313 err = c.processEvent(sessionID, &assistantMsg, event)
314 if err != nil {
315 if errors.Is(err, context.Canceled) {
316 assistantMsg.AddFinish("canceled")
317 c.Messages.Update(assistantMsg)
318 return context.Canceled
319 }
320 assistantMsg.AddFinish("error:" + err.Error())
321 c.Messages.Update(assistantMsg)
322 return err
323 }
324
325 select {
326 case <-ctx.Done():
327 assistantMsg.AddFinish("canceled")
328 c.Messages.Update(assistantMsg)
329 return context.Canceled
330 default:
331 }
332 }
333
334 // Check for context cancellation before tool execution
335 select {
336 case <-ctx.Done():
337 assistantMsg.AddFinish("canceled")
338 c.Messages.Update(assistantMsg)
339 return context.Canceled
340 default:
341 // Continue processing
342 }
343
344 msg, err := c.handleToolExecution(ctx, assistantMsg)
345 if err != nil {
346 if errors.Is(err, context.Canceled) {
347 assistantMsg.AddFinish("canceled")
348 c.Messages.Update(assistantMsg)
349 return context.Canceled
350 }
351 return err
352 }
353
354 c.Messages.Update(assistantMsg)
355
356 if len(assistantMsg.ToolCalls()) == 0 {
357 break
358 }
359
360 messages = append(messages, assistantMsg)
361 if msg != nil {
362 messages = append(messages, *msg)
363 }
364
365 // Check for context cancellation after tool execution
366 select {
367 case <-ctx.Done():
368 assistantMsg.AddFinish("canceled")
369 c.Messages.Update(assistantMsg)
370 return context.Canceled
371 default:
372 // Continue processing
373 }
374 }
375 return nil
376}
377
378func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
379 maxTokens := config.Get().Model.CoderMaxTokens
380
381 providerConfig, ok := config.Get().Providers[model.Provider]
382 if !ok || !providerConfig.Enabled {
383 return nil, nil, errors.New("provider is not enabled")
384 }
385 var agentProvider provider.Provider
386 var titleGenerator provider.Provider
387
388 switch model.Provider {
389 case models.ProviderOpenAI:
390 var err error
391 agentProvider, err = provider.NewOpenAIProvider(
392 provider.WithOpenAISystemMessage(
393 prompt.CoderOpenAISystemPrompt(),
394 ),
395 provider.WithOpenAIMaxTokens(maxTokens),
396 provider.WithOpenAIModel(model),
397 provider.WithOpenAIKey(providerConfig.APIKey),
398 )
399 if err != nil {
400 return nil, nil, err
401 }
402 titleGenerator, err = provider.NewOpenAIProvider(
403 provider.WithOpenAISystemMessage(
404 prompt.TitlePrompt(),
405 ),
406 provider.WithOpenAIMaxTokens(80),
407 provider.WithOpenAIModel(model),
408 provider.WithOpenAIKey(providerConfig.APIKey),
409 )
410 if err != nil {
411 return nil, nil, err
412 }
413 case models.ProviderAnthropic:
414 var err error
415 agentProvider, err = provider.NewAnthropicProvider(
416 provider.WithAnthropicSystemMessage(
417 prompt.CoderAnthropicSystemPrompt(),
418 ),
419 provider.WithAnthropicMaxTokens(maxTokens),
420 provider.WithAnthropicKey(providerConfig.APIKey),
421 provider.WithAnthropicModel(model),
422 )
423 if err != nil {
424 return nil, nil, err
425 }
426 titleGenerator, err = provider.NewAnthropicProvider(
427 provider.WithAnthropicSystemMessage(
428 prompt.TitlePrompt(),
429 ),
430 provider.WithAnthropicMaxTokens(80),
431 provider.WithAnthropicKey(providerConfig.APIKey),
432 provider.WithAnthropicModel(model),
433 )
434 if err != nil {
435 return nil, nil, err
436 }
437
438 case models.ProviderGemini:
439 var err error
440 agentProvider, err = provider.NewGeminiProvider(
441 ctx,
442 provider.WithGeminiSystemMessage(
443 prompt.CoderOpenAISystemPrompt(),
444 ),
445 provider.WithGeminiMaxTokens(int32(maxTokens)),
446 provider.WithGeminiKey(providerConfig.APIKey),
447 provider.WithGeminiModel(model),
448 )
449 if err != nil {
450 return nil, nil, err
451 }
452 titleGenerator, err = provider.NewGeminiProvider(
453 ctx,
454 provider.WithGeminiSystemMessage(
455 prompt.TitlePrompt(),
456 ),
457 provider.WithGeminiMaxTokens(80),
458 provider.WithGeminiKey(providerConfig.APIKey),
459 provider.WithGeminiModel(model),
460 )
461 if err != nil {
462 return nil, nil, err
463 }
464 case models.ProviderGROQ:
465 var err error
466 agentProvider, err = provider.NewOpenAIProvider(
467 provider.WithOpenAISystemMessage(
468 prompt.CoderAnthropicSystemPrompt(),
469 ),
470 provider.WithOpenAIMaxTokens(maxTokens),
471 provider.WithOpenAIModel(model),
472 provider.WithOpenAIKey(providerConfig.APIKey),
473 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
474 )
475 if err != nil {
476 return nil, nil, err
477 }
478 titleGenerator, err = provider.NewOpenAIProvider(
479 provider.WithOpenAISystemMessage(
480 prompt.TitlePrompt(),
481 ),
482 provider.WithOpenAIMaxTokens(80),
483 provider.WithOpenAIModel(model),
484 provider.WithOpenAIKey(providerConfig.APIKey),
485 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
486 )
487 if err != nil {
488 return nil, nil, err
489 }
490
491 case models.ProviderBedrock:
492 var err error
493 agentProvider, err = provider.NewBedrockProvider(
494 provider.WithBedrockSystemMessage(
495 prompt.CoderAnthropicSystemPrompt(),
496 ),
497 provider.WithBedrockMaxTokens(maxTokens),
498 provider.WithBedrockModel(model),
499 )
500 if err != nil {
501 return nil, nil, err
502 }
503 titleGenerator, err = provider.NewBedrockProvider(
504 provider.WithBedrockSystemMessage(
505 prompt.TitlePrompt(),
506 ),
507 provider.WithBedrockMaxTokens(maxTokens),
508 provider.WithBedrockModel(model),
509 )
510 if err != nil {
511 return nil, nil, err
512 }
513
514 }
515
516 return agentProvider, titleGenerator, nil
517}