agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"os"
  8	"runtime/debug"
  9	"strings"
 10	"sync"
 11
 12	"github.com/kujtimiihoxha/termai/internal/config"
 13	"github.com/kujtimiihoxha/termai/internal/llm/models"
 14	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 15	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 16	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 17	"github.com/kujtimiihoxha/termai/internal/logging"
 18	"github.com/kujtimiihoxha/termai/internal/message"
 19	"github.com/kujtimiihoxha/termai/internal/session"
 20)
 21
 22// Common errors
 23var (
 24	ErrProviderNotEnabled = errors.New("provider is not enabled")
 25	ErrRequestCancelled   = errors.New("request cancelled by user")
 26	ErrSessionBusy        = errors.New("session is currently processing another request")
 27)
 28
 29// Service defines the interface for generating responses
 30type Service interface {
 31	Generate(ctx context.Context, sessionID string, content string) error
 32	Cancel(sessionID string) error
 33}
 34
 35type agent struct {
 36	sessions       session.Service
 37	messages       message.Service
 38	model          models.Model
 39	tools          []tools.BaseTool
 40	agent          provider.Provider
 41	titleGenerator provider.Provider
 42	activeRequests sync.Map // map[sessionID]context.CancelFunc
 43}
 44
 45// NewAgent creates a new agent instance with the given model and tools
 46func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
 47	agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
 48	if err != nil {
 49		return nil, fmt.Errorf("failed to initialize providers: %w", err)
 50	}
 51
 52	return &agent{
 53		model:          model,
 54		tools:          tools,
 55		sessions:       sessions,
 56		messages:       messages,
 57		agent:          agentProvider,
 58		titleGenerator: titleGenerator,
 59		activeRequests: sync.Map{},
 60	}, nil
 61}
 62
 63// Cancel cancels an active request by session ID
 64func (a *agent) Cancel(sessionID string) error {
 65	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
 66		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
 67			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
 68			cancel()
 69			return nil
 70		}
 71	}
 72	return errors.New("no active request found for this session")
 73}
 74
 75// Generate starts the generation process
 76func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
 77	// Check if this session already has an active request
 78	if _, busy := a.activeRequests.Load(sessionID); busy {
 79		return ErrSessionBusy
 80	}
 81
 82	// Create a cancellable context
 83	genCtx, cancel := context.WithCancel(ctx)
 84
 85	// Store cancel function to allow user cancellation
 86	a.activeRequests.Store(sessionID, cancel)
 87
 88	// Launch the generation in a goroutine
 89	go func() {
 90		defer func() {
 91			if r := recover(); r != nil {
 92				logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
 93
 94				// dump stack trace into a file
 95				file, err := os.Create("panic.log")
 96				if err != nil {
 97					logging.ErrorPersist(fmt.Sprintf("Failed to create panic log: %v", err))
 98					return
 99				}
100
101				defer file.Close()
102
103				stackTrace := debug.Stack()
104				if _, err := file.Write(stackTrace); err != nil {
105					logging.ErrorPersist(fmt.Sprintf("Failed to write panic log: %v", err))
106				}
107
108			}
109		}()
110		defer a.activeRequests.Delete(sessionID)
111		defer cancel()
112
113		if err := a.generate(genCtx, sessionID, content); err != nil {
114			if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
115				// Log the error (avoid logging cancellations as they're expected)
116				logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
117
118				// You may want to create an error message in the chat
119				bgCtx := context.Background()
120				errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
121				_, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
122					Role: message.System,
123					Parts: []message.ContentPart{
124						message.TextContent{
125							Text: errorMsg,
126						},
127					},
128				})
129				if createErr != nil {
130					logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
131				}
132			}
133		}
134	}()
135
136	return nil
137}
138
139// IsSessionBusy checks if a session currently has an active request
140func (a *agent) IsSessionBusy(sessionID string) bool {
141	_, busy := a.activeRequests.Load(sessionID)
142	return busy
143} // handleTitleGeneration asynchronously generates a title for new sessions
144func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
145	response, err := a.titleGenerator.SendMessages(
146		ctx,
147		[]message.Message{
148			{
149				Role: message.User,
150				Parts: []message.ContentPart{
151					message.TextContent{
152						Text: content,
153					},
154				},
155			},
156		},
157		nil,
158	)
159	if err != nil {
160		logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
161		return
162	}
163
164	session, err := a.sessions.Get(ctx, sessionID)
165	if err != nil {
166		logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
167		return
168	}
169
170	if response.Content != "" {
171		session.Title = strings.TrimSpace(response.Content)
172		session.Title = strings.ReplaceAll(session.Title, "\n", " ")
173		if _, err := a.sessions.Save(ctx, session); err != nil {
174			logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
175		}
176	}
177}
178
179// TrackUsage updates token usage statistics for the session
180func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
181	session, err := a.sessions.Get(ctx, sessionID)
182	if err != nil {
183		return fmt.Errorf("failed to get session: %w", err)
184	}
185
186	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
187		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
188		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
189		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
190
191	session.Cost += cost
192	session.CompletionTokens += usage.OutputTokens
193	session.PromptTokens += usage.InputTokens
194
195	_, err = a.sessions.Save(ctx, session)
196	if err != nil {
197		return fmt.Errorf("failed to save session: %w", err)
198	}
199	return nil
200}
201
202// processEvent handles different types of events during generation
203func (a *agent) processEvent(
204	ctx context.Context,
205	sessionID string,
206	assistantMsg *message.Message,
207	event provider.ProviderEvent,
208) error {
209	select {
210	case <-ctx.Done():
211		return ctx.Err()
212	default:
213		// Continue processing
214	}
215
216	switch event.Type {
217	case provider.EventThinkingDelta:
218		assistantMsg.AppendReasoningContent(event.Content)
219		return a.messages.Update(ctx, *assistantMsg)
220	case provider.EventContentDelta:
221		assistantMsg.AppendContent(event.Content)
222		return a.messages.Update(ctx, *assistantMsg)
223	case provider.EventError:
224		if errors.Is(event.Error, context.Canceled) {
225			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
226			return context.Canceled
227		}
228		logging.ErrorPersist(event.Error.Error())
229		return event.Error
230	case provider.EventWarning:
231		logging.WarnPersist(event.Info)
232	case provider.EventInfo:
233		logging.InfoPersist(event.Info)
234	case provider.EventComplete:
235		assistantMsg.SetToolCalls(event.Response.ToolCalls)
236		assistantMsg.AddFinish(event.Response.FinishReason)
237		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
238			return fmt.Errorf("failed to update message: %w", err)
239		}
240		return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
241	}
242
243	return nil
244}
245
246// ExecuteTools runs all tool calls sequentially and returns the results
247func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
248	toolResults := make([]message.ToolResult, len(toolCalls))
249
250	// Create a child context that can be canceled
251	ctx, cancel := context.WithCancel(ctx)
252	defer cancel()
253
254	// Check if already canceled before starting any execution
255	if ctx.Err() != nil {
256		// Mark all tools as canceled
257		for i, toolCall := range toolCalls {
258			toolResults[i] = message.ToolResult{
259				ToolCallID: toolCall.ID,
260				Content:    "Tool execution canceled by user",
261				IsError:    true,
262			}
263		}
264		return toolResults, ctx.Err()
265	}
266
267	for i, toolCall := range toolCalls {
268		// Check for cancellation before executing each tool
269		select {
270		case <-ctx.Done():
271			// Mark this and all remaining tools as canceled
272			for j := i; j < len(toolCalls); j++ {
273				toolResults[j] = message.ToolResult{
274					ToolCallID: toolCalls[j].ID,
275					Content:    "Tool execution canceled by user",
276					IsError:    true,
277				}
278			}
279			return toolResults, ctx.Err()
280		default:
281			// Continue processing
282		}
283
284		response := ""
285		isError := false
286		found := false
287
288		// Find and execute the appropriate tool
289		for _, tool := range tls {
290			if tool.Info().Name == toolCall.Name {
291				found = true
292				toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
293					ID:    toolCall.ID,
294					Name:  toolCall.Name,
295					Input: toolCall.Input,
296				})
297
298				if toolErr != nil {
299					if errors.Is(toolErr, context.Canceled) {
300						response = "Tool execution canceled by user"
301					} else {
302						response = fmt.Sprintf("Error running tool: %s", toolErr)
303					}
304					isError = true
305				} else {
306					response = toolResult.Content
307					isError = toolResult.IsError
308				}
309				break
310			}
311		}
312
313		if !found {
314			response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
315			isError = true
316		}
317
318		toolResults[i] = message.ToolResult{
319			ToolCallID: toolCall.ID,
320			Content:    response,
321			IsError:    isError,
322		}
323	}
324
325	return toolResults, nil
326}
327
328// handleToolExecution processes tool calls and creates tool result messages
329func (a *agent) handleToolExecution(
330	ctx context.Context,
331	assistantMsg message.Message,
332) (*message.Message, error) {
333	select {
334	case <-ctx.Done():
335		// If cancelled, create tool results that indicate cancellation
336		if len(assistantMsg.ToolCalls()) > 0 {
337			toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
338			for _, tc := range assistantMsg.ToolCalls() {
339				toolResults = append(toolResults, message.ToolResult{
340					ToolCallID: tc.ID,
341					Content:    "Tool execution canceled by user",
342					IsError:    true,
343				})
344			}
345
346			// Use background context to ensure the message is created even if original context is cancelled
347			bgCtx := context.Background()
348			parts := make([]message.ContentPart, 0)
349			for _, toolResult := range toolResults {
350				parts = append(parts, toolResult)
351			}
352			msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
353				Role:  message.Tool,
354				Parts: parts,
355			})
356			if err != nil {
357				return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
358			}
359			return &msg, ctx.Err()
360		}
361		return nil, ctx.Err()
362	default:
363		// Continue processing
364	}
365
366	if len(assistantMsg.ToolCalls()) == 0 {
367		return nil, nil
368	}
369
370	toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
371	if err != nil {
372		// If error is from cancellation, still return the partial results we have
373		if errors.Is(err, context.Canceled) {
374			// Use background context to ensure the message is created even if original context is cancelled
375			bgCtx := context.Background()
376			parts := make([]message.ContentPart, 0)
377			for _, toolResult := range toolResults {
378				parts = append(parts, toolResult)
379			}
380
381			msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
382				Role:  message.Tool,
383				Parts: parts,
384			})
385			if createErr != nil {
386				logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
387				return nil, err
388			}
389			return &msg, err
390		}
391		return nil, err
392	}
393
394	parts := make([]message.ContentPart, 0, len(toolResults))
395	for _, toolResult := range toolResults {
396		parts = append(parts, toolResult)
397	}
398
399	msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
400		Role:  message.Tool,
401		Parts: parts,
402	})
403	if err != nil {
404		return nil, fmt.Errorf("failed to create tool message: %w", err)
405	}
406
407	return &msg, nil
408}
409
410// generate handles the main generation workflow
411func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
412	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
413
414	// Handle context cancellation at any point
415	if err := ctx.Err(); err != nil {
416		return ErrRequestCancelled
417	}
418
419	messages, err := a.messages.List(ctx, sessionID)
420	if err != nil {
421		return fmt.Errorf("failed to list messages: %w", err)
422	}
423
424	if len(messages) == 0 {
425		titleCtx := context.Background()
426		go a.handleTitleGeneration(titleCtx, sessionID, content)
427	}
428
429	userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
430		Role: message.User,
431		Parts: []message.ContentPart{
432			message.TextContent{
433				Text: content,
434			},
435		},
436	})
437	if err != nil {
438		return fmt.Errorf("failed to create user message: %w", err)
439	}
440
441	messages = append(messages, userMsg)
442
443	for {
444		// Check for cancellation before each iteration
445		select {
446		case <-ctx.Done():
447			return ErrRequestCancelled
448		default:
449			// Continue processing
450		}
451
452		eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
453		if err != nil {
454			if errors.Is(err, context.Canceled) {
455				return ErrRequestCancelled
456			}
457			return fmt.Errorf("failed to stream response: %w", err)
458		}
459
460		assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
461			Role:  message.Assistant,
462			Parts: []message.ContentPart{},
463			Model: a.model.ID,
464		})
465		if err != nil {
466			return fmt.Errorf("failed to create assistant message: %w", err)
467		}
468
469		ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470
471		// Process events from the LLM provider
472		for event := range eventChan {
473			if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
474				if errors.Is(err, context.Canceled) {
475					// Mark as canceled but don't create separate message
476					assistantMsg.AddFinish("canceled")
477					_ = a.messages.Update(context.Background(), assistantMsg)
478					return ErrRequestCancelled
479				}
480				assistantMsg.AddFinish("error:" + err.Error())
481				_ = a.messages.Update(ctx, assistantMsg)
482				return fmt.Errorf("event processing error: %w", err)
483			}
484
485			// Check for cancellation during event processing
486			select {
487			case <-ctx.Done():
488				// Mark as canceled
489				assistantMsg.AddFinish("canceled")
490				_ = a.messages.Update(context.Background(), assistantMsg)
491				return ErrRequestCancelled
492			default:
493			}
494		}
495
496		// Check for cancellation before tool execution
497		select {
498		case <-ctx.Done():
499			assistantMsg.AddFinish("canceled_by_user")
500			_ = a.messages.Update(context.Background(), assistantMsg)
501			return ErrRequestCancelled
502		default:
503		}
504
505		// Execute any tool calls
506		toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
507		if err != nil {
508			if errors.Is(err, context.Canceled) {
509				assistantMsg.AddFinish("canceled_by_user")
510				_ = a.messages.Update(context.Background(), assistantMsg)
511				return ErrRequestCancelled
512			}
513			return fmt.Errorf("tool execution error: %w", err)
514		}
515
516		if err := a.messages.Update(ctx, assistantMsg); err != nil {
517			return fmt.Errorf("failed to update assistant message: %w", err)
518		}
519
520		// If no tool calls, we're done
521		if len(assistantMsg.ToolCalls()) == 0 {
522			break
523		}
524
525		// Add messages for next iteration
526		messages = append(messages, assistantMsg)
527		if toolMsg != nil {
528			messages = append(messages, *toolMsg)
529		}
530
531		// Check for cancellation after tool execution
532		select {
533		case <-ctx.Done():
534			return ErrRequestCancelled
535		default:
536		}
537	}
538
539	return nil
540}
541
542// getAgentProviders initializes the LLM providers based on the chosen model
543func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
544	maxTokens := config.Get().Model.CoderMaxTokens
545
546	providerConfig, ok := config.Get().Providers[model.Provider]
547	if !ok || providerConfig.Disabled {
548		return nil, nil, ErrProviderNotEnabled
549	}
550
551	var agentProvider provider.Provider
552	var titleGenerator provider.Provider
553	var err error
554
555	switch model.Provider {
556	case models.ProviderOpenAI:
557		agentProvider, err = provider.NewOpenAIProvider(
558			provider.WithOpenAISystemMessage(
559				prompt.CoderOpenAISystemPrompt(),
560			),
561			provider.WithOpenAIMaxTokens(maxTokens),
562			provider.WithOpenAIModel(model),
563			provider.WithOpenAIKey(providerConfig.APIKey),
564		)
565		if err != nil {
566			return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
567		}
568
569		titleGenerator, err = provider.NewOpenAIProvider(
570			provider.WithOpenAISystemMessage(
571				prompt.TitlePrompt(),
572			),
573			provider.WithOpenAIMaxTokens(80),
574			provider.WithOpenAIModel(model),
575			provider.WithOpenAIKey(providerConfig.APIKey),
576		)
577		if err != nil {
578			return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
579		}
580
581	case models.ProviderAnthropic:
582		agentProvider, err = provider.NewAnthropicProvider(
583			provider.WithAnthropicSystemMessage(
584				prompt.CoderAnthropicSystemPrompt(),
585			),
586			provider.WithAnthropicMaxTokens(maxTokens),
587			provider.WithAnthropicKey(providerConfig.APIKey),
588			provider.WithAnthropicModel(model),
589		)
590		if err != nil {
591			return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
592		}
593
594		titleGenerator, err = provider.NewAnthropicProvider(
595			provider.WithAnthropicSystemMessage(
596				prompt.TitlePrompt(),
597			),
598			provider.WithAnthropicMaxTokens(80),
599			provider.WithAnthropicKey(providerConfig.APIKey),
600			provider.WithAnthropicModel(model),
601		)
602		if err != nil {
603			return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
604		}
605
606	case models.ProviderGemini:
607		agentProvider, err = provider.NewGeminiProvider(
608			ctx,
609			provider.WithGeminiSystemMessage(
610				prompt.CoderOpenAISystemPrompt(),
611			),
612			provider.WithGeminiMaxTokens(int32(maxTokens)),
613			provider.WithGeminiKey(providerConfig.APIKey),
614			provider.WithGeminiModel(model),
615		)
616		if err != nil {
617			return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
618		}
619
620		titleGenerator, err = provider.NewGeminiProvider(
621			ctx,
622			provider.WithGeminiSystemMessage(
623				prompt.TitlePrompt(),
624			),
625			provider.WithGeminiMaxTokens(80),
626			provider.WithGeminiKey(providerConfig.APIKey),
627			provider.WithGeminiModel(model),
628		)
629		if err != nil {
630			return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
631		}
632
633	case models.ProviderGROQ:
634		agentProvider, err = provider.NewOpenAIProvider(
635			provider.WithOpenAISystemMessage(
636				prompt.CoderAnthropicSystemPrompt(),
637			),
638			provider.WithOpenAIMaxTokens(maxTokens),
639			provider.WithOpenAIModel(model),
640			provider.WithOpenAIKey(providerConfig.APIKey),
641			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
642		)
643		if err != nil {
644			return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
645		}
646
647		titleGenerator, err = provider.NewOpenAIProvider(
648			provider.WithOpenAISystemMessage(
649				prompt.TitlePrompt(),
650			),
651			provider.WithOpenAIMaxTokens(80),
652			provider.WithOpenAIModel(model),
653			provider.WithOpenAIKey(providerConfig.APIKey),
654			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
655		)
656		if err != nil {
657			return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
658		}
659
660	case models.ProviderBedrock:
661		agentProvider, err = provider.NewBedrockProvider(
662			provider.WithBedrockSystemMessage(
663				prompt.CoderAnthropicSystemPrompt(),
664			),
665			provider.WithBedrockMaxTokens(maxTokens),
666			provider.WithBedrockModel(model),
667		)
668		if err != nil {
669			return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
670		}
671
672		titleGenerator, err = provider.NewBedrockProvider(
673			provider.WithBedrockSystemMessage(
674				prompt.TitlePrompt(),
675			),
676			provider.WithBedrockMaxTokens(80),
677			provider.WithBedrockModel(model),
678		)
679		if err != nil {
680			return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
681		}
682	default:
683		return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
684	}
685
686	return agentProvider, titleGenerator, nil
687}