cleanup diff, cleanup agent

Kujtim Hoxha created

Change summary

README.md                              |  14 
cmd/diff/main.go                       | 102 -----
cmd/root.go                            |  12 
go.mod                                 |   2 
internal/app/app.go                    |  20 
internal/assets/diff/themes/dark.json  |  73 ---
internal/assets/embed.go               |   6 
internal/assets/write.go               |  60 ---
internal/git/diff.go                   |  35 
internal/llm/agent/agent-tool.go       |  34 +
internal/llm/agent/agent.go            | 522 +++++++++++++++++----------
internal/llm/agent/coder.go            |  83 +--
internal/llm/agent/task.go             |   7 
internal/llm/provider/provider.go      |   4 
internal/llm/tools/edit.go             |   7 
internal/llm/tools/tools.go            |   2 
internal/llm/tools/write.go            |   2 
internal/tui/components/repl/editor.go |   8 
internal/tui/page/chat.go              |  15 
19 files changed, 456 insertions(+), 552 deletions(-)

Detailed changes

README.md 🔗

@@ -53,6 +53,7 @@ termai -d
 ### Keyboard Shortcuts
 
 #### Global Shortcuts
+
 - `?`: Toggle help panel
 - `Ctrl+C` or `q`: Quit application
 - `L`: View logs
@@ -60,10 +61,12 @@ termai -d
 - `Esc`: Close current view/dialog or return to normal mode
 
 #### Session Management
+
 - `N`: Create new session
 - `Enter` or `Space`: Select session (in sessions list)
 
 #### Editor Shortcuts (Vim-like)
+
 - `i`: Enter insert mode
 - `Esc`: Enter normal mode
 - `v`: Enter visual mode
@@ -72,6 +75,7 @@ termai -d
 - `Ctrl+S`: Send message (in insert mode)
 
 #### Navigation
+
 - Arrow keys: Navigate through lists and content
 - Page Up/Down: Scroll through content
 
