agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"strings"
  8	"sync"
  9
 10	"github.com/kujtimiihoxha/termai/internal/config"
 11	"github.com/kujtimiihoxha/termai/internal/llm/models"
 12	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 13	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 14	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 15	"github.com/kujtimiihoxha/termai/internal/logging"
 16	"github.com/kujtimiihoxha/termai/internal/message"
 17	"github.com/kujtimiihoxha/termai/internal/session"
 18)
 19
 20// Common errors
 21var (
 22	ErrProviderNotEnabled = errors.New("provider is not enabled")
 23	ErrRequestCancelled   = errors.New("request cancelled by user")
 24	ErrSessionBusy        = errors.New("session is currently processing another request")
 25)
 26
 27// Service defines the interface for generating responses
 28type Service interface {
 29	Generate(ctx context.Context, sessionID string, content string) error
 30	Cancel(sessionID string) error
 31}
 32
 33type agent struct {
 34	sessions       session.Service
 35	messages       message.Service
 36	model          models.Model
 37	tools          []tools.BaseTool
 38	agent          provider.Provider
 39	titleGenerator provider.Provider
 40	activeRequests sync.Map // map[sessionID]context.CancelFunc
 41}
 42
 43// NewAgent creates a new agent instance with the given model and tools
 44func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
 45	agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
 46	if err != nil {
 47		return nil, fmt.Errorf("failed to initialize providers: %w", err)
 48	}
 49
 50	return &agent{
 51		model:          model,
 52		tools:          tools,
 53		sessions:       sessions,
 54		messages:       messages,
 55		agent:          agentProvider,
 56		titleGenerator: titleGenerator,
 57		activeRequests: sync.Map{},
 58	}, nil
 59}
 60
 61// Cancel cancels an active request by session ID
 62func (a *agent) Cancel(sessionID string) error {
 63	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
 64		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
 65			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
 66			cancel()
 67			return nil
 68		}
 69	}
 70	return errors.New("no active request found for this session")
 71}
 72
 73// Generate starts the generation process
 74func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
 75	// Check if this session already has an active request
 76	if _, busy := a.activeRequests.Load(sessionID); busy {
 77		return ErrSessionBusy
 78	}
 79
 80	// Create a cancellable context
 81	genCtx, cancel := context.WithCancel(ctx)
 82
 83	// Store cancel function to allow user cancellation
 84	a.activeRequests.Store(sessionID, cancel)
 85
 86	// Launch the generation in a goroutine
 87	go func() {
 88		defer func() {
 89			if r := recover(); r != nil {
 90				logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
 91			}
 92		}()
 93		defer a.activeRequests.Delete(sessionID)
 94		defer cancel()
 95
 96		if err := a.generate(genCtx, sessionID, content); err != nil {
 97			if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
 98				// Log the error (avoid logging cancellations as they're expected)
 99				logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
100
101				// You may want to create an error message in the chat
102				bgCtx := context.Background()
103				errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
104				_, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
105					Role: message.System,
106					Parts: []message.ContentPart{
107						message.TextContent{
108							Text: errorMsg,
109						},
110					},
111				})
112				if createErr != nil {
113					logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
114				}
115			}
116		}
117	}()
118
119	return nil
120}
121
122// IsSessionBusy checks if a session currently has an active request
123func (a *agent) IsSessionBusy(sessionID string) bool {
124	_, busy := a.activeRequests.Load(sessionID)
125	return busy
126} // handleTitleGeneration asynchronously generates a title for new sessions
127func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
128	response, err := a.titleGenerator.SendMessages(
129		ctx,
130		[]message.Message{
131			{
132				Role: message.User,
133				Parts: []message.ContentPart{
134					message.TextContent{
135						Text: content,
136					},
137				},
138			},
139		},
140		nil,
141	)
142	if err != nil {
143		logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
144		return
145	}
146
147	session, err := a.sessions.Get(ctx, sessionID)
148	if err != nil {
149		logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
150		return
151	}
152
153	if response.Content != "" {
154		session.Title = strings.TrimSpace(response.Content)
155		session.Title = strings.ReplaceAll(session.Title, "\n", " ")
156		if _, err := a.sessions.Save(ctx, session); err != nil {
157			logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
158		}
159	}
160}
161
162// TrackUsage updates token usage statistics for the session
163func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
164	session, err := a.sessions.Get(ctx, sessionID)
165	if err != nil {
166		return fmt.Errorf("failed to get session: %w", err)
167	}
168
169	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
170		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
171		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
172		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
173
174	session.Cost += cost
175	session.CompletionTokens += usage.OutputTokens
176	session.PromptTokens += usage.InputTokens
177
178	_, err = a.sessions.Save(ctx, session)
179	if err != nil {
180		return fmt.Errorf("failed to save session: %w", err)
181	}
182	return nil
183}
184
185// processEvent handles different types of events during generation
186func (a *agent) processEvent(
187	ctx context.Context,
188	sessionID string,
189	assistantMsg *message.Message,
190	event provider.ProviderEvent,
191) error {
192	select {
193	case <-ctx.Done():
194		return ctx.Err()
195	default:
196		// Continue processing
197	}
198
199	switch event.Type {
200	case provider.EventThinkingDelta:
201		assistantMsg.AppendReasoningContent(event.Content)
202		return a.messages.Update(ctx, *assistantMsg)
203	case provider.EventContentDelta:
204		assistantMsg.AppendContent(event.Content)
205		return a.messages.Update(ctx, *assistantMsg)
206	case provider.EventError:
207		if errors.Is(event.Error, context.Canceled) {
208			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
209			return context.Canceled
210		}
211		logging.ErrorPersist(event.Error.Error())
212		return event.Error
213	case provider.EventWarning:
214		logging.WarnPersist(event.Info)
215	case provider.EventInfo:
216		logging.InfoPersist(event.Info)
217	case provider.EventComplete:
218		assistantMsg.SetToolCalls(event.Response.ToolCalls)
219		assistantMsg.AddFinish(event.Response.FinishReason)
220		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
221			return fmt.Errorf("failed to update message: %w", err)
222		}
223		return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
224	}
225
226	return nil
227}
228
229// ExecuteTools runs all tool calls sequentially and returns the results
230func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
231	toolResults := make([]message.ToolResult, len(toolCalls))
232
233	// Create a child context that can be canceled
234	ctx, cancel := context.WithCancel(ctx)
235	defer cancel()
236
237	// Check if already canceled before starting any execution
238	if ctx.Err() != nil {
239		// Mark all tools as canceled
240		for i, toolCall := range toolCalls {
241			toolResults[i] = message.ToolResult{
242				ToolCallID: toolCall.ID,
243				Content:    "Tool execution canceled by user",
244				IsError:    true,
245			}
246		}
247		return toolResults, ctx.Err()
248	}
249
250	for i, toolCall := range toolCalls {
251		// Check for cancellation before executing each tool
252		select {
253		case <-ctx.Done():
254			// Mark this and all remaining tools as canceled
255			for j := i; j < len(toolCalls); j++ {
256				toolResults[j] = message.ToolResult{
257					ToolCallID: toolCalls[j].ID,
258					Content:    "Tool execution canceled by user",
259					IsError:    true,
260				}
261			}
262			return toolResults, ctx.Err()
263		default:
264			// Continue processing
265		}
266
267		response := ""
268		isError := false
269		found := false
270
271		// Find and execute the appropriate tool
272		for _, tool := range tls {
273			if tool.Info().Name == toolCall.Name {
274				found = true
275				toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
276					ID:    toolCall.ID,
277					Name:  toolCall.Name,
278					Input: toolCall.Input,
279				})
280
281				if toolErr != nil {
282					if errors.Is(toolErr, context.Canceled) {
283						response = "Tool execution canceled by user"
284					} else {
285						response = fmt.Sprintf("Error running tool: %s", toolErr)
286					}
287					isError = true
288				} else {
289					response = toolResult.Content
290					isError = toolResult.IsError
291				}
292				break
293			}
294		}
295
296		if !found {
297			response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
298			isError = true
299		}
300
301		toolResults[i] = message.ToolResult{
302			ToolCallID: toolCall.ID,
303			Content:    response,
304			IsError:    isError,
305		}
306	}
307
308	return toolResults, nil
309}
310
311// handleToolExecution processes tool calls and creates tool result messages
312func (a *agent) handleToolExecution(
313	ctx context.Context,
314	assistantMsg message.Message,
315) (*message.Message, error) {
316	select {
317	case <-ctx.Done():
318		// If cancelled, create tool results that indicate cancellation
319		if len(assistantMsg.ToolCalls()) > 0 {
320			toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
321			for _, tc := range assistantMsg.ToolCalls() {
322				toolResults = append(toolResults, message.ToolResult{
323					ToolCallID: tc.ID,
324					Content:    "Tool execution canceled by user",
325					IsError:    true,
326				})
327			}
328
329			// Use background context to ensure the message is created even if original context is cancelled
330			bgCtx := context.Background()
331			parts := make([]message.ContentPart, 0)
332			for _, toolResult := range toolResults {
333				parts = append(parts, toolResult)
334			}
335			msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
336				Role:  message.Tool,
337				Parts: parts,
338			})
339			if err != nil {
340				return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
341			}
342			return &msg, ctx.Err()
343		}
344		return nil, ctx.Err()
345	default:
346		// Continue processing
347	}
348
349	if len(assistantMsg.ToolCalls()) == 0 {
350		return nil, nil
351	}
352
353	toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
354	if err != nil {
355		// If error is from cancellation, still return the partial results we have
356		if errors.Is(err, context.Canceled) {
357			// Use background context to ensure the message is created even if original context is cancelled
358			bgCtx := context.Background()
359			parts := make([]message.ContentPart, 0)
360			for _, toolResult := range toolResults {
361				parts = append(parts, toolResult)
362			}
363
364			msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
365				Role:  message.Tool,
366				Parts: parts,
367			})
368			if createErr != nil {
369				logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
370				return nil, err
371			}
372			return &msg, err
373		}
374		return nil, err
375	}
376
377	parts := make([]message.ContentPart, 0, len(toolResults))
378	for _, toolResult := range toolResults {
379		parts = append(parts, toolResult)
380	}
381
382	msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
383		Role:  message.Tool,
384		Parts: parts,
385	})
386	if err != nil {
387		return nil, fmt.Errorf("failed to create tool message: %w", err)
388	}
389
390	return &msg, nil
391}
392
393// generate handles the main generation workflow
394func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
395	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
396
397	// Handle context cancellation at any point
398	if err := ctx.Err(); err != nil {
399		return ErrRequestCancelled
400	}
401
402	messages, err := a.messages.List(ctx, sessionID)
403	if err != nil {
404		return fmt.Errorf("failed to list messages: %w", err)
405	}
406
407	if len(messages) == 0 {
408		titleCtx := context.Background()
409		go a.handleTitleGeneration(titleCtx, sessionID, content)
410	}
411
412	userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
413		Role: message.User,
414		Parts: []message.ContentPart{
415			message.TextContent{
416				Text: content,
417			},
418		},
419	})
420	if err != nil {
421		return fmt.Errorf("failed to create user message: %w", err)
422	}
423
424	messages = append(messages, userMsg)
425
426	for {
427		// Check for cancellation before each iteration
428		select {
429		case <-ctx.Done():
430			return ErrRequestCancelled
431		default:
432			// Continue processing
433		}
434
435		eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
436		if err != nil {
437			if errors.Is(err, context.Canceled) {
438				return ErrRequestCancelled
439			}
440			return fmt.Errorf("failed to stream response: %w", err)
441		}
442
443		assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444			Role:  message.Assistant,
445			Parts: []message.ContentPart{},
446			Model: a.model.ID,
447		})
448		if err != nil {
449			return fmt.Errorf("failed to create assistant message: %w", err)
450		}
451
452		ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454		// Process events from the LLM provider
455		for event := range eventChan {
456			if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
457				if errors.Is(err, context.Canceled) {
458					// Mark as canceled but don't create separate message
459					assistantMsg.AddFinish("canceled")
460					_ = a.messages.Update(context.Background(), assistantMsg)
461					return ErrRequestCancelled
462				}
463				assistantMsg.AddFinish("error:" + err.Error())
464				_ = a.messages.Update(ctx, assistantMsg)
465				return fmt.Errorf("event processing error: %w", err)
466			}
467
468			// Check for cancellation during event processing
469			select {
470			case <-ctx.Done():
471				// Mark as canceled
472				assistantMsg.AddFinish("canceled")
473				_ = a.messages.Update(context.Background(), assistantMsg)
474				return ErrRequestCancelled
475			default:
476			}
477		}
478
479		// Check for cancellation before tool execution
480		select {
481		case <-ctx.Done():
482			assistantMsg.AddFinish("canceled_by_user")
483			_ = a.messages.Update(context.Background(), assistantMsg)
484			return ErrRequestCancelled
485		default:
486		}
487
488		// Execute any tool calls
489		toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
490		if err != nil {
491			if errors.Is(err, context.Canceled) {
492				assistantMsg.AddFinish("canceled_by_user")
493				_ = a.messages.Update(context.Background(), assistantMsg)
494				return ErrRequestCancelled
495			}
496			return fmt.Errorf("tool execution error: %w", err)
497		}
498
499		if err := a.messages.Update(ctx, assistantMsg); err != nil {
500			return fmt.Errorf("failed to update assistant message: %w", err)
501		}
502
503		// If no tool calls, we're done
504		if len(assistantMsg.ToolCalls()) == 0 {
505			break
506		}
507
508		// Add messages for next iteration
509		messages = append(messages, assistantMsg)
510		if toolMsg != nil {
511			messages = append(messages, *toolMsg)
512		}
513
514		// Check for cancellation after tool execution
515		select {
516		case <-ctx.Done():
517			return ErrRequestCancelled
518		default:
519		}
520	}
521
522	return nil
523}
524
525// getAgentProviders initializes the LLM providers based on the chosen model
526func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
527	maxTokens := config.Get().Model.CoderMaxTokens
528
529	providerConfig, ok := config.Get().Providers[model.Provider]
530	if !ok || providerConfig.Disabled {
531		return nil, nil, ErrProviderNotEnabled
532	}
533
534	var agentProvider provider.Provider
535	var titleGenerator provider.Provider
536	var err error
537
538	switch model.Provider {
539	case models.ProviderOpenAI:
540		agentProvider, err = provider.NewOpenAIProvider(
541			provider.WithOpenAISystemMessage(
542				prompt.CoderOpenAISystemPrompt(),
543			),
544			provider.WithOpenAIMaxTokens(maxTokens),
545			provider.WithOpenAIModel(model),
546			provider.WithOpenAIKey(providerConfig.APIKey),
547		)
548		if err != nil {
549			return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
550		}
551
552		titleGenerator, err = provider.NewOpenAIProvider(
553			provider.WithOpenAISystemMessage(
554				prompt.TitlePrompt(),
555			),
556			provider.WithOpenAIMaxTokens(80),
557			provider.WithOpenAIModel(model),
558			provider.WithOpenAIKey(providerConfig.APIKey),
559		)
560		if err != nil {
561			return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
562		}
563
564	case models.ProviderAnthropic:
565		agentProvider, err = provider.NewAnthropicProvider(
566			provider.WithAnthropicSystemMessage(
567				prompt.CoderAnthropicSystemPrompt(),
568			),
569			provider.WithAnthropicMaxTokens(maxTokens),
570			provider.WithAnthropicKey(providerConfig.APIKey),
571			provider.WithAnthropicModel(model),
572		)
573		if err != nil {
574			return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
575		}
576
577		titleGenerator, err = provider.NewAnthropicProvider(
578			provider.WithAnthropicSystemMessage(
579				prompt.TitlePrompt(),
580			),
581			provider.WithAnthropicMaxTokens(80),
582			provider.WithAnthropicKey(providerConfig.APIKey),
583			provider.WithAnthropicModel(model),
584		)
585		if err != nil {
586			return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
587		}
588
589	case models.ProviderGemini:
590		agentProvider, err = provider.NewGeminiProvider(
591			ctx,
592			provider.WithGeminiSystemMessage(
593				prompt.CoderOpenAISystemPrompt(),
594			),
595			provider.WithGeminiMaxTokens(int32(maxTokens)),
596			provider.WithGeminiKey(providerConfig.APIKey),
597			provider.WithGeminiModel(model),
598		)
599		if err != nil {
600			return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
601		}
602
603		titleGenerator, err = provider.NewGeminiProvider(
604			ctx,
605			provider.WithGeminiSystemMessage(
606				prompt.TitlePrompt(),
607			),
608			provider.WithGeminiMaxTokens(80),
609			provider.WithGeminiKey(providerConfig.APIKey),
610			provider.WithGeminiModel(model),
611		)
612		if err != nil {
613			return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
614		}
615
616	case models.ProviderGROQ:
617		agentProvider, err = provider.NewOpenAIProvider(
618			provider.WithOpenAISystemMessage(
619				prompt.CoderAnthropicSystemPrompt(),
620			),
621			provider.WithOpenAIMaxTokens(maxTokens),
622			provider.WithOpenAIModel(model),
623			provider.WithOpenAIKey(providerConfig.APIKey),
624			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
625		)
626		if err != nil {
627			return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
628		}
629
630		titleGenerator, err = provider.NewOpenAIProvider(
631			provider.WithOpenAISystemMessage(
632				prompt.TitlePrompt(),
633			),
634			provider.WithOpenAIMaxTokens(80),
635			provider.WithOpenAIModel(model),
636			provider.WithOpenAIKey(providerConfig.APIKey),
637			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
638		)
639		if err != nil {
640			return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
641		}
642
643	case models.ProviderBedrock:
644		agentProvider, err = provider.NewBedrockProvider(
645			provider.WithBedrockSystemMessage(
646				prompt.CoderAnthropicSystemPrompt(),
647			),
648			provider.WithBedrockMaxTokens(maxTokens),
649			provider.WithBedrockModel(model),
650		)
651		if err != nil {
652			return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
653		}
654
655		titleGenerator, err = provider.NewBedrockProvider(
656			provider.WithBedrockSystemMessage(
657				prompt.TitlePrompt(),
658			),
659			provider.WithBedrockMaxTokens(80),
660			provider.WithBedrockModel(model),
661		)
662		if err != nil {
663			return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
664		}
665	default:
666		return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
667	}
668
669	return agentProvider, titleGenerator, nil
670}