1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "sync"
9
10 "github.com/kujtimiihoxha/termai/internal/config"
11 "github.com/kujtimiihoxha/termai/internal/llm/models"
12 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
13 "github.com/kujtimiihoxha/termai/internal/llm/provider"
14 "github.com/kujtimiihoxha/termai/internal/llm/tools"
15 "github.com/kujtimiihoxha/termai/internal/logging"
16 "github.com/kujtimiihoxha/termai/internal/message"
17 "github.com/kujtimiihoxha/termai/internal/session"
18)
19
20// Common errors
21var (
22 ErrProviderNotEnabled = errors.New("provider is not enabled")
23 ErrRequestCancelled = errors.New("request cancelled by user")
24 ErrSessionBusy = errors.New("session is currently processing another request")
25)
26
27// Service defines the interface for generating responses
28type Service interface {
29 Generate(ctx context.Context, sessionID string, content string) error
30 Cancel(sessionID string) error
31}
32
33type agent struct {
34 sessions session.Service
35 messages message.Service
36 model models.Model
37 tools []tools.BaseTool
38 agent provider.Provider
39 titleGenerator provider.Provider
40 activeRequests sync.Map // map[sessionID]context.CancelFunc
41}
42
43// NewAgent creates a new agent instance with the given model and tools
44func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
45 agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
46 if err != nil {
47 return nil, fmt.Errorf("failed to initialize providers: %w", err)
48 }
49
50 return &agent{
51 model: model,
52 tools: tools,
53 sessions: sessions,
54 messages: messages,
55 agent: agentProvider,
56 titleGenerator: titleGenerator,
57 activeRequests: sync.Map{},
58 }, nil
59}
60
61// Cancel cancels an active request by session ID
62func (a *agent) Cancel(sessionID string) error {
63 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
64 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
65 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
66 cancel()
67 return nil
68 }
69 }
70 return errors.New("no active request found for this session")
71}
72
73// Generate starts the generation process
74func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
75 // Check if this session already has an active request
76 if _, busy := a.activeRequests.Load(sessionID); busy {
77 return ErrSessionBusy
78 }
79
80 // Create a cancellable context
81 genCtx, cancel := context.WithCancel(ctx)
82
83 // Store cancel function to allow user cancellation
84 a.activeRequests.Store(sessionID, cancel)
85
86 // Launch the generation in a goroutine
87 go func() {
88 defer func() {
89 if r := recover(); r != nil {
90 logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
91 }
92 }()
93 defer a.activeRequests.Delete(sessionID)
94 defer cancel()
95
96 if err := a.generate(genCtx, sessionID, content); err != nil {
97 if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
98 // Log the error (avoid logging cancellations as they're expected)
99 logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
100
101 // You may want to create an error message in the chat
102 bgCtx := context.Background()
103 errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
104 _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
105 Role: message.System,
106 Parts: []message.ContentPart{
107 message.TextContent{
108 Text: errorMsg,
109 },
110 },
111 })
112 if createErr != nil {
113 logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
114 }
115 }
116 }
117 }()
118
119 return nil
120}
121
122// IsSessionBusy checks if a session currently has an active request
123func (a *agent) IsSessionBusy(sessionID string) bool {
124 _, busy := a.activeRequests.Load(sessionID)
125 return busy
126} // handleTitleGeneration asynchronously generates a title for new sessions
127func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
128 response, err := a.titleGenerator.SendMessages(
129 ctx,
130 []message.Message{
131 {
132 Role: message.User,
133 Parts: []message.ContentPart{
134 message.TextContent{
135 Text: content,
136 },
137 },
138 },
139 },
140 nil,
141 )
142 if err != nil {
143 logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
144 return
145 }
146
147 session, err := a.sessions.Get(ctx, sessionID)
148 if err != nil {
149 logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
150 return
151 }
152
153 if response.Content != "" {
154 session.Title = strings.TrimSpace(response.Content)
155 session.Title = strings.ReplaceAll(session.Title, "\n", " ")
156 if _, err := a.sessions.Save(ctx, session); err != nil {
157 logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
158 }
159 }
160}
161
162// TrackUsage updates token usage statistics for the session
163func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
164 session, err := a.sessions.Get(ctx, sessionID)
165 if err != nil {
166 return fmt.Errorf("failed to get session: %w", err)
167 }
168
169 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
170 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
171 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
172 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
173
174 session.Cost += cost
175 session.CompletionTokens += usage.OutputTokens
176 session.PromptTokens += usage.InputTokens
177
178 _, err = a.sessions.Save(ctx, session)
179 if err != nil {
180 return fmt.Errorf("failed to save session: %w", err)
181 }
182 return nil
183}
184
185// processEvent handles different types of events during generation
186func (a *agent) processEvent(
187 ctx context.Context,
188 sessionID string,
189 assistantMsg *message.Message,
190 event provider.ProviderEvent,
191) error {
192 select {
193 case <-ctx.Done():
194 return ctx.Err()
195 default:
196 // Continue processing
197 }
198
199 switch event.Type {
200 case provider.EventThinkingDelta:
201 assistantMsg.AppendReasoningContent(event.Content)
202 return a.messages.Update(ctx, *assistantMsg)
203 case provider.EventContentDelta:
204 assistantMsg.AppendContent(event.Content)
205 return a.messages.Update(ctx, *assistantMsg)
206 case provider.EventError:
207 if errors.Is(event.Error, context.Canceled) {
208 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
209 return context.Canceled
210 }
211 logging.ErrorPersist(event.Error.Error())
212 return event.Error
213 case provider.EventWarning:
214 logging.WarnPersist(event.Info)
215 case provider.EventInfo:
216 logging.InfoPersist(event.Info)
217 case provider.EventComplete:
218 assistantMsg.SetToolCalls(event.Response.ToolCalls)
219 assistantMsg.AddFinish(event.Response.FinishReason)
220 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
221 return fmt.Errorf("failed to update message: %w", err)
222 }
223 return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
224 }
225
226 return nil
227}
228
229// ExecuteTools runs all tool calls sequentially and returns the results
230func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
231 toolResults := make([]message.ToolResult, len(toolCalls))
232
233 // Create a child context that can be canceled
234 ctx, cancel := context.WithCancel(ctx)
235 defer cancel()
236
237 // Check if already canceled before starting any execution
238 if ctx.Err() != nil {
239 // Mark all tools as canceled
240 for i, toolCall := range toolCalls {
241 toolResults[i] = message.ToolResult{
242 ToolCallID: toolCall.ID,
243 Content: "Tool execution canceled by user",
244 IsError: true,
245 }
246 }
247 return toolResults, ctx.Err()
248 }
249
250 for i, toolCall := range toolCalls {
251 // Check for cancellation before executing each tool
252 select {
253 case <-ctx.Done():
254 // Mark this and all remaining tools as canceled
255 for j := i; j < len(toolCalls); j++ {
256 toolResults[j] = message.ToolResult{
257 ToolCallID: toolCalls[j].ID,
258 Content: "Tool execution canceled by user",
259 IsError: true,
260 }
261 }
262 return toolResults, ctx.Err()
263 default:
264 // Continue processing
265 }
266
267 response := ""
268 isError := false
269 found := false
270
271 // Find and execute the appropriate tool
272 for _, tool := range tls {
273 if tool.Info().Name == toolCall.Name {
274 found = true
275 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
276 ID: toolCall.ID,
277 Name: toolCall.Name,
278 Input: toolCall.Input,
279 })
280
281 if toolErr != nil {
282 if errors.Is(toolErr, context.Canceled) {
283 response = "Tool execution canceled by user"
284 } else {
285 response = fmt.Sprintf("Error running tool: %s", toolErr)
286 }
287 isError = true
288 } else {
289 response = toolResult.Content
290 isError = toolResult.IsError
291 }
292 break
293 }
294 }
295
296 if !found {
297 response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
298 isError = true
299 }
300
301 toolResults[i] = message.ToolResult{
302 ToolCallID: toolCall.ID,
303 Content: response,
304 IsError: isError,
305 }
306 }
307
308 return toolResults, nil
309}
310
311// handleToolExecution processes tool calls and creates tool result messages
312func (a *agent) handleToolExecution(
313 ctx context.Context,
314 assistantMsg message.Message,
315) (*message.Message, error) {
316 select {
317 case <-ctx.Done():
318 // If cancelled, create tool results that indicate cancellation
319 if len(assistantMsg.ToolCalls()) > 0 {
320 toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
321 for _, tc := range assistantMsg.ToolCalls() {
322 toolResults = append(toolResults, message.ToolResult{
323 ToolCallID: tc.ID,
324 Content: "Tool execution canceled by user",
325 IsError: true,
326 })
327 }
328
329 // Use background context to ensure the message is created even if original context is cancelled
330 bgCtx := context.Background()
331 parts := make([]message.ContentPart, 0)
332 for _, toolResult := range toolResults {
333 parts = append(parts, toolResult)
334 }
335 msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
336 Role: message.Tool,
337 Parts: parts,
338 })
339 if err != nil {
340 return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
341 }
342 return &msg, ctx.Err()
343 }
344 return nil, ctx.Err()
345 default:
346 // Continue processing
347 }
348
349 if len(assistantMsg.ToolCalls()) == 0 {
350 return nil, nil
351 }
352
353 toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
354 if err != nil {
355 // If error is from cancellation, still return the partial results we have
356 if errors.Is(err, context.Canceled) {
357 // Use background context to ensure the message is created even if original context is cancelled
358 bgCtx := context.Background()
359 parts := make([]message.ContentPart, 0)
360 for _, toolResult := range toolResults {
361 parts = append(parts, toolResult)
362 }
363
364 msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
365 Role: message.Tool,
366 Parts: parts,
367 })
368 if createErr != nil {
369 logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
370 return nil, err
371 }
372 return &msg, err
373 }
374 return nil, err
375 }
376
377 parts := make([]message.ContentPart, 0, len(toolResults))
378 for _, toolResult := range toolResults {
379 parts = append(parts, toolResult)
380 }
381
382 msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
383 Role: message.Tool,
384 Parts: parts,
385 })
386 if err != nil {
387 return nil, fmt.Errorf("failed to create tool message: %w", err)
388 }
389
390 return &msg, nil
391}
392
393// generate handles the main generation workflow
394func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
395 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
396
397 // Handle context cancellation at any point
398 if err := ctx.Err(); err != nil {
399 return ErrRequestCancelled
400 }
401
402 messages, err := a.messages.List(ctx, sessionID)
403 if err != nil {
404 return fmt.Errorf("failed to list messages: %w", err)
405 }
406
407 if len(messages) == 0 {
408 titleCtx := context.Background()
409 go a.handleTitleGeneration(titleCtx, sessionID, content)
410 }
411
412 userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
413 Role: message.User,
414 Parts: []message.ContentPart{
415 message.TextContent{
416 Text: content,
417 },
418 },
419 })
420 if err != nil {
421 return fmt.Errorf("failed to create user message: %w", err)
422 }
423
424 messages = append(messages, userMsg)
425
426 for {
427 // Check for cancellation before each iteration
428 select {
429 case <-ctx.Done():
430 return ErrRequestCancelled
431 default:
432 // Continue processing
433 }
434
435 eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
436 if err != nil {
437 if errors.Is(err, context.Canceled) {
438 return ErrRequestCancelled
439 }
440 return fmt.Errorf("failed to stream response: %w", err)
441 }
442
443 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444 Role: message.Assistant,
445 Parts: []message.ContentPart{},
446 Model: a.model.ID,
447 })
448 if err != nil {
449 return fmt.Errorf("failed to create assistant message: %w", err)
450 }
451
452 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454 // Process events from the LLM provider
455 for event := range eventChan {
456 if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
457 if errors.Is(err, context.Canceled) {
458 // Mark as canceled but don't create separate message
459 assistantMsg.AddFinish("canceled")
460 _ = a.messages.Update(context.Background(), assistantMsg)
461 return ErrRequestCancelled
462 }
463 assistantMsg.AddFinish("error:" + err.Error())
464 _ = a.messages.Update(ctx, assistantMsg)
465 return fmt.Errorf("event processing error: %w", err)
466 }
467
468 // Check for cancellation during event processing
469 select {
470 case <-ctx.Done():
471 // Mark as canceled
472 assistantMsg.AddFinish("canceled")
473 _ = a.messages.Update(context.Background(), assistantMsg)
474 return ErrRequestCancelled
475 default:
476 }
477 }
478
479 // Check for cancellation before tool execution
480 select {
481 case <-ctx.Done():
482 assistantMsg.AddFinish("canceled_by_user")
483 _ = a.messages.Update(context.Background(), assistantMsg)
484 return ErrRequestCancelled
485 default:
486 }
487
488 // Execute any tool calls
489 toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
490 if err != nil {
491 if errors.Is(err, context.Canceled) {
492 assistantMsg.AddFinish("canceled_by_user")
493 _ = a.messages.Update(context.Background(), assistantMsg)
494 return ErrRequestCancelled
495 }
496 return fmt.Errorf("tool execution error: %w", err)
497 }
498
499 if err := a.messages.Update(ctx, assistantMsg); err != nil {
500 return fmt.Errorf("failed to update assistant message: %w", err)
501 }
502
503 // If no tool calls, we're done
504 if len(assistantMsg.ToolCalls()) == 0 {
505 break
506 }
507
508 // Add messages for next iteration
509 messages = append(messages, assistantMsg)
510 if toolMsg != nil {
511 messages = append(messages, *toolMsg)
512 }
513
514 // Check for cancellation after tool execution
515 select {
516 case <-ctx.Done():
517 return ErrRequestCancelled
518 default:
519 }
520 }
521
522 return nil
523}
524
525// getAgentProviders initializes the LLM providers based on the chosen model
526func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
527 maxTokens := config.Get().Model.CoderMaxTokens
528
529 providerConfig, ok := config.Get().Providers[model.Provider]
530 if !ok || providerConfig.Disabled {
531 return nil, nil, ErrProviderNotEnabled
532 }
533
534 var agentProvider provider.Provider
535 var titleGenerator provider.Provider
536 var err error
537
538 switch model.Provider {
539 case models.ProviderOpenAI:
540 agentProvider, err = provider.NewOpenAIProvider(
541 provider.WithOpenAISystemMessage(
542 prompt.CoderOpenAISystemPrompt(),
543 ),
544 provider.WithOpenAIMaxTokens(maxTokens),
545 provider.WithOpenAIModel(model),
546 provider.WithOpenAIKey(providerConfig.APIKey),
547 )
548 if err != nil {
549 return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
550 }
551
552 titleGenerator, err = provider.NewOpenAIProvider(
553 provider.WithOpenAISystemMessage(
554 prompt.TitlePrompt(),
555 ),
556 provider.WithOpenAIMaxTokens(80),
557 provider.WithOpenAIModel(model),
558 provider.WithOpenAIKey(providerConfig.APIKey),
559 )
560 if err != nil {
561 return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
562 }
563
564 case models.ProviderAnthropic:
565 agentProvider, err = provider.NewAnthropicProvider(
566 provider.WithAnthropicSystemMessage(
567 prompt.CoderAnthropicSystemPrompt(),
568 ),
569 provider.WithAnthropicMaxTokens(maxTokens),
570 provider.WithAnthropicKey(providerConfig.APIKey),
571 provider.WithAnthropicModel(model),
572 )
573 if err != nil {
574 return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
575 }
576
577 titleGenerator, err = provider.NewAnthropicProvider(
578 provider.WithAnthropicSystemMessage(
579 prompt.TitlePrompt(),
580 ),
581 provider.WithAnthropicMaxTokens(80),
582 provider.WithAnthropicKey(providerConfig.APIKey),
583 provider.WithAnthropicModel(model),
584 )
585 if err != nil {
586 return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
587 }
588
589 case models.ProviderGemini:
590 agentProvider, err = provider.NewGeminiProvider(
591 ctx,
592 provider.WithGeminiSystemMessage(
593 prompt.CoderOpenAISystemPrompt(),
594 ),
595 provider.WithGeminiMaxTokens(int32(maxTokens)),
596 provider.WithGeminiKey(providerConfig.APIKey),
597 provider.WithGeminiModel(model),
598 )
599 if err != nil {
600 return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
601 }
602
603 titleGenerator, err = provider.NewGeminiProvider(
604 ctx,
605 provider.WithGeminiSystemMessage(
606 prompt.TitlePrompt(),
607 ),
608 provider.WithGeminiMaxTokens(80),
609 provider.WithGeminiKey(providerConfig.APIKey),
610 provider.WithGeminiModel(model),
611 )
612 if err != nil {
613 return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
614 }
615
616 case models.ProviderGROQ:
617 agentProvider, err = provider.NewOpenAIProvider(
618 provider.WithOpenAISystemMessage(
619 prompt.CoderAnthropicSystemPrompt(),
620 ),
621 provider.WithOpenAIMaxTokens(maxTokens),
622 provider.WithOpenAIModel(model),
623 provider.WithOpenAIKey(providerConfig.APIKey),
624 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
625 )
626 if err != nil {
627 return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
628 }
629
630 titleGenerator, err = provider.NewOpenAIProvider(
631 provider.WithOpenAISystemMessage(
632 prompt.TitlePrompt(),
633 ),
634 provider.WithOpenAIMaxTokens(80),
635 provider.WithOpenAIModel(model),
636 provider.WithOpenAIKey(providerConfig.APIKey),
637 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
638 )
639 if err != nil {
640 return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
641 }
642
643 case models.ProviderBedrock:
644 agentProvider, err = provider.NewBedrockProvider(
645 provider.WithBedrockSystemMessage(
646 prompt.CoderAnthropicSystemPrompt(),
647 ),
648 provider.WithBedrockMaxTokens(maxTokens),
649 provider.WithBedrockModel(model),
650 )
651 if err != nil {
652 return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
653 }
654
655 titleGenerator, err = provider.NewBedrockProvider(
656 provider.WithBedrockSystemMessage(
657 prompt.TitlePrompt(),
658 ),
659 provider.WithBedrockMaxTokens(80),
660 provider.WithBedrockModel(model),
661 )
662 if err != nil {
663 return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
664 }
665 default:
666 return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
667 }
668
669 return agentProvider, titleGenerator, nil
670}