1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "os"
8 "runtime/debug"
9 "strings"
10 "sync"
11
12 "github.com/kujtimiihoxha/termai/internal/config"
13 "github.com/kujtimiihoxha/termai/internal/llm/models"
14 "github.com/kujtimiihoxha/termai/internal/llm/prompt"
15 "github.com/kujtimiihoxha/termai/internal/llm/provider"
16 "github.com/kujtimiihoxha/termai/internal/llm/tools"
17 "github.com/kujtimiihoxha/termai/internal/logging"
18 "github.com/kujtimiihoxha/termai/internal/message"
19 "github.com/kujtimiihoxha/termai/internal/session"
20)
21
22// Common errors
23var (
24 ErrProviderNotEnabled = errors.New("provider is not enabled")
25 ErrRequestCancelled = errors.New("request cancelled by user")
26 ErrSessionBusy = errors.New("session is currently processing another request")
27)
28
29// Service defines the interface for generating responses
30type Service interface {
31 Generate(ctx context.Context, sessionID string, content string) error
32 Cancel(sessionID string) error
33}
34
35type agent struct {
36 sessions session.Service
37 messages message.Service
38 model models.Model
39 tools []tools.BaseTool
40 agent provider.Provider
41 titleGenerator provider.Provider
42 activeRequests sync.Map // map[sessionID]context.CancelFunc
43}
44
45// NewAgent creates a new agent instance with the given model and tools
46func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
47 agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
48 if err != nil {
49 return nil, fmt.Errorf("failed to initialize providers: %w", err)
50 }
51
52 return &agent{
53 model: model,
54 tools: tools,
55 sessions: sessions,
56 messages: messages,
57 agent: agentProvider,
58 titleGenerator: titleGenerator,
59 activeRequests: sync.Map{},
60 }, nil
61}
62
63// Cancel cancels an active request by session ID
64func (a *agent) Cancel(sessionID string) error {
65 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
66 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
67 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
68 cancel()
69 return nil
70 }
71 }
72 return errors.New("no active request found for this session")
73}
74
75// Generate starts the generation process
76func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
77 // Check if this session already has an active request
78 if _, busy := a.activeRequests.Load(sessionID); busy {
79 return ErrSessionBusy
80 }
81
82 // Create a cancellable context
83 genCtx, cancel := context.WithCancel(ctx)
84
85 // Store cancel function to allow user cancellation
86 a.activeRequests.Store(sessionID, cancel)
87
88 // Launch the generation in a goroutine
89 go func() {
90 defer func() {
91 if r := recover(); r != nil {
92 logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
93
94 // dump stack trace into a file
95 file, err := os.Create("panic.log")
96 if err != nil {
97 logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err))
98 return
99 }
100
101 defer file.Close()
102
103 stackTrace := debug.Stack()
104 if _, err := file.Write(stackTrace); err != nil {
105 logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err))
106 }
107
108 }
109 }()
110 defer a.activeRequests.Delete(sessionID)
111 defer cancel()
112
113 if err := a.generate(genCtx, sessionID, content); err != nil {
114 if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
115 // Log the error (avoid logging cancellations as they're expected)
116 logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
117
118 // You may want to create an error message in the chat
119 bgCtx := context.Background()
120 errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
121 _, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
122 Role: message.System,
123 Parts: []message.ContentPart{
124 message.TextContent{
125 Text: errorMsg,
126 },
127 },
128 })
129 if createErr != nil {
130 logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
131 }
132 }
133 }
134 }()
135
136 return nil
137}
138
139// IsSessionBusy checks if a session currently has an active request
140func (a *agent) IsSessionBusy(sessionID string) bool {
141 _, busy := a.activeRequests.Load(sessionID)
142 return busy
143} // handleTitleGeneration asynchronously generates a title for new sessions
144func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
145 response, err := a.titleGenerator.SendMessages(
146 ctx,
147 []message.Message{
148 {
149 Role: message.User,
150 Parts: []message.ContentPart{
151 message.TextContent{
152 Text: content,
153 },
154 },
155 },
156 },
157 nil,
158 )
159 if err != nil {
160 logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
161 return
162 }
163
164 session, err := a.sessions.Get(ctx, sessionID)
165 if err != nil {
166 logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
167 return
168 }
169
170 if response.Content != "" {
171 session.Title = strings.TrimSpace(response.Content)
172 session.Title = strings.ReplaceAll(session.Title, "\n", " ")
173 if _, err := a.sessions.Save(ctx, session); err != nil {
174 logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
175 }
176 }
177}
178
179// TrackUsage updates token usage statistics for the session
180func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
181 session, err := a.sessions.Get(ctx, sessionID)
182 if err != nil {
183 return fmt.Errorf("failed to get session: %w", err)
184 }
185
186 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
187 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
188 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
189 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
190
191 session.Cost += cost
192 session.CompletionTokens += usage.OutputTokens
193 session.PromptTokens += usage.InputTokens
194
195 _, err = a.sessions.Save(ctx, session)
196 if err != nil {
197 return fmt.Errorf("failed to save session: %w", err)
198 }
199 return nil
200}
201
202// processEvent handles different types of events during generation
203func (a *agent) processEvent(
204 ctx context.Context,
205 sessionID string,
206 assistantMsg *message.Message,
207 event provider.ProviderEvent,
208) error {
209 select {
210 case <-ctx.Done():
211 return ctx.Err()
212 default:
213 // Continue processing
214 }
215
216 switch event.Type {
217 case provider.EventThinkingDelta:
218 assistantMsg.AppendReasoningContent(event.Content)
219 return a.messages.Update(ctx, *assistantMsg)
220 case provider.EventContentDelta:
221 assistantMsg.AppendContent(event.Content)
222 return a.messages.Update(ctx, *assistantMsg)
223 case provider.EventError:
224 if errors.Is(event.Error, context.Canceled) {
225 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
226 return context.Canceled
227 }
228 logging.ErrorPersist(event.Error.Error())
229 return event.Error
230 case provider.EventWarning:
231 logging.WarnPersist(event.Info)
232 case provider.EventInfo:
233 logging.InfoPersist(event.Info)
234 case provider.EventComplete:
235 assistantMsg.SetToolCalls(event.Response.ToolCalls)
236 assistantMsg.AddFinish(event.Response.FinishReason)
237 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
238 return fmt.Errorf("failed to update message: %w", err)
239 }
240 return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
241 }
242
243 return nil
244}
245
246// ExecuteTools runs all tool calls sequentially and returns the results
247func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
248 toolResults := make([]message.ToolResult, len(toolCalls))
249
250 // Create a child context that can be canceled
251 ctx, cancel := context.WithCancel(ctx)
252 defer cancel()
253
254 // Check if already canceled before starting any execution
255 if ctx.Err() != nil {
256 // Mark all tools as canceled
257 for i, toolCall := range toolCalls {
258 toolResults[i] = message.ToolResult{
259 ToolCallID: toolCall.ID,
260 Content: "Tool execution canceled by user",
261 IsError: true,
262 }
263 }
264 return toolResults, ctx.Err()
265 }
266
267 for i, toolCall := range toolCalls {
268 // Check for cancellation before executing each tool
269 select {
270 case <-ctx.Done():
271 // Mark this and all remaining tools as canceled
272 for j := i; j < len(toolCalls); j++ {
273 toolResults[j] = message.ToolResult{
274 ToolCallID: toolCalls[j].ID,
275 Content: "Tool execution canceled by user",
276 IsError: true,
277 }
278 }
279 return toolResults, ctx.Err()
280 default:
281 // Continue processing
282 }
283
284 response := ""
285 isError := false
286 found := false
287
288 // Find and execute the appropriate tool
289 for _, tool := range tls {
290 if tool.Info().Name == toolCall.Name {
291 found = true
292 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
293 ID: toolCall.ID,
294 Name: toolCall.Name,
295 Input: toolCall.Input,
296 })
297
298 if toolErr != nil {
299 if errors.Is(toolErr, context.Canceled) {
300 response = "Tool execution canceled by user"
301 } else {
302 response = fmt.Sprintf("Error running tool: %s", toolErr)
303 }
304 isError = true
305 } else {
306 response = toolResult.Content
307 isError = toolResult.IsError
308 }
309 break
310 }
311 }
312
313 if !found {
314 response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
315 isError = true
316 }
317
318 toolResults[i] = message.ToolResult{
319 ToolCallID: toolCall.ID,
320 Content: response,
321 IsError: isError,
322 }
323 }
324
325 return toolResults, nil
326}
327
328// handleToolExecution processes tool calls and creates tool result messages
329func (a *agent) handleToolExecution(
330 ctx context.Context,
331 assistantMsg message.Message,
332) (*message.Message, error) {
333 select {
334 case <-ctx.Done():
335 // If cancelled, create tool results that indicate cancellation
336 if len(assistantMsg.ToolCalls()) > 0 {
337 toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
338 for _, tc := range assistantMsg.ToolCalls() {
339 toolResults = append(toolResults, message.ToolResult{
340 ToolCallID: tc.ID,
341 Content: "Tool execution canceled by user",
342 IsError: true,
343 })
344 }
345
346 // Use background context to ensure the message is created even if original context is cancelled
347 bgCtx := context.Background()
348 parts := make([]message.ContentPart, 0)
349 for _, toolResult := range toolResults {
350 parts = append(parts, toolResult)
351 }
352 msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
353 Role: message.Tool,
354 Parts: parts,
355 })
356 if err != nil {
357 return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
358 }
359 return &msg, ctx.Err()
360 }
361 return nil, ctx.Err()
362 default:
363 // Continue processing
364 }
365
366 if len(assistantMsg.ToolCalls()) == 0 {
367 return nil, nil
368 }
369
370 toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
371 if err != nil {
372 // If error is from cancellation, still return the partial results we have
373 if errors.Is(err, context.Canceled) {
374 // Use background context to ensure the message is created even if original context is cancelled
375 bgCtx := context.Background()
376 parts := make([]message.ContentPart, 0)
377 for _, toolResult := range toolResults {
378 parts = append(parts, toolResult)
379 }
380
381 msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
382 Role: message.Tool,
383 Parts: parts,
384 })
385 if createErr != nil {
386 logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
387 return nil, err
388 }
389 return &msg, err
390 }
391 return nil, err
392 }
393
394 parts := make([]message.ContentPart, 0, len(toolResults))
395 for _, toolResult := range toolResults {
396 parts = append(parts, toolResult)
397 }
398
399 msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
400 Role: message.Tool,
401 Parts: parts,
402 })
403 if err != nil {
404 return nil, fmt.Errorf("failed to create tool message: %w", err)
405 }
406
407 return &msg, nil
408}
409
410// generate handles the main generation workflow
411func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
412 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
413
414 // Handle context cancellation at any point
415 if err := ctx.Err(); err != nil {
416 return ErrRequestCancelled
417 }
418
419 messages, err := a.messages.List(ctx, sessionID)
420 if err != nil {
421 return fmt.Errorf("failed to list messages: %w", err)
422 }
423
424 if len(messages) == 0 {
425 titleCtx := context.Background()
426 go a.handleTitleGeneration(titleCtx, sessionID, content)
427 }
428
429 userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
430 Role: message.User,
431 Parts: []message.ContentPart{
432 message.TextContent{
433 Text: content,
434 },
435 },
436 })
437 if err != nil {
438 return fmt.Errorf("failed to create user message: %w", err)
439 }
440
441 messages = append(messages, userMsg)
442
443 for {
444 // Check for cancellation before each iteration
445 select {
446 case <-ctx.Done():
447 return ErrRequestCancelled
448 default:
449 // Continue processing
450 }
451
452 eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
453 if err != nil {
454 if errors.Is(err, context.Canceled) {
455 return ErrRequestCancelled
456 }
457 return fmt.Errorf("failed to stream response: %w", err)
458 }
459
460 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
461 Role: message.Assistant,
462 Parts: []message.ContentPart{},
463 Model: a.model.ID,
464 })
465 if err != nil {
466 return fmt.Errorf("failed to create assistant message: %w", err)
467 }
468
469 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470
471 // Process events from the LLM provider
472 for event := range eventChan {
473 if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
474 if errors.Is(err, context.Canceled) {
475 // Mark as canceled but don't create separate message
476 assistantMsg.AddFinish("canceled")
477 _ = a.messages.Update(context.Background(), assistantMsg)
478 return ErrRequestCancelled
479 }
480 assistantMsg.AddFinish("error:" + err.Error())
481 _ = a.messages.Update(ctx, assistantMsg)
482 return fmt.Errorf("event processing error: %w", err)
483 }
484
485 // Check for cancellation during event processing
486 select {
487 case <-ctx.Done():
488 // Mark as canceled
489 assistantMsg.AddFinish("canceled")
490 _ = a.messages.Update(context.Background(), assistantMsg)
491 return ErrRequestCancelled
492 default:
493 }
494 }
495
496 // Check for cancellation before tool execution
497 select {
498 case <-ctx.Done():
499 assistantMsg.AddFinish("canceled_by_user")
500 _ = a.messages.Update(context.Background(), assistantMsg)
501 return ErrRequestCancelled
502 default:
503 }
504
505 // Execute any tool calls
506 toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
507 if err != nil {
508 if errors.Is(err, context.Canceled) {
509 assistantMsg.AddFinish("canceled_by_user")
510 _ = a.messages.Update(context.Background(), assistantMsg)
511 return ErrRequestCancelled
512 }
513 return fmt.Errorf("tool execution error: %w", err)
514 }
515
516 if err := a.messages.Update(ctx, assistantMsg); err != nil {
517 return fmt.Errorf("failed to update assistant message: %w", err)
518 }
519
520 // If no tool calls, we're done
521 if len(assistantMsg.ToolCalls()) == 0 {
522 break
523 }
524
525 // Add messages for next iteration
526 messages = append(messages, assistantMsg)
527 if toolMsg != nil {
528 messages = append(messages, *toolMsg)
529 }
530
531 // Check for cancellation after tool execution
532 select {
533 case <-ctx.Done():
534 return ErrRequestCancelled
535 default:
536 }
537 }
538
539 return nil
540}
541
542// getAgentProviders initializes the LLM providers based on the chosen model
543func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
544 maxTokens := config.Get().Model.CoderMaxTokens
545
546 providerConfig, ok := config.Get().Providers[model.Provider]
547 if !ok || providerConfig.Disabled {
548 return nil, nil, ErrProviderNotEnabled
549 }
550
551 var agentProvider provider.Provider
552 var titleGenerator provider.Provider
553 var err error
554
555 switch model.Provider {
556 case models.ProviderOpenAI:
557 agentProvider, err = provider.NewOpenAIProvider(
558 provider.WithOpenAISystemMessage(
559 prompt.CoderOpenAISystemPrompt(),
560 ),
561 provider.WithOpenAIMaxTokens(maxTokens),
562 provider.WithOpenAIModel(model),
563 provider.WithOpenAIKey(providerConfig.APIKey),
564 )
565 if err != nil {
566 return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
567 }
568
569 titleGenerator, err = provider.NewOpenAIProvider(
570 provider.WithOpenAISystemMessage(
571 prompt.TitlePrompt(),
572 ),
573 provider.WithOpenAIMaxTokens(80),
574 provider.WithOpenAIModel(model),
575 provider.WithOpenAIKey(providerConfig.APIKey),
576 )
577 if err != nil {
578 return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
579 }
580
581 case models.ProviderAnthropic:
582 agentProvider, err = provider.NewAnthropicProvider(
583 provider.WithAnthropicSystemMessage(
584 prompt.CoderAnthropicSystemPrompt(),
585 ),
586 provider.WithAnthropicMaxTokens(maxTokens),
587 provider.WithAnthropicKey(providerConfig.APIKey),
588 provider.WithAnthropicModel(model),
589 )
590 if err != nil {
591 return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
592 }
593
594 titleGenerator, err = provider.NewAnthropicProvider(
595 provider.WithAnthropicSystemMessage(
596 prompt.TitlePrompt(),
597 ),
598 provider.WithAnthropicMaxTokens(80),
599 provider.WithAnthropicKey(providerConfig.APIKey),
600 provider.WithAnthropicModel(model),
601 )
602 if err != nil {
603 return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
604 }
605
606 case models.ProviderGemini:
607 agentProvider, err = provider.NewGeminiProvider(
608 ctx,
609 provider.WithGeminiSystemMessage(
610 prompt.CoderOpenAISystemPrompt(),
611 ),
612 provider.WithGeminiMaxTokens(int32(maxTokens)),
613 provider.WithGeminiKey(providerConfig.APIKey),
614 provider.WithGeminiModel(model),
615 )
616 if err != nil {
617 return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
618 }
619
620 titleGenerator, err = provider.NewGeminiProvider(
621 ctx,
622 provider.WithGeminiSystemMessage(
623 prompt.TitlePrompt(),
624 ),
625 provider.WithGeminiMaxTokens(80),
626 provider.WithGeminiKey(providerConfig.APIKey),
627 provider.WithGeminiModel(model),
628 )
629 if err != nil {
630 return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
631 }
632
633 case models.ProviderGROQ:
634 agentProvider, err = provider.NewOpenAIProvider(
635 provider.WithOpenAISystemMessage(
636 prompt.CoderAnthropicSystemPrompt(),
637 ),
638 provider.WithOpenAIMaxTokens(maxTokens),
639 provider.WithOpenAIModel(model),
640 provider.WithOpenAIKey(providerConfig.APIKey),
641 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
642 )
643 if err != nil {
644 return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
645 }
646
647 titleGenerator, err = provider.NewOpenAIProvider(
648 provider.WithOpenAISystemMessage(
649 prompt.TitlePrompt(),
650 ),
651 provider.WithOpenAIMaxTokens(80),
652 provider.WithOpenAIModel(model),
653 provider.WithOpenAIKey(providerConfig.APIKey),
654 provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
655 )
656 if err != nil {
657 return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
658 }
659
660 case models.ProviderBedrock:
661 agentProvider, err = provider.NewBedrockProvider(
662 provider.WithBedrockSystemMessage(
663 prompt.CoderAnthropicSystemPrompt(),
664 ),
665 provider.WithBedrockMaxTokens(maxTokens),
666 provider.WithBedrockModel(model),
667 )
668 if err != nil {
669 return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
670 }
671
672 titleGenerator, err = provider.NewBedrockProvider(
673 provider.WithBedrockSystemMessage(
674 prompt.TitlePrompt(),
675 ),
676 provider.WithBedrockMaxTokens(80),
677 provider.WithBedrockModel(model),
678 )
679 if err != nil {
680 return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
681 }
682 default:
683 return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
684 }
685
686 return agentProvider, titleGenerator, nil
687}