1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/kujtimiihoxha/termai/internal/app"
11 "github.com/kujtimiihoxha/termai/internal/config"
12 "github.com/kujtimiihoxha/termai/internal/llm/models"
13 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
14 "github.com/kujtimiihoxha/termai/internal/llm/provider"
15 "github.com/kujtimiihoxha/termai/internal/llm/tools"
16 "github.com/kujtimiihoxha/termai/internal/logging"
17 "github.com/kujtimiihoxha/termai/internal/message"
18)
19
20type Agent interface {
21 Generate(ctx context.Context, sessionID string, content string) error
22}
23
24type agent struct {
25 *app.App
26 model models.Model
27 tools []tools.BaseTool
28 agent provider.Provider
29 titleGenerator provider.Provider
30}
31
32func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
33 response, err := c.titleGenerator.SendMessages(
34 ctx,
35 []message.Message{
36 {
37 Role: message.User,
38 Parts: []message.ContentPart{
39 message.TextContent{
40 Text: content,
41 },
42 },
43 },
44 },
45 nil,
46 )
47 if err != nil {
48 return
49 }
50
51 session, err := c.Sessions.Get(sessionID)
52 if err != nil {
53 return
54 }
55 if response.Content != "" {
56 session.Title = response.Content
57 session.Title = strings.TrimSpace(session.Title)
58 session.Title = strings.ReplaceAll(session.Title, "\n", " ")
59 c.Sessions.Save(session)
60 }
61}
62
63func (c *agent) TrackUsage(sessionID string, model models.Model, usage provider.TokenUsage) error {
64 session, err := c.Sessions.Get(sessionID)
65 if err != nil {
66 return err
67 }
68
69 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
70 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
71 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
72 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
73
74 session.Cost += cost
75 session.CompletionTokens += usage.OutputTokens
76 session.PromptTokens += usage.InputTokens
77
78 _, err = c.Sessions.Save(session)
79 return err
80}
81
82func (c *agent) processEvent(
83 sessionID string,
84 assistantMsg *message.Message,
85 event provider.ProviderEvent,
86) error {
87 switch event.Type {
88 case provider.EventThinkingDelta:
89 assistantMsg.AppendReasoningContent(event.Content)
90 return c.Messages.Update(*assistantMsg)
91 case provider.EventContentDelta:
92 assistantMsg.AppendContent(event.Content)
93 return c.Messages.Update(*assistantMsg)
94 case provider.EventError:
95 if errors.Is(event.Error, context.Canceled) {
96 return nil
97 }
98 logging.ErrorPersist(event.Error.Error())
99 return event.Error
100 case provider.EventWarning:
101 logging.WarnPersist(event.Info)
102 return nil
103 case provider.EventInfo:
104 logging.InfoPersist(event.Info)
105 case provider.EventComplete:
106 assistantMsg.SetToolCalls(event.Response.ToolCalls)
107 assistantMsg.AddFinish(event.Response.FinishReason)
108 err := c.Messages.Update(*assistantMsg)
109 if err != nil {
110 return err
111 }
112 return c.TrackUsage(sessionID, c.model, event.Response.Usage)
113 }
114
115 return nil
116}
117
118func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
119 var wg sync.WaitGroup
120 toolResults := make([]message.ToolResult, len(toolCalls))
121 mutex := &sync.Mutex{}
122 errChan := make(chan error, 1)
123
124 // Create a child context that can be canceled
125 ctx, cancel := context.WithCancel(ctx)
126 defer cancel()
127
128 for i, tc := range toolCalls {
129 wg.Add(1)
130 go func(index int, toolCall message.ToolCall) {
131 defer wg.Done()
132
133 // Check if context is already canceled
134 select {
135 case <-ctx.Done():
136 mutex.Lock()
137 toolResults[index] = message.ToolResult{
138 ToolCallID: toolCall.ID,
139 Content: "Tool execution canceled",
140 IsError: true,
141 }
142 mutex.Unlock()
143
144 // Send cancellation error to error channel if it's empty
145 select {
146 case errChan <- ctx.Err():
147 default:
148 }
149 return
150 default:
151 }
152
153 response := ""
154 isError := false
155 found := false
156
157 for _, tool := range tls {
158 if tool.Info().Name == toolCall.Name {
159 found = true
160 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
161 ID: toolCall.ID,
162 Name: toolCall.Name,
163 Input: toolCall.Input,
164 })
165
166 if toolErr != nil {
167 if errors.Is(toolErr, context.Canceled) {
168 response = "Tool execution canceled"
169
170 // Send cancellation error to error channel if it's empty
171 select {
172 case errChan <- ctx.Err():
173 default:
174 }
175 } else {
176 response = fmt.Sprintf("error running tool: %s", toolErr)
177 }
178 isError = true
179 } else {
180 response = toolResult.Content
181 isError = toolResult.IsError
182 }
183 break
184 }
185 }
186
187 if !found {
188 response = fmt.Sprintf("tool not found: %s", toolCall.Name)
189 isError = true
190 }
191
192 mutex.Lock()
193 defer mutex.Unlock()
194
195 toolResults[index] = message.ToolResult{
196 ToolCallID: toolCall.ID,
197 Content: response,
198 IsError: isError,
199 }
200 }(i, tc)
201 }
202
203 // Wait for all goroutines to finish or context to be canceled
204 done := make(chan struct{})
205 go func() {
206 wg.Wait()
207 close(done)
208 }()
209
210 select {
211 case <-done:
212 // All tools completed successfully
213 case err := <-errChan:
214 // One of the tools encountered a cancellation
215 return toolResults, err
216 case <-ctx.Done():
217 // Context was canceled externally
218 return toolResults, ctx.Err()
219 }
220
221 return toolResults, nil
222}
223
224func (c *agent) handleToolExecution(
225 ctx context.Context,
226 assistantMsg message.Message,
227) (*message.Message, error) {
228 if len(assistantMsg.ToolCalls()) == 0 {
229 return nil, nil
230 }
231
232 toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
233 if err != nil {
234 return nil, err
235 }
236 parts := make([]message.ContentPart, 0)
237 for _, toolResult := range toolResults {
238 parts = append(parts, toolResult)
239 }
240 msg, err := c.Messages.Create(assistantMsg.SessionID, message.CreateMessageParams{
241 Role: message.Tool,
242 Parts: parts,
243 })
244
245 return &msg, err
246}
247
248func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
249 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
250 messages, err := c.Messages.List(sessionID)
251 if err != nil {
252 return err
253 }
254
255 if len(messages) == 0 {
256 go c.handleTitleGeneration(ctx, sessionID, content)
257 }
258
259 userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
260 Role: message.User,
261 Parts: []message.ContentPart{
262 message.TextContent{
263 Text: content,
264 },
265 },
266 })
267 if err != nil {
268 return err
269 }
270
271 messages = append(messages, userMsg)
272 for {
273 select {
274 case <-ctx.Done():
275 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
276 Role: message.Assistant,
277 Parts: []message.ContentPart{},
278 })
279 if err != nil {
280 return err
281 }
282 assistantMsg.AddFinish("canceled")
283 c.Messages.Update(assistantMsg)
284 return context.Canceled
285 default:
286 // Continue processing
287 }
288
289 eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
290 if err != nil {
291 if errors.Is(err, context.Canceled) {
292 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
293 Role: message.Assistant,
294 Parts: []message.ContentPart{},
295 })
296 if err != nil {
297 return err
298 }
299 assistantMsg.AddFinish("canceled")
300 c.Messages.Update(assistantMsg)
301 return context.Canceled
302 }
303 return err
304 }
305
306 assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
307 Role: message.Assistant,
308 Parts: []message.ContentPart{},
309 Model: c.model.ID,
310 })
311 if err != nil {
312 return err
313 }
314
315 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
316 for event := range eventChan {
317 err = c.processEvent(sessionID, &assistantMsg, event)
318 if err != nil {
319 if errors.Is(err, context.Canceled) {
320 assistantMsg.AddFinish("canceled")
321 c.Messages.Update(assistantMsg)
322 return context.Canceled
323 }
324 assistantMsg.AddFinish("error:" + err.Error())
325 c.Messages.Update(assistantMsg)
326 return err
327 }
328
329 select {
330 case <-ctx.Done():
331 assistantMsg.AddFinish("canceled")
332 c.Messages.Update(assistantMsg)
333 return context.Canceled
334 default:
335 }
336 }
337
338 // Check for context cancellation before tool execution
339 select {
340 case <-ctx.Done():
341 assistantMsg.AddFinish("canceled")
342 c.Messages.Update(assistantMsg)
343 return context.Canceled
344 default:
345 // Continue processing
346 }
347
348 msg, err := c.handleToolExecution(ctx, assistantMsg)
349 if err != nil {
350 if errors.Is(err, context.Canceled) {
351 assistantMsg.AddFinish("canceled")
352 c.Messages.Update(assistantMsg)
353 return context.Canceled
354 }
355 return err
356 }
357
358 c.Messages.Update(assistantMsg)
359
360 if len(assistantMsg.ToolCalls()) == 0 {
361 break
362 }
363
364 messages = append(messages, assistantMsg)
365 if msg != nil {
366 messages = append(messages, *msg)
367 }
368
369 // Check for context cancellation after tool execution
370 select {
371 case <-ctx.Done():
372 assistantMsg.AddFinish("canceled")
373 c.Messages.Update(assistantMsg)
374 return context.Canceled
375 default:
376 // Continue processing
377 }
378 }
379 return nil
380}
381
382func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
383 maxTokens := config.Get().Model.CoderMaxTokens
384
385 providerConfig, ok := config.Get().Providers[model.Provider]
386 if !ok || !providerConfig.Enabled {
387 return nil, nil, errors.New("provider is not enabled")
388 }
389 var agentProvider provider.Provider
390 var titleGenerator provider.Provider
391
392 switch model.Provider {
393 case models.ProviderOpenAI:
394 var err error
395 agentProvider, err = provider.NewOpenAIProvider(
396 provider.WithOpenAISystemMessage(
397 prompt.CoderOpenAISystemPrompt(),
398 ),
399 provider.WithOpenAIMaxTokens(maxTokens),
400 provider.WithOpenAIModel(model),
401 provider.WithOpenAIKey(providerConfig.APIKey),
402 )
403 if err != nil {
404 return nil, nil, err
405 }
406 titleGenerator, err = provider.NewOpenAIProvider(
407 provider.WithOpenAISystemMessage(
408 prompt.TitlePrompt(),
409 ),
410 provider.WithOpenAIMaxTokens(80),
411 provider.WithOpenAIModel(model),
412 provider.WithOpenAIKey(providerConfig.APIKey),
413 )
414 if err != nil {
415 return nil, nil, err
416 }
417 case models.ProviderAnthropic:
418 var err error
419 agentProvider, err = provider.NewAnthropicProvider(
420 provider.WithAnthropicSystemMessage(
421 prompt.CoderAnthropicSystemPrompt(),
422 ),
423 provider.WithAnthropicMaxTokens(maxTokens),
424 provider.WithAnthropicKey(providerConfig.APIKey),
425 provider.WithAnthropicModel(model),
426 )
427 if err != nil {
428 return nil, nil, err
429 }
430 titleGenerator, err = provider.NewAnthropicProvider(
431 provider.WithAnthropicSystemMessage(
432 prompt.TitlePrompt(),
433 ),
434 provider.WithAnthropicMaxTokens(80),
435 provider.WithAnthropicKey(providerConfig.APIKey),
436 provider.WithAnthropicModel(model),
437 )
438 if err != nil {
439 return nil, nil, err
440 }
441
442 case models.ProviderGemini:
443 var err error
444 agentProvider, err = provider.NewGeminiProvider(
445 ctx,
446 provider.WithGeminiSystemMessage(
447 prompt.CoderOpenAISystemPrompt(),
448 ),
449 provider.WithGeminiMaxTokens(int32(maxTokens)),
450 provider.WithGeminiKey(providerConfig.APIKey),
451 provider.WithGeminiModel(model),
452 )
453 if err != nil {
454 return nil, nil, err
455 }
456 titleGenerator, err = provider.NewGeminiProvider(
457 ctx,
458 provider.WithGeminiSystemMessage(
459 prompt.TitlePrompt(),
460 ),
461 provider.WithGeminiMaxTokens(80),
462 provider.WithGeminiKey(providerConfig.APIKey),
463 provider.WithGeminiModel(model),
464 )
465 if err != nil {
466 return nil, nil, err
467 }
468 case models.ProviderGROQ:
469 var err error
470 agentProvider, err = provider.NewOpenAIProvider(
471 provider.WithOpenAISystemMessage(
472 prompt.CoderAnthropicSystemPrompt(),
473 ),
474 provider.WithOpenAIMaxTokens(maxTokens),
475 provider.WithOpenAIModel(model),
476 provider.WithOpenAIKey(providerConfig.APIKey),
477 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
478 )
479 if err != nil {
480 return nil, nil, err
481 }
482 titleGenerator, err = provider.NewOpenAIProvider(
483 provider.WithOpenAISystemMessage(
484 prompt.TitlePrompt(),
485 ),
486 provider.WithOpenAIMaxTokens(80),
487 provider.WithOpenAIModel(model),
488 provider.WithOpenAIKey(providerConfig.APIKey),
489 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
490 )
491 if err != nil {
492 return nil, nil, err
493 }
494
495 case models.ProviderBedrock:
496 var err error
497 agentProvider, err = provider.NewBedrockProvider(
498 provider.WithBedrockSystemMessage(
499 prompt.CoderAnthropicSystemPrompt(),
500 ),
501 provider.WithBedrockMaxTokens(maxTokens),
502 provider.WithBedrockModel(model),
503 )
504 if err != nil {
505 return nil, nil, err
506 }
507 titleGenerator, err = provider.NewBedrockProvider(
508 provider.WithBedrockSystemMessage(
509 prompt.TitlePrompt(),
510 ),
511 provider.WithBedrockMaxTokens(maxTokens),
512 provider.WithBedrockModel(model),
513 )
514 if err != nil {
515 return nil, nil, err
516 }
517
518 }
519
520 return agentProvider, titleGenerator, nil
521}