@@ -112,16 +116,6 @@ go build -o termai
 ./termai
 ```
 
-### Important: Building the Diff Script
-
-Before building or running the application, you must first build the diff script by running:
-
-```bash
-go run cmd/diff/main.go
-```
-
-This command generates the necessary JavaScript file (`index.mjs`) used by the diff functionality in the application.
-
 ## Acknowledgments
 
 TermAI builds upon the work of several open source projects and developers:

cmd/diff/main.go 🔗

@@ -1,102 +0,0 @@
-package main
-
-import (
-	"fmt"
-	"io"
-	"os"
-	"os/exec"
-	"path/filepath"
-)
-
-func main() {
-	// Create a temporary directory
-	tempDir, err := os.MkdirTemp("", "git-split-diffs")
-	if err != nil {
-		fmt.Printf("Error creating temp directory: %v\n", err)
-		os.Exit(1)
-	}
-	defer func() {
-		fmt.Printf("Cleaning up temporary directory: %s\n", tempDir)
-		os.RemoveAll(tempDir)
-	}()
-	fmt.Printf("Created temporary directory: %s\n", tempDir)
-
-	// Clone the repository with minimum depth
-	fmt.Println("Cloning git-split-diffs repository with minimum depth...")
-	cmd := exec.Command("git", "clone", "--depth=1", "https://github.com/kujtimiihoxha/git-split-diffs", tempDir)
-	cmd.Stdout = os.Stdout
-	cmd.Stderr = os.Stderr
-	if err := cmd.Run(); err != nil {
-		fmt.Printf("Error cloning repository: %v\n", err)
-		os.Exit(1)
-	}
-
-	// Run npm install
-	fmt.Println("Running npm install...")
-	cmdNpmInstall := exec.Command("npm", "install")
-	cmdNpmInstall.Dir = tempDir
-	cmdNpmInstall.Stdout = os.Stdout
-	cmdNpmInstall.Stderr = os.Stderr
-	if err := cmdNpmInstall.Run(); err != nil {
-		fmt.Printf("Error running npm install: %v\n", err)
-		os.Exit(1)
-	}
-
-	// Run npm run build
-	fmt.Println("Running npm run build...")
-	cmdNpmBuild := exec.Command("npm", "run", "build")
-	cmdNpmBuild.Dir = tempDir
-	cmdNpmBuild.Stdout = os.Stdout
-	cmdNpmBuild.Stderr = os.Stderr
-	if err := cmdNpmBuild.Run(); err != nil {
-		fmt.Printf("Error running npm run build: %v\n", err)
-		os.Exit(1)
-	}
-
-	destDir := filepath.Join(".", "internal", "assets", "diff")
-	destFile := filepath.Join(destDir, "index.mjs")
-
-	// Make sure the destination directory exists
-	if err := os.MkdirAll(destDir, 0o755); err != nil {
-		fmt.Printf("Error creating destination directory: %v\n", err)
-		os.Exit(1)
-	}
-
-	// Copy the file
-	srcFile := filepath.Join(tempDir, "build", "index.mjs")
-	fmt.Printf("Copying %s to %s\n", srcFile, destFile)
-	if err := copyFile(srcFile, destFile); err != nil {
-		fmt.Printf("Error copying file: %v\n", err)
-		os.Exit(1)
-	}
-
-	fmt.Println("Successfully completed the process!")
-}
-
-// copyFile copies a file from src to dst
-func copyFile(src, dst string) error {
-	sourceFile, err := os.Open(src)
-	if err != nil {
-		return err
-	}
-	defer sourceFile.Close()
-
-	destFile, err := os.Create(dst)
-	if err != nil {
-		return err
-	}
-	defer destFile.Close()
-
-	_, err = io.Copy(destFile, sourceFile)
-	if err != nil {
-		return err
-	}
-
-	// Make sure the file is written to disk
-	err = destFile.Sync()
-	if err != nil {
-		return err
-	}
-
-	return nil
-}

cmd/root.go 🔗

@@ -9,7 +9,6 @@ import (
 
 	tea "github.com/charmbracelet/bubbletea"
 	"github.com/kujtimiihoxha/termai/internal/app"
-	"github.com/kujtimiihoxha/termai/internal/assets"
 	"github.com/kujtimiihoxha/termai/internal/config"
 	"github.com/kujtimiihoxha/termai/internal/db"
 	"github.com/kujtimiihoxha/termai/internal/llm/agent"
@@ -52,11 +51,6 @@ var rootCmd = &cobra.Command{
 			return err
 		}
 
-		err = assets.WriteAssets()
-		if err != nil {
-			logging.Error("Error writing assets: %v", err)
-		}
-
 		// Connect DB, this will also run migrations
 		conn, err := db.Connect()
 		if err != nil {
@@ -67,7 +61,11 @@ var rootCmd = &cobra.Command{
 		ctx, cancel := context.WithCancel(context.Background())
 		defer cancel()
 
-		app := app.New(ctx, conn)
+		app, err := app.New(ctx, conn)
+		if err != nil {
+			logging.Error("Failed to create app: %v", err)
+			return err
+		}
 
 		// Set up the TUI
 		zone.NewGlobal()

go.mod 🔗

@@ -31,7 +31,6 @@ require (
 	github.com/muesli/reflow v0.3.0
 	github.com/muesli/termenv v0.16.0
 	github.com/openai/openai-go v0.1.0-beta.2
-	github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3
 	github.com/spf13/cobra v1.9.1
 	github.com/spf13/viper v1.20.0
 	github.com/stretchr/testify v1.10.0
@@ -107,6 +106,7 @@ require (
 	github.com/rivo/uniseg v0.4.7 // indirect
 	github.com/sagikazarmark/locafero v0.7.0 // indirect
 	github.com/sahilm/fuzzy v0.1.1 // indirect
+	github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect
 	github.com/skeema/knownhosts v1.3.1 // indirect
 	github.com/sourcegraph/conc v0.3.0 // indirect
 	github.com/spf13/afero v1.12.0 // indirect

internal/app/app.go 🔗

@@ -9,6 +9,7 @@ import (
 
 	"github.com/kujtimiihoxha/termai/internal/db"
 	"github.com/kujtimiihoxha/termai/internal/history"
+	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 	"github.com/kujtimiihoxha/termai/internal/logging"
 	"github.com/kujtimiihoxha/termai/internal/lsp"
 	"github.com/kujtimiihoxha/termai/internal/message"
@@ -22,6 +23,8 @@ type App struct {
 	Files       history.Service
 	Permissions permission.Service
 
+	CoderAgent agent.Service
+
 	LSPClients map[string]*lsp.Client
 
 	clientsMutex sync.RWMutex
@@ -31,7 +34,7 @@ type App struct {
 	watcherWG          sync.WaitGroup
 }
 
-func New(ctx context.Context, conn *sql.DB) *App {
+func New(ctx context.Context, conn *sql.DB) (*App, error) {
 	q := db.New(conn)
 	sessions := session.NewService(q)
 	messages := message.NewService(q)
@@ -45,9 +48,22 @@ func New(ctx context.Context, conn *sql.DB) *App {
 		LSPClients:  make(map[string]*lsp.Client),
 	}
 
+	var err error
+	app.CoderAgent, err = agent.NewCoderAgent(
+
+		app.Permissions,
+		app.Sessions,
+		app.Messages,
+		app.LSPClients,
+	)
+	if err != nil {
+		logging.Error("Failed to create coder agent", err)
+		return nil, err
+	}
+
 	app.initLSPClients(ctx)
 
-	return app
+	return app, nil
 }
 
 // Shutdown performs a clean shutdown of the application

internal/assets/diff/themes/dark.json 🔗

@@ -1,73 +0,0 @@
-{
-  "SYNTAX_HIGHLIGHTING_THEME": "dark-plus",
-  "DEFAULT_COLOR": {
-    "color": "#ffffff",
-    "backgroundColor": "#212121"
-  },
-  "COMMIT_HEADER_COLOR": {
-    "color": "#cccccc"
-  },
-  "COMMIT_HEADER_LABEL_COLOR": {
-    "color": "#00000022"
-  },
-  "COMMIT_SHA_COLOR": {
-    "color": "#00eeaa"
-  },
-  "COMMIT_AUTHOR_COLOR": {
-    "color": "#00aaee"
-  },
-  "COMMIT_DATE_COLOR": {
-    "color": "#cccccc"
-  },
-  "COMMIT_MESSAGE_COLOR": {
-    "color": "#cccccc"
-  },
-  "COMMIT_TITLE_COLOR": {
-    "modifiers": [
-      "bold"
-    ]
-  },
-  "FILE_NAME_COLOR": {
-    "color": "#ffdd99"
-  },
-  "BORDER_COLOR": {
-    "color": "#ffdd9966",
-    "modifiers": [
-      "dim"
-    ]
-  },
-  "HUNK_HEADER_COLOR": {
-    "modifiers": [
-      "dim"
-    ]
-  },
-  "DELETED_WORD_COLOR": {
-    "color": "#ffcccc",
-    "backgroundColor": "#ff000033"
-  },
-  "INSERTED_WORD_COLOR": {
-    "color": "#ccffcc",
-    "backgroundColor": "#00ff0033"
-  },
-  "DELETED_LINE_NO_COLOR": {
-    "color": "#00000022",
-    "backgroundColor": "#00000022"
-  },
-  "INSERTED_LINE_NO_COLOR": {
-    "color": "#00000022",
-    "backgroundColor": "#00000022"
-  },
-  "UNMODIFIED_LINE_NO_COLOR": {
-    "color": "#666666"
-  },
-  "DELETED_LINE_COLOR": {
-    "color": "#cc6666",
-    "backgroundColor": "#3a3030"
-  },
-  "INSERTED_LINE_COLOR": {
-    "color": "#66cc66",
-    "backgroundColor": "#303a30"
-  },
-  "UNMODIFIED_LINE_COLOR": {},
-  "MISSING_LINE_COLOR": {}
-}

internal/assets/write.go 🔗

@@ -1,60 +0,0 @@
-package assets
-
-import (
-	"os"
-	"path/filepath"
-
-	"github.com/kujtimiihoxha/termai/internal/config"
-)
-
-func WriteAssets() error {
-	appCfg := config.Get()
-	appWd := config.WorkingDirectory()
-	scriptDir := filepath.Join(
-		appWd,
-		appCfg.Data.Directory,
-		"diff",
-	)
-	scriptPath := filepath.Join(scriptDir, "index.mjs")
-	// Before, run the script in cmd/diff/main.go to build this file
-	if _, err := os.Stat(scriptPath); err != nil {
-		scriptData, err := FS.ReadFile("diff/index.mjs")
-		if err != nil {
-			return err
-		}
-
-		err = os.MkdirAll(scriptDir, 0o755)
-		if err != nil {
-			return err
-		}
-		err = os.WriteFile(scriptPath, scriptData, 0o755)
-		if err != nil {
-			return err
-		}
-	}
-
-	themeDir := filepath.Join(
-		appWd,
-		appCfg.Data.Directory,
-		"themes",
-	)
-
-	themePath := filepath.Join(themeDir, "dark.json")
-
-	if _, err := os.Stat(themePath); err != nil {
-		themeData, err := FS.ReadFile("diff/themes/dark.json")
-		if err != nil {
-			return err
-		}
-
-		err = os.MkdirAll(themeDir, 0o755)
-		if err != nil {
-			return err
-		}
-		err = os.WriteFile(themePath, themeData, 0o755)
-		if err != nil {
-			return err
-		}
-	}
-	return nil
-}

internal/git/diff.go 🔗

@@ -11,7 +11,6 @@ import (
 
 	"github.com/go-git/go-git/v5"
 	"github.com/go-git/go-git/v5/plumbing/object"
-	"github.com/kujtimiihoxha/termai/internal/config"
 )
 
 type DiffStats struct {
@@ -197,32 +196,32 @@ func isSplitDiffsAvailable() bool {
 }
 
 func formatWithSplitDiffs(diffText string, width int) (string, error) {
-	var cmd *exec.Cmd
+	args := []string{
+		"--color",
+	}
 
-	appCfg := config.Get()
-	appWd := config.WorkingDirectory()
-	script := filepath.Join(
-		appWd,
-		appCfg.Data.Directory,
-		"diff",
-		"index.mjs",
-	)
+	var diffCmd *exec.Cmd
 
-	cmd = exec.Command("node", script, "--color")
+	if _, err := exec.LookPath("git-split-diffs-opencode"); err == nil {
+		fullArgs := append([]string{"git-split-diffs-opencode"}, args...)
+		diffCmd = exec.Command(fullArgs[0], fullArgs[1:]...)
+	} else {
+		npxArgs := append([]string{"git-split-diffs-opencode"}, args...)
+		diffCmd = exec.Command("npx", npxArgs...)
+	}
 
-	cmd.Env = append(os.Environ(), fmt.Sprintf("COLUMNS=%d", width))
+	diffCmd.Env = append(os.Environ(), fmt.Sprintf("DIFF_COLUMNS=%d", width))
 
-	cmd.Stdin = strings.NewReader(diffText)
+	diffCmd.Stdin = strings.NewReader(diffText)
 
 	var out bytes.Buffer
-	cmd.Stdout = &out
+	diffCmd.Stdout = &out
 
 	var stderr bytes.Buffer
-	cmd.Stderr = &stderr
+	diffCmd.Stderr = &stderr
 
-	err := cmd.Run()
-	if err != nil {
-		return "", fmt.Errorf("git-split-diffs error: %v, stderr: %s", err, stderr.String())
+	if err := diffCmd.Run(); err != nil {
+		return "", fmt.Errorf("git-split-diffs-opencode error: %w, stderr: %s", err, stderr.String())
 	}
 
 	return out.String(), nil

internal/llm/agent/agent-tool.go 🔗

@@ -5,14 +5,16 @@ import (
 	"encoding/json"
 	"fmt"
 
-	"github.com/kujtimiihoxha/termai/internal/app"
 	"github.com/kujtimiihoxha/termai/internal/llm/tools"
+	"github.com/kujtimiihoxha/termai/internal/lsp"
 	"github.com/kujtimiihoxha/termai/internal/message"
+	"github.com/kujtimiihoxha/termai/internal/session"
 )
 
 type agentTool struct {
-	parentSessionID string
-	app             *app.App
+	sessions   session.Service
+	messages   message.Service
+	lspClients map[string]*lsp.Client
 }
 
 const (
@@ -46,12 +48,17 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
 		return tools.NewTextErrorResponse("prompt is required"), nil
 	}
 
-	agent, err := NewTaskAgent(b.app)
+	sessionID, messageID := tools.GetContextValues(ctx)
+	if sessionID == "" || messageID == "" {
+		return tools.NewTextErrorResponse("session ID and message ID are required"), nil
+	}
+
+	agent, err := NewTaskAgent(b.lspClients)
 	if err != nil {
 		return tools.NewTextErrorResponse(fmt.Sprintf("error creating agent: %s", err)), nil
 	}
 
-	session, err := b.app.Sessions.CreateTaskSession(ctx, call.ID, b.parentSessionID, "New Agent Session")
+	session, err := b.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
 	if err != nil {
 		return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil
 	}
@@ -61,7 +68,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
 		return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil
 	}
 
-	messages, err := b.app.Messages.List(ctx, session.ID)
+	messages, err := b.messages.List(ctx, session.ID)
 	if err != nil {
 		return tools.NewTextErrorResponse(fmt.Sprintf("error listing messages: %s", err)), nil
 	}
@@ -74,11 +81,11 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
 		return tools.NewTextErrorResponse("no assistant message found"), nil
 	}
 
-	updatedSession, err := b.app.Sessions.Get(ctx, session.ID)
+	updatedSession, err := b.sessions.Get(ctx, session.ID)
 	if err != nil {
 		return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
 	}
-	parentSession, err := b.app.Sessions.Get(ctx, b.parentSessionID)
+	parentSession, err := b.sessions.Get(ctx, sessionID)
 	if err != nil {
 		return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
 	}
@@ -87,16 +94,19 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
 	parentSession.PromptTokens += updatedSession.PromptTokens
 	parentSession.CompletionTokens += updatedSession.CompletionTokens
 
-	_, err = b.app.Sessions.Save(ctx, parentSession)
+	_, err = b.sessions.Save(ctx, parentSession)
 	if err != nil {
 		return tools.NewTextErrorResponse(fmt.Sprintf("error: %s", err)), nil
 	}
 	return tools.NewTextResponse(response.Content().String()), nil
 }
 
-func NewAgentTool(parentSessionID string, app *app.App) tools.BaseTool {
+func NewAgentTool(
+	Sessions session.Service,
+	Messages message.Service,
+) tools.BaseTool {
 	return &agentTool{
-		parentSessionID: parentSessionID,
-		app:             app,
+		sessions: Sessions,
+		messages: Messages,
 	}
 }

internal/llm/agent/agent.go 🔗

@@ -7,7 +7,6 @@ import (
 	"strings"
 	"sync"
 
-	"github.com/kujtimiihoxha/termai/internal/app"
 	"github.com/kujtimiihoxha/termai/internal/config"
 	"github.com/kujtimiihoxha/termai/internal/llm/models"
 	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
@@ -15,22 +14,118 @@ import (
 	"github.com/kujtimiihoxha/termai/internal/llm/tools"
 	"github.com/kujtimiihoxha/termai/internal/logging"
 	"github.com/kujtimiihoxha/termai/internal/message"
+	"github.com/kujtimiihoxha/termai/internal/session"
 )
 
-type Agent interface {
+// Common errors
+var (
+	ErrProviderNotEnabled = errors.New("provider is not enabled")
+	ErrRequestCancelled   = errors.New("request cancelled by user")
+	ErrSessionBusy        = errors.New("session is currently processing another request")
+)
+
+// Service defines the interface for generating responses
+type Service interface {
 	Generate(ctx context.Context, sessionID string, content string) error
+	Cancel(sessionID string) error
 }
 
 type agent struct {
-	*app.App
+	sessions       session.Service
+	messages       message.Service
 	model          models.Model
 	tools          []tools.BaseTool
 	agent          provider.Provider
 	titleGenerator provider.Provider
+	activeRequests sync.Map // map[sessionID]context.CancelFunc
+}
+
+// NewAgent creates a new agent instance with the given model and tools
+func NewAgent(ctx context.Context, sessions session.Service, messages message.Service, model models.Model, tools []tools.BaseTool) (Service, error) {
+	agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
+	if err != nil {
+		return nil, fmt.Errorf("failed to initialize providers: %w", err)
+	}
+
+	return &agent{
+		model:          model,
+		tools:          tools,
+		sessions:       sessions,
+		messages:       messages,
+		agent:          agentProvider,
+		titleGenerator: titleGenerator,
+		activeRequests: sync.Map{},
+	}, nil
+}
+
+// Cancel cancels an active request by session ID
+func (a *agent) Cancel(sessionID string) error {
+	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
+		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
+			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
+			cancel()
+			return nil
+		}
+	}
+	return errors.New("no active request found for this session")
 }
 
-func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
-	response, err := c.titleGenerator.SendMessages(
+// Generate starts the generation process
+func (a *agent) Generate(ctx context.Context, sessionID string, content string) error {
+	// Check if this session already has an active request
+	if _, busy := a.activeRequests.Load(sessionID); busy {
+		return ErrSessionBusy
+	}
+
+	// Create a cancellable context
+	genCtx, cancel := context.WithCancel(ctx)
+
+	// Store cancel function to allow user cancellation
+	a.activeRequests.Store(sessionID, cancel)
+
+	// Launch the generation in a goroutine
+	go func() {
+		defer func() {
+			if r := recover(); r != nil {
+				logging.ErrorPersist(fmt.Sprintf("Panic in Generate: %v", r))
+			}
+		}()
+		defer a.activeRequests.Delete(sessionID)
+		defer cancel()
+
+		if err := a.generate(genCtx, sessionID, content); err != nil {
+			if !errors.Is(err, ErrRequestCancelled) && !errors.Is(err, context.Canceled) {
+				// Log the error (avoid logging cancellations as they're expected)
+				logging.ErrorPersist(fmt.Sprintf("Generation error for session %s: %v", sessionID, err))
+
+				// You may want to create an error message in the chat
+				bgCtx := context.Background()
+				errorMsg := fmt.Sprintf("Sorry, an error occurred: %v", err)
+				_, createErr := a.messages.Create(bgCtx, sessionID, message.CreateMessageParams{
+					Role: message.System,
+					Parts: []message.ContentPart{
+						message.TextContent{
+							Text: errorMsg,
+						},
+					},
+				})
+				if createErr != nil {
+					logging.ErrorPersist(fmt.Sprintf("Failed to create error message: %v", createErr))
+				}
+			}
+		}
+	}()
+
+	return nil
+}
+
+// IsSessionBusy checks if a session currently has an active request
+func (a *agent) IsSessionBusy(sessionID string) bool {
+	_, busy := a.activeRequests.Load(sessionID)
+	return busy
+} // handleTitleGeneration asynchronously generates a title for new sessions
+func (a *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
+	response, err := a.titleGenerator.SendMessages(
 		ctx,
 		[]message.Message{
 			{
@@ -45,25 +140,30 @@ func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content st
 		nil,
 	)
 	if err != nil {
+		logging.ErrorPersist(fmt.Sprintf("Failed to generate title: %v", err))
 		return
 	}
 
-	session, err := c.Sessions.Get(ctx, sessionID)
+	session, err := a.sessions.Get(ctx, sessionID)
 	if err != nil {
+		logging.ErrorPersist(fmt.Sprintf("Failed to get session: %v", err))
 		return
 	}
+
 	if response.Content != "" {
-		session.Title = response.Content
-		session.Title = strings.TrimSpace(session.Title)
+		session.Title = strings.TrimSpace(response.Content)
 		session.Title = strings.ReplaceAll(session.Title, "\n", " ")
-		c.Sessions.Save(ctx, session)
+		if _, err := a.sessions.Save(ctx, session); err != nil {
+			logging.ErrorPersist(fmt.Sprintf("Failed to save session title: %v", err))
+		}
 	}
 }
 
-func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
-	session, err := c.Sessions.Get(ctx, sessionID)
+// TrackUsage updates token usage statistics for the session
+func (a *agent) TrackUsage(ctx context.Context, sessionID string, model models.Model, usage provider.TokenUsage) error {
+	session, err := a.sessions.Get(ctx, sessionID)
 	if err != nil {
-		return err
+		return fmt.Errorf("failed to get session: %w", err)
 	}
 
 	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
@@ -75,189 +175,241 @@ func (c *agent) TrackUsage(ctx context.Context, sessionID string, model models.M
 	session.CompletionTokens += usage.OutputTokens
 	session.PromptTokens += usage.InputTokens
 
-	_, err = c.Sessions.Save(ctx, session)
-	return err
+	_, err = a.sessions.Save(ctx, session)
+	if err != nil {
+		return fmt.Errorf("failed to save session: %w", err)
+	}
+	return nil
 }
 
-func (c *agent) processEvent(
+// processEvent handles different types of events during generation
+func (a *agent) processEvent(
 	ctx context.Context,
 	sessionID string,
 	assistantMsg *message.Message,
 	event provider.ProviderEvent,
 ) error {
+	select {
+	case <-ctx.Done():
+		return ctx.Err()
+	default:
+		// Continue processing
+	}
+
 	switch event.Type {
 	case provider.EventThinkingDelta:
 		assistantMsg.AppendReasoningContent(event.Content)
-		return c.Messages.Update(ctx, *assistantMsg)
+		return a.messages.Update(ctx, *assistantMsg)
 	case provider.EventContentDelta:
 		assistantMsg.AppendContent(event.Content)
-		return c.Messages.Update(ctx, *assistantMsg)
+		return a.messages.Update(ctx, *assistantMsg)
 	case provider.EventError:
 		if errors.Is(event.Error, context.Canceled) {
-			return nil
+			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
+			return context.Canceled
 		}
 		logging.ErrorPersist(event.Error.Error())
 		return event.Error
 	case provider.EventWarning:
 		logging.WarnPersist(event.Info)
-		return nil
 	case provider.EventInfo:
 		logging.InfoPersist(event.Info)
 	case provider.EventComplete:
 		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 		assistantMsg.AddFinish(event.Response.FinishReason)
-		err := c.Messages.Update(ctx, *assistantMsg)
-		if err != nil {
-			return err
+		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
+			return fmt.Errorf("failed to update message: %w", err)
 		}
-		return c.TrackUsage(ctx, sessionID, c.model, event.Response.Usage)
+		return a.TrackUsage(ctx, sessionID, a.model, event.Response.Usage)
 	}
 
 	return nil
 }
 
-func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
-	var wg sync.WaitGroup
+// ExecuteTools runs all tool calls sequentially and returns the results
+func (a *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, tls []tools.BaseTool) ([]message.ToolResult, error) {
 	toolResults := make([]message.ToolResult, len(toolCalls))
-	mutex := &sync.Mutex{}
-	errChan := make(chan error, 1)
 
 	// Create a child context that can be canceled
 	ctx, cancel := context.WithCancel(ctx)
 	defer cancel()
 
-	for i, tc := range toolCalls {
-		wg.Add(1)
-		go func(index int, toolCall message.ToolCall) {
-			defer wg.Done()
+	// Check if already canceled before starting any execution
+	if ctx.Err() != nil {
+		// Mark all tools as canceled
+		for i, toolCall := range toolCalls {
+			toolResults[i] = message.ToolResult{
+				ToolCallID: toolCall.ID,
+				Content:    "Tool execution canceled by user",
+				IsError:    true,
+			}
+		}
+		return toolResults, ctx.Err()
+	}
 
-			// Check if context is already canceled
-			select {
-			case <-ctx.Done():
-				mutex.Lock()
-				toolResults[index] = message.ToolResult{
-					ToolCallID: toolCall.ID,
-					Content:    "Tool execution canceled",
+	for i, toolCall := range toolCalls {
+		// Check for cancellation before executing each tool
+		select {
+		case <-ctx.Done():
+			// Mark this and all remaining tools as canceled
+			for j := i; j < len(toolCalls); j++ {
+				toolResults[j] = message.ToolResult{
+					ToolCallID: toolCalls[j].ID,
+					Content:    "Tool execution canceled by user",
 					IsError:    true,
 				}
-				mutex.Unlock()
-
-				// Send cancellation error to error channel if it's empty
-				select {
-				case errChan <- ctx.Err():
-				default:
-				}
-				return
-			default:
 			}
+			return toolResults, ctx.Err()
+		default:
+			// Continue processing
+		}
 
-			response := ""
-			isError := false
-			found := false
-
-			for _, tool := range tls {
-				if tool.Info().Name == toolCall.Name {
-					found = true
-					toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
-						ID:    toolCall.ID,
-						Name:  toolCall.Name,
-						Input: toolCall.Input,
-					})
-
-					if toolErr != nil {
-						if errors.Is(toolErr, context.Canceled) {
-							response = "Tool execution canceled"
-
-							// Send cancellation error to error channel if it's empty
-							select {
-							case errChan <- ctx.Err():
-							default:
-							}
-						} else {
-							response = fmt.Sprintf("error running tool: %s", toolErr)
-						}
-						isError = true
+		response := ""
+		isError := false
+		found := false
+
+		// Find and execute the appropriate tool
+		for _, tool := range tls {
+			if tool.Info().Name == toolCall.Name {
+				found = true
+				toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
+					ID:    toolCall.ID,
+					Name:  toolCall.Name,
+					Input: toolCall.Input,
+				})
+
+				if toolErr != nil {
+					if errors.Is(toolErr, context.Canceled) {
+						response = "Tool execution canceled by user"
 					} else {
-						response = toolResult.Content
-						isError = toolResult.IsError
+						response = fmt.Sprintf("Error running tool: %s", toolErr)
 					}
-					break
+					isError = true
+				} else {
+					response = toolResult.Content
+					isError = toolResult.IsError
 				}
+				break
 			}
+		}
 
-			if !found {
-				response = fmt.Sprintf("tool not found: %s", toolCall.Name)
-				isError = true
-			}
-
-			mutex.Lock()
-			defer mutex.Unlock()
-
-			toolResults[index] = message.ToolResult{
-				ToolCallID: toolCall.ID,
-				Content:    response,
-				IsError:    isError,
-			}
-		}(i, tc)
-	}
-
-	// Wait for all goroutines to finish or context to be canceled
-	done := make(chan struct{})
-	go func() {
-		wg.Wait()
-		close(done)
-	}()
+		if !found {
+			response = fmt.Sprintf("Tool not found: %s", toolCall.Name)
+			isError = true
+		}
 
-	select {
-	case <-done:
-		// All tools completed successfully
-	case err := <-errChan:
-		// One of the tools encountered a cancellation
-		return toolResults, err
-	case <-ctx.Done():
-		// Context was canceled externally
-		return toolResults, ctx.Err()
+		toolResults[i] = message.ToolResult{
+			ToolCallID: toolCall.ID,
+			Content:    response,
+			IsError:    isError,
+		}
 	}
 
 	return toolResults, nil
 }
 
-func (c *agent) handleToolExecution(
+// handleToolExecution processes tool calls and creates tool result messages
+func (a *agent) handleToolExecution(
 	ctx context.Context,
 	assistantMsg message.Message,
 ) (*message.Message, error) {
+	select {
+	case <-ctx.Done():
+		// If cancelled, create tool results that indicate cancellation
+		if len(assistantMsg.ToolCalls()) > 0 {
+			toolResults := make([]message.ToolResult, 0, len(assistantMsg.ToolCalls()))
+			for _, tc := range assistantMsg.ToolCalls() {
+				toolResults = append(toolResults, message.ToolResult{
+					ToolCallID: tc.ID,
+					Content:    "Tool execution canceled by user",
+					IsError:    true,
+				})
+			}
+
+			// Use background context to ensure the message is created even if original context is cancelled
+			bgCtx := context.Background()
+			parts := make([]message.ContentPart, 0)
+			for _, toolResult := range toolResults {
+				parts = append(parts, toolResult)
+			}
+			msg, err := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
+				Role:  message.Tool,
+				Parts: parts,
+			})
+			if err != nil {
+				return nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
+			}
+			return &msg, ctx.Err()
+		}
+		return nil, ctx.Err()
+	default:
+		// Continue processing
+	}
+
 	if len(assistantMsg.ToolCalls()) == 0 {
 		return nil, nil
 	}
 
-	toolResults, err := c.ExecuteTools(ctx, assistantMsg.ToolCalls(), c.tools)
+	toolResults, err := a.ExecuteTools(ctx, assistantMsg.ToolCalls(), a.tools)
 	if err != nil {
+		// If error is from cancellation, still return the partial results we have
+		if errors.Is(err, context.Canceled) {
+			// Use background context to ensure the message is created even if original context is cancelled
+			bgCtx := context.Background()
+			parts := make([]message.ContentPart, 0)
+			for _, toolResult := range toolResults {
+				parts = append(parts, toolResult)
+			}
+
+			msg, createErr := a.messages.Create(bgCtx, assistantMsg.SessionID, message.CreateMessageParams{
+				Role:  message.Tool,
+				Parts: parts,
+			})
+			if createErr != nil {
+				logging.ErrorPersist(fmt.Sprintf("Failed to create tool message after cancellation: %v", createErr))
+				return nil, err
+			}
+			return &msg, err
+		}
 		return nil, err
 	}
-	parts := make([]message.ContentPart, 0)
+
+	parts := make([]message.ContentPart, 0, len(toolResults))
 	for _, toolResult := range toolResults {
 		parts = append(parts, toolResult)
 	}
-	msg, err := c.Messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
+
+	msg, err := a.messages.Create(ctx, assistantMsg.SessionID, message.CreateMessageParams{
 		Role:  message.Tool,
 		Parts: parts,
 	})
+	if err != nil {
+		return nil, fmt.Errorf("failed to create tool message: %w", err)
+	}
 
-	return &msg, err
+	return &msg, nil
 }
 
-func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
+// generate handles the main generation workflow
+func (a *agent) generate(ctx context.Context, sessionID string, content string) error {
 	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
-	messages, err := c.Messages.List(ctx, sessionID)
+
+	// Handle context cancellation at any point
+	if err := ctx.Err(); err != nil {
+		return ErrRequestCancelled
+	}
+
+	messages, err := a.messages.List(ctx, sessionID)
 	if err != nil {
-		return err
+		return fmt.Errorf("failed to list messages: %w", err)
 	}
 
 	if len(messages) == 0 {
-		go c.handleTitleGeneration(ctx, sessionID, content)
+		titleCtx := context.Background()
+		go a.handleTitleGeneration(titleCtx, sessionID, content)
 	}
 
-	userMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
+	userMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 		Role: message.User,
 		Parts: []message.ContentPart{
 			message.TextContent{
@@ -266,133 +418,125 @@ func (c *agent) generate(ctx context.Context, sessionID string, content string)
 		},
 	})
 	if err != nil {
-		return err
+		return fmt.Errorf("failed to create user message: %w", err)
 	}
 
 	messages = append(messages, userMsg)
+
 	for {
+		// Check for cancellation before each iteration
 		select {
 		case <-ctx.Done():
-			assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
-				Role:  message.Assistant,
-				Parts: []message.ContentPart{},
-			})
-			if err != nil {
-				return err
-			}
-			assistantMsg.AddFinish("canceled")
-			c.Messages.Update(ctx, assistantMsg)
-			return context.Canceled
+			return ErrRequestCancelled
 		default:
 			// Continue processing
 		}
 
-		eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
+		eventChan, err := a.agent.StreamResponse(ctx, messages, a.tools)
 		if err != nil {
 			if errors.Is(err, context.Canceled) {
-				assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
-					Role:  message.Assistant,
-					Parts: []message.ContentPart{},
-				})
-				if err != nil {
-					return err
-				}
-				assistantMsg.AddFinish("canceled")
-				c.Messages.Update(ctx, assistantMsg)
-				return context.Canceled
+				return ErrRequestCancelled
 			}
-			return err
+			return fmt.Errorf("failed to stream response: %w", err)
 		}
 
-		assistantMsg, err := c.Messages.Create(ctx, sessionID, message.CreateMessageParams{
+		assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 			Role:  message.Assistant,
 			Parts: []message.ContentPart{},
-			Model: c.model.ID,
+			Model: a.model.ID,
 		})
 		if err != nil {
-			return err
+			return fmt.Errorf("failed to create assistant message: %w", err)
 		}
 
 		ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
+
+		// Process events from the LLM provider
 		for event := range eventChan {
-			err = c.processEvent(ctx, sessionID, &assistantMsg, event)
-			if err != nil {
+			if err := a.processEvent(ctx, sessionID, &assistantMsg, event); err != nil {
 				if errors.Is(err, context.Canceled) {
+					// Mark as canceled but don't create separate message
 					assistantMsg.AddFinish("canceled")
-					c.Messages.Update(ctx, assistantMsg)
-					return context.Canceled
+					_ = a.messages.Update(context.Background(), assistantMsg)
+					return ErrRequestCancelled
 				}
 				assistantMsg.AddFinish("error:" + err.Error())
-				c.Messages.Update(ctx, assistantMsg)
-				return err
+				_ = a.messages.Update(ctx, assistantMsg)
+				return fmt.Errorf("event processing error: %w", err)
 			}
 
+			// Check for cancellation during event processing
 			select {
 			case <-ctx.Done():
+				// Mark as canceled
 				assistantMsg.AddFinish("canceled")
-				c.Messages.Update(ctx, assistantMsg)
-				return context.Canceled
+				_ = a.messages.Update(context.Background(), assistantMsg)
+				return ErrRequestCancelled
 			default:
 			}
 		}
 
-		// Check for context cancellation before tool execution
+		// Check for cancellation before tool execution
 		select {
 		case <-ctx.Done():
-			assistantMsg.AddFinish("canceled")
-			c.Messages.Update(ctx, assistantMsg)
-			return context.Canceled
+			assistantMsg.AddFinish("canceled_by_user")
+			_ = a.messages.Update(context.Background(), assistantMsg)
+			return ErrRequestCancelled
 		default:
-			// Continue processing
 		}
 
-		msg, err := c.handleToolExecution(ctx, assistantMsg)
+		// Execute any tool calls
+		toolMsg, err := a.handleToolExecution(ctx, assistantMsg)
 		if err != nil {
 			if errors.Is(err, context.Canceled) {
-				assistantMsg.AddFinish("canceled")
-				c.Messages.Update(ctx, assistantMsg)
-				return context.Canceled
+				assistantMsg.AddFinish("canceled_by_user")
+				_ = a.messages.Update(context.Background(), assistantMsg)
+				return ErrRequestCancelled
 			}
-			return err
+			return fmt.Errorf("tool execution error: %w", err)
 		}
 
-		c.Messages.Update(ctx, assistantMsg)
+		if err := a.messages.Update(ctx, assistantMsg); err != nil {
+			return fmt.Errorf("failed to update assistant message: %w", err)
+		}
 
+		// If no tool calls, we're done
 		if len(assistantMsg.ToolCalls()) == 0 {
 			break
 		}
 
+		// Add messages for next iteration
 		messages = append(messages, assistantMsg)
-		if msg != nil {
-			messages = append(messages, *msg)
+		if toolMsg != nil {
+			messages = append(messages, *toolMsg)
 		}
 
-		// Check for context cancellation after tool execution
+		// Check for cancellation after tool execution
 		select {
 		case <-ctx.Done():
-			assistantMsg.AddFinish("canceled")
-			c.Messages.Update(ctx, assistantMsg)
-			return context.Canceled
+			return ErrRequestCancelled
 		default:
-			// Continue processing
 		}
 	}
+
 	return nil
 }
 
+// getAgentProviders initializes the LLM providers based on the chosen model
 func getAgentProviders(ctx context.Context, model models.Model) (provider.Provider, provider.Provider, error) {
 	maxTokens := config.Get().Model.CoderMaxTokens
 
 	providerConfig, ok := config.Get().Providers[model.Provider]
 	if !ok || providerConfig.Disabled {
-		return nil, nil, errors.New("provider is not enabled")
+		return nil, nil, ErrProviderNotEnabled
 	}
+
 	var agentProvider provider.Provider
 	var titleGenerator provider.Provider
+	var err error
 
 	switch model.Provider {
 	case models.ProviderOpenAI:
-		var err error
 		agentProvider, err = provider.NewOpenAIProvider(
 			provider.WithOpenAISystemMessage(
 				prompt.CoderOpenAISystemPrompt(),
@@ -402,8 +546,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			provider.WithOpenAIKey(providerConfig.APIKey),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create OpenAI agent provider: %w", err)
 		}
+
 		titleGenerator, err = provider.NewOpenAIProvider(
 			provider.WithOpenAISystemMessage(
 				prompt.TitlePrompt(),
@@ -413,10 +558,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			provider.WithOpenAIKey(providerConfig.APIKey),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create OpenAI title generator: %w", err)
 		}
+
 	case models.ProviderAnthropic:
-		var err error
 		agentProvider, err = provider.NewAnthropicProvider(
 			provider.WithAnthropicSystemMessage(
 				prompt.CoderAnthropicSystemPrompt(),
@@ -426,8 +571,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			provider.WithAnthropicModel(model),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create Anthropic agent provider: %w", err)
 		}
+
 		titleGenerator, err = provider.NewAnthropicProvider(
 			provider.WithAnthropicSystemMessage(
 				prompt.TitlePrompt(),
@@ -437,11 +583,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			provider.WithAnthropicModel(model),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create Anthropic title generator: %w", err)
 		}
 
 	case models.ProviderGemini:
-		var err error
 		agentProvider, err = provider.NewGeminiProvider(
 			ctx,
 			provider.WithGeminiSystemMessage(
@@ -452,8 +597,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			provider.WithGeminiModel(model),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create Gemini agent provider: %w", err)
 		}
+
 		titleGenerator, err = provider.NewGeminiProvider(
 			ctx,
 			provider.WithGeminiSystemMessage(
@@ -464,10 +610,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			provider.WithGeminiModel(model),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create Gemini title generator: %w", err)
 		}
+
 	case models.ProviderGROQ:
-		var err error
 		agentProvider, err = provider.NewOpenAIProvider(
 			provider.WithOpenAISystemMessage(
 				prompt.CoderAnthropicSystemPrompt(),
@@ -478,8 +624,9 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create GROQ agent provider: %w", err)
 		}
+
 		titleGenerator, err = provider.NewOpenAIProvider(
 			provider.WithOpenAISystemMessage(
 				prompt.TitlePrompt(),
@@ -490,11 +637,10 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			provider.WithOpenAIBaseURL("https://api.groq.com/openai/v1"),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create GROQ title generator: %w", err)
 		}
 
 	case models.ProviderBedrock:
-		var err error
 		agentProvider, err = provider.NewBedrockProvider(
 			provider.WithBedrockSystemMessage(
 				prompt.CoderAnthropicSystemPrompt(),
@@ -503,19 +649,21 @@ func getAgentProviders(ctx context.Context, model models.Model) (provider.Provid
 			provider.WithBedrockModel(model),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create Bedrock agent provider: %w", err)
 		}
+
 		titleGenerator, err = provider.NewBedrockProvider(
 			provider.WithBedrockSystemMessage(
 				prompt.TitlePrompt(),
 			),
-			provider.WithBedrockMaxTokens(maxTokens),
+			provider.WithBedrockMaxTokens(80),
 			provider.WithBedrockModel(model),
 		)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, fmt.Errorf("failed to create Bedrock title generator: %w", err)
 		}
-
+	default:
+		return nil, nil, fmt.Errorf("unsupported provider: %s", model.Provider)
 	}
 
 	return agentProvider, titleGenerator, nil

internal/llm/agent/coder.go 🔗

@@ -4,71 +4,60 @@ import (
 	"context"
 	"errors"
 
-	"github.com/kujtimiihoxha/termai/internal/app"
 	"github.com/kujtimiihoxha/termai/internal/config"
 	"github.com/kujtimiihoxha/termai/internal/llm/models"
 	"github.com/kujtimiihoxha/termai/internal/llm/tools"
+	"github.com/kujtimiihoxha/termai/internal/lsp"
+	"github.com/kujtimiihoxha/termai/internal/message"
+	"github.com/kujtimiihoxha/termai/internal/permission"
+	"github.com/kujtimiihoxha/termai/internal/session"
 )
 
 type coderAgent struct {
-	*agent
+	Service
 }
 
-func (c *coderAgent) setAgentTool(sessionID string) {
-	inx := -1
-	for i, tool := range c.tools {
-		if tool.Info().Name == AgentToolName {
-			inx = i
-			break
-		}
-	}
-	if inx == -1 {
-		c.tools = append(c.tools, NewAgentTool(sessionID, c.App))
-	} else {
-		c.tools[inx] = NewAgentTool(sessionID, c.App)
-	}
-}
-
-func (c *coderAgent) Generate(ctx context.Context, sessionID string, content string) error {
-	c.setAgentTool(sessionID)
-	return c.generate(ctx, sessionID, content)
-}
-
-func NewCoderAgent(app *app.App) (Agent, error) {
+func NewCoderAgent(
+	permissions permission.Service,
+	sessions session.Service,
+	messages message.Service,
+	lspClients map[string]*lsp.Client,
+) (Service, error) {
 	model, ok := models.SupportedModels[config.Get().Model.Coder]
 	if !ok {
 		return nil, errors.New("model not supported")
 	}
 
 	ctx := context.Background()
-	agentProvider, titleGenerator, err := getAgentProviders(ctx, model)
+	otherTools := GetMcpTools(ctx, permissions)
+	if len(lspClients) > 0 {
+		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
+	}
+	agent, err := NewAgent(
+		ctx,
+		sessions,
+		messages,
+		model,
+		append(
+			[]tools.BaseTool{
+				tools.NewBashTool(permissions),
+				tools.NewEditTool(lspClients, permissions),
+				tools.NewFetchTool(permissions),
+				tools.NewGlobTool(),
+				tools.NewGrepTool(),
+				tools.NewLsTool(),
+				tools.NewSourcegraphTool(),
+				tools.NewViewTool(lspClients),
+				tools.NewWriteTool(lspClients, permissions),
+				NewAgentTool(sessions, messages),
+			}, otherTools...,
+		),
+	)
 	if err != nil {
 		return nil, err
 	}
 
-	otherTools := GetMcpTools(ctx, app.Permissions)
-	if len(app.LSPClients) > 0 {
-		otherTools = append(otherTools, tools.NewDiagnosticsTool(app.LSPClients))
-	}
 	return &coderAgent{
-		agent: &agent{
-			App: app,
-			tools: append(
-				[]tools.BaseTool{
-					tools.NewBashTool(app.Permissions),
-					tools.NewEditTool(app.LSPClients, app.Permissions),
-					tools.NewFetchTool(app.Permissions),
-					tools.NewGlobTool(),
-					tools.NewGrepTool(),
-					tools.NewLsTool(),
-					tools.NewSourcegraphTool(),
-					tools.NewViewTool(app.LSPClients),
-					tools.NewWriteTool(app.LSPClients, app.Permissions),
-				}, otherTools...,
-			),
-			model:          model,
-			agent:          agentProvider,
-			titleGenerator: titleGenerator,
-		},
+		agent,
 	}, nil
 }

internal/llm/agent/task.go 🔗

@@ -4,10 +4,10 @@ import (
 	"context"
 	"errors"
 
-	"github.com/kujtimiihoxha/termai/internal/app"
 	"github.com/kujtimiihoxha/termai/internal/config"
 	"github.com/kujtimiihoxha/termai/internal/llm/models"
 	"github.com/kujtimiihoxha/termai/internal/llm/tools"
+	"github.com/kujtimiihoxha/termai/internal/lsp"
 )
 
 type taskAgent struct {
@@ -18,7 +18,7 @@ func (c *taskAgent) Generate(ctx context.Context, sessionID string, content stri
 	return c.generate(ctx, sessionID, content)
 }
 
-func NewTaskAgent(app *app.App) (Agent, error) {
+func NewTaskAgent(lspClients map[string]*lsp.Client) (Service, error) {
 	model, ok := models.SupportedModels[config.Get().Model.Coder]
 	if !ok {
 		return nil, errors.New("model not supported")
@@ -31,13 +31,12 @@ func NewTaskAgent(app *app.App) (Agent, error) {
 	}
 	return &taskAgent{
 		agent: &agent{
-			App: app,
 			tools: []tools.BaseTool{
 				tools.NewGlobTool(),
 				tools.NewGrepTool(),
 				tools.NewLsTool(),
 				tools.NewSourcegraphTool(),
-				tools.NewViewTool(app.LSPClients),
+				tools.NewViewTool(lspClients),
 			},
 			model:          model,
 			agent:          agentProvider,

internal/llm/provider/provider.go 🔗

@@ -57,7 +57,9 @@ func cleanupMessages(messages []message.Message) []message.Message {
 	// First pass: filter out canceled messages
 	var cleanedMessages []message.Message
 	for _, msg := range messages {
-		if msg.FinishReason() != "canceled" {
+		if msg.FinishReason() != "canceled" || len(msg.ToolCalls()) > 0 {
+			// if there are toolCalls this means we want to return it to the LLM telling it that those tools have been
+			// cancelled
 			cleanedMessages = append(cleanedMessages, msg)
 		}
 	}

internal/llm/tools/edit.go 🔗

@@ -190,7 +190,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string)
 		return er, fmt.Errorf("failed to create parent directories: %w", err)
 	}
 
-	sessionID, messageID := getContextValues(ctx)
+	sessionID, messageID := GetContextValues(ctx)
 	if sessionID == "" || messageID == "" {
 		return er, fmt.Errorf("session ID and message ID are required for creating a new file")
 	}
@@ -277,7 +277,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string
 
 	newContent := oldContent[:index] + oldContent[index+len(oldString):]
 
-	sessionID, messageID := getContextValues(ctx)
+	sessionID, messageID := GetContextValues(ctx)
 
 	if sessionID == "" || messageID == "" {
 		return er, fmt.Errorf("session ID and message ID are required for creating a new file")
@@ -365,7 +365,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
 
 	newContent := oldContent[:index] + newString + oldContent[index+len(oldString):]
 
-	sessionID, messageID := getContextValues(ctx)
+	sessionID, messageID := GetContextValues(ctx)
 
 	if sessionID == "" || messageID == "" {
 		return er, fmt.Errorf("session ID and message ID are required for creating a new file")
@@ -409,4 +409,3 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS
 
 	return er, nil
 }
-

internal/llm/tools/tools.go 🔗

@@ -66,7 +66,7 @@ type BaseTool interface {
 	Run(ctx context.Context, params ToolCall) (ToolResponse, error)
 }
 
-func getContextValues(ctx context.Context) (string, string) {
+func GetContextValues(ctx context.Context) (string, string) {
 	sessionID := ctx.Value(SessionIDContextKey)
 	messageID := ctx.Value(MessageIDContextKey)
 	if sessionID == nil {

internal/llm/tools/write.go 🔗

@@ -144,7 +144,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error
 		}
 	}
 
-	sessionID, messageID := getContextValues(ctx)
+	sessionID, messageID := GetContextValues(ctx)
 	if sessionID == "" || messageID == "" {
 		return NewTextErrorResponse("session ID or message ID is missing"), nil
 	}

internal/tui/components/repl/editor.go 🔗

@@ -7,7 +7,6 @@ import (
 	tea "github.com/charmbracelet/bubbletea"
 	"github.com/charmbracelet/lipgloss"
 	"github.com/kujtimiihoxha/termai/internal/app"
-	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 	"github.com/kujtimiihoxha/termai/internal/tui/layout"
 	"github.com/kujtimiihoxha/termai/internal/tui/styles"
 	"github.com/kujtimiihoxha/termai/internal/tui/util"
@@ -168,11 +167,6 @@ func (m *editorCmp) Send() tea.Cmd {
 		return util.ReportWarn("Assistant is still working on the previous message")
 	}
 
-	a, err := agent.NewCoderAgent(m.app)
-	if err != nil {
-		return util.ReportError(err)
-	}
-
 	content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
 	if len(content) == 0 {
 		return util.ReportWarn("Message is empty")
@@ -181,7 +175,7 @@ func (m *editorCmp) Send() tea.Cmd {
 	m.cancelMessage = cancel
 	go func() {
 		defer cancel()
-		a.Generate(ctx, m.sessionID, content)
+		m.app.CoderAgent.Generate(ctx, m.sessionID, content)
 		m.cancelMessage = nil
 	}()
 

internal/tui/page/chat.go 🔗

@@ -6,7 +6,6 @@ import (
 	"github.com/charmbracelet/bubbles/key"
 	tea "github.com/charmbracelet/bubbletea"
 	"github.com/kujtimiihoxha/termai/internal/app"
-	"github.com/kujtimiihoxha/termai/internal/llm/agent"
 	"github.com/kujtimiihoxha/termai/internal/session"
 	"github.com/kujtimiihoxha/termai/internal/tui/components/chat"
 	"github.com/kujtimiihoxha/termai/internal/tui/layout"
@@ -23,6 +22,7 @@ type chatPage struct {
 
 type ChatKeyMap struct {
 	NewSession key.Binding
+	Cancel     key.Binding
 }
 
 var keyMap = ChatKeyMap{
@@ -30,6 +30,10 @@ var keyMap = ChatKeyMap{
 		key.WithKeys("ctrl+n"),
 		key.WithHelp("ctrl+n", "new session"),
 	),
+	Cancel: key.NewBinding(
+		key.WithKeys("ctrl+x"),
+		key.WithHelp("ctrl+x", "cancel"),
+	),
 }
 
 func (p *chatPage) Init() tea.Cmd {
@@ -106,15 +110,8 @@ func (p *chatPage) sendMessage(text string) tea.Cmd {
 		}
 		cmds = append(cmds, util.CmdHandler(chat.SessionSelectedMsg(session)))
 	}
-	// TODO: move this to a service
-	a, err := agent.NewCoderAgent(p.app)
-	if err != nil {
-		return util.ReportError(err)
-	}
-	go func() {
-		a.Generate(context.Background(), p.session.ID, text)
-	}()
 
+	p.app.CoderAgent.Generate(context.Background(), p.session.ID, text)
 	return tea.Batch(cmds...)
 }