fix logs and add cancellation

Kujtim Hoxha created

Change summary

cmd/root.go                              |  19 ++
go.mod                                   |   4 
internal/app/services.go                 |  14 +-
internal/db/connect.go                   |  18 +-
internal/llm/agent/agent-tool.go         |   2 
internal/llm/agent/agent.go              | 148 +++++++++++++++++++++++--
internal/llm/agent/coder.go              |   5 
internal/llm/agent/mcp-tools.go          |  10 -
internal/llm/agent/task.go               |   5 
internal/logging/default.go              |  12 --
internal/logging/logger.go               | 136 ++--------------------
internal/logging/logging.go              |  23 ----
internal/logging/writer.go               |  49 +++++++
internal/lsp/client.go                   |   8 
internal/lsp/handlers.go                 |  13 +-
internal/lsp/transport.go                |  31 ++--
internal/lsp/watcher/watcher.go          |  58 ++++-----
internal/tui/components/core/status.go   |   8 
internal/tui/components/logs/details.go  |   2 
internal/tui/components/logs/table.go    |   6 
internal/tui/components/repl/editor.go   |  41 +++++-
internal/tui/components/repl/messages.go |  11 +
internal/tui/tui.go                      |   3 
23 files changed, 343 insertions(+), 283 deletions(-)

Detailed changes

cmd/root.go 🔗

@@ -2,6 +2,7 @@ package cmd
 
 import (
 	"context"
+	"log/slog"
 	"os"
 	"sync"
 
@@ -10,6 +11,7 @@ import (
 	"github.com/kujtimiihoxha/termai/internal/config"
 	"github.com/kujtimiihoxha/termai/internal/db"
 	"github.com/kujtimiihoxha/termai/internal/llm/agent"
+	"github.com/kujtimiihoxha/termai/internal/logging"
 	"github.com/kujtimiihoxha/termai/internal/tui"
 	zone "github.com/lrstanley/bubblezone"
 	"github.com/spf13/cobra"
@@ -26,6 +28,16 @@ var rootCmd = &cobra.Command{
 		}
 		debug, _ := cmd.Flags().GetBool("debug")
 		err := config.Load(debug)
+		cfg := config.Get()
+		defaultLevel := slog.LevelInfo
+		if cfg.Debug {
+			defaultLevel = slog.LevelDebug
+		}
+		logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{
+			Level: defaultLevel,
+		}))
+		slog.SetDefault(logger)
+
 		if err != nil {
 			return err
 		}
@@ -37,14 +49,14 @@ var rootCmd = &cobra.Command{
 
 		app := app.New(ctx, conn)
 		defer app.Close()
-		app.Logger.Info("Starting termai...")
+		logging.Info("Starting termai...")
 		zone.NewGlobal()
 		tui := tea.NewProgram(
 			tui.New(app),
 			tea.WithAltScreen(),
 			tea.WithMouseCellMotion(),
 		)
-		app.Logger.Info("Setting up subscriptions...")
+		logging.Info("Setting up subscriptions...")
 		ch, unsub := setupSubscriptions(app)
 		defer unsub()
 
@@ -66,9 +78,8 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) {
 	ch := make(chan tea.Msg)
 	wg := sync.WaitGroup{}
 	ctx, cancel := context.WithCancel(app.Context)
-
 	{
-		sub := app.Logger.Subscribe(ctx)
+		sub := logging.Subscribe(ctx)
 		wg.Add(1)
 		go func() {
 			for ev := range sub {

go.mod 🔗

@@ -33,7 +33,7 @@ require (
 	github.com/spf13/cobra v1.9.1
 	github.com/spf13/viper v1.20.0
 	github.com/stretchr/testify v1.10.0
-	golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
+	golang.org/x/net v0.34.0
 	google.golang.org/api v0.215.0
 )
 
@@ -116,10 +116,10 @@ require (
 	go.uber.org/multierr v1.9.0 // indirect
 	golang.design/x/clipboard v0.7.0 // indirect
 	golang.org/x/crypto v0.33.0 // indirect
+	golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect
 	golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect
 	golang.org/x/image v0.14.0 // indirect
 	golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect
-	golang.org/x/net v0.34.0 // indirect
 	golang.org/x/oauth2 v0.25.0 // indirect
 	golang.org/x/sync v0.12.0 // indirect
 	golang.org/x/sys v0.31.0 // indirect

internal/app/services.go 🔗

@@ -3,6 +3,7 @@ package app
 import (
 	"context"
 	"database/sql"
+	"log/slog"
 
 	"github.com/kujtimiihoxha/termai/internal/config"
 	"github.com/kujtimiihoxha/termai/internal/db"
@@ -23,16 +24,14 @@ type App struct {
 
 	LSPClients map[string]*lsp.Client
 
-	Logger logging.Interface
-
 	ceanups []func()
 }
 
 func New(ctx context.Context, conn *sql.DB) *App {
 	cfg := config.Get()
+	logging.Info("Debug mode enabled")
+
 	q := db.New(conn)
-	log := logging.Get()
-	log.SetLevel(cfg.Log.Level)
 	sessions := session.NewService(ctx, q)
 	messages := message.NewService(ctx, q)
 
@@ -41,7 +40,6 @@ func New(ctx context.Context, conn *sql.DB) *App {
 		Sessions:    sessions,
 		Messages:    messages,
 		Permissions: permission.NewPermissionService(),
-		Logger:      log,
 		LSPClients:  make(map[string]*lsp.Client),
 	}
 
@@ -52,13 +50,13 @@ func New(ctx context.Context, conn *sql.DB) *App {
 		})
 		workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient)
 		if err != nil {
-			log.Error("Failed to create LSP client for", name, err)
+			logging.Error("Failed to create LSP client for", name, err)
 			continue
 		}
 
 		_, err = lspClient.InitializeLSPClient(ctx, config.WorkingDirectory())
 		if err != nil {
-			log.Error("Initialize failed", "error", err)
+			logging.Error("Initialize failed", "error", err)
 			continue
 		}
 		go workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory())
@@ -74,5 +72,5 @@ func (a *App) Close() {
 	for _, client := range a.LSPClients {
 		client.Close()
 	}
-	a.Logger.Info("App closed")
+	slog.Info("App closed")
 }

internal/db/connect.go 🔗

@@ -16,8 +16,6 @@ import (
 	"github.com/kujtimiihoxha/termai/internal/logging"
 )
 
-var log = logging.Get()
-
 func Connect() (*sql.DB, error) {
 	dataDir := config.Get().Data.Directory
 	if dataDir == "" {
@@ -50,43 +48,43 @@ func Connect() (*sql.DB, error) {
 
 	for _, pragma := range pragmas {
 		if _, err = db.Exec(pragma); err != nil {
-			log.Warn("Failed to set pragma", pragma, err)
+			logging.Warn("Failed to set pragma", pragma, err)
 		} else {
-			log.Warn("Set pragma", "pragma", pragma)
+			logging.Warn("Set pragma", "pragma", pragma)
 		}
 	}
 
 	// Initialize schema from embedded file
 	d, err := iofs.New(FS, "migrations")
 	if err != nil {
-		log.Error("Failed to open embedded migrations", "error", err)
+		logging.Error("Failed to open embedded migrations", "error", err)
 		db.Close()
 		return nil, fmt.Errorf("failed to open embedded migrations: %w", err)
 	}
 
 	driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
 	if err != nil {
-		log.Error("Failed to create SQLite driver", "error", err)
+		logging.Error("Failed to create SQLite driver", "error", err)
 		db.Close()
 		return nil, fmt.Errorf("failed to create SQLite driver: %w", err)
 	}
 
 	m, err := migrate.NewWithInstance("iofs", d, "ql", driver)
 	if err != nil {
-		log.Error("Failed to create migration instance", "error", err)
+		logging.Error("Failed to create migration instance", "error", err)
 		db.Close()
 		return nil, fmt.Errorf("failed to create migration instance: %w", err)
 	}
 
 	err = m.Up()
 	if err != nil && err != migrate.ErrNoChange {
-		log.Error("Migration failed", "error", err)
+		logging.Error("Migration failed", "error", err)
 		db.Close()
 		return nil, fmt.Errorf("failed to apply schema: %w", err)
 	} else if err == migrate.ErrNoChange {
-		log.Info("No schema changes to apply")
+		logging.Info("No schema changes to apply")
 	} else {
-		log.Info("Schema migration applied successfully")
+		logging.Info("Schema migration applied successfully")
 	}
 
 	return db, nil

internal/llm/agent/agent-tool.go 🔗

@@ -56,7 +56,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes
 		return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil
 	}
 
-	err = agent.Generate(session.ID, params.Prompt)
+	err = agent.Generate(ctx, session.ID, params.Prompt)
 	if err != nil {
 		return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil
 	}

internal/llm/agent/agent.go 🔗

@@ -13,11 +13,12 @@ import (
 	"github.com/kujtimiihoxha/termai/internal/llm/prompt"
 	"github.com/kujtimiihoxha/termai/internal/llm/provider"
 	"github.com/kujtimiihoxha/termai/internal/llm/tools"
+	"github.com/kujtimiihoxha/termai/internal/logging"
 	"github.com/kujtimiihoxha/termai/internal/message"
 )
 
 type Agent interface {
-	Generate(sessionID string, content string) error
+	Generate(ctx context.Context, sessionID string, content string) error
 }
 
 type agent struct {
@@ -28,9 +29,9 @@ type agent struct {
 	titleGenerator provider.Provider
 }
 
-func (c *agent) handleTitleGeneration(sessionID, content string) {
+func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) {
 	response, err := c.titleGenerator.SendMessages(
-		c.Context,
+		ctx,
 		[]message.Message{
 			{
 				Role: message.User,
@@ -91,13 +92,16 @@ func (c *agent) processEvent(
 		assistantMsg.AppendContent(event.Content)
 		return c.Messages.Update(*assistantMsg)
 	case provider.EventError:
-		c.App.Logger.PersistError(event.Error.Error())
+		if errors.Is(event.Error, context.Canceled) {
+			return nil
+		}
+		logging.ErrorPersist(event.Error.Error())
 		return event.Error
 	case provider.EventWarning:
-		c.App.Logger.PersistWarn(event.Info)
+		logging.WarnPersist(event.Info)
 		return nil
 	case provider.EventInfo:
-		c.App.Logger.PersistInfo(event.Info)
+		logging.InfoPersist(event.Info)
 	case provider.EventComplete:
 		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 		assistantMsg.AddFinish(event.Response.FinishReason)
@@ -115,12 +119,37 @@ func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall,
 	var wg sync.WaitGroup
 	toolResults := make([]message.ToolResult, len(toolCalls))
 	mutex := &sync.Mutex{}
+	errChan := make(chan error, 1)
+
+	// Create a child context that can be canceled
+	ctx, cancel := context.WithCancel(ctx)
+	defer cancel()
 
 	for i, tc := range toolCalls {
 		wg.Add(1)
 		go func(index int, toolCall message.ToolCall) {
 			defer wg.Done()
 
+			// Check if context is already canceled
+			select {
+			case <-ctx.Done():
+				mutex.Lock()
+				toolResults[index] = message.ToolResult{
+					ToolCallID: toolCall.ID,
+					Content:    "Tool execution canceled",
+					IsError:    true,
+				}
+				mutex.Unlock()
+
+				// Send cancellation error to error channel if it's empty
+				select {
+				case errChan <- ctx.Err():
+				default:
+				}
+				return
+			default:
+			}
+
 			response := ""
 			isError := false
 			found := false
@@ -133,8 +162,19 @@ func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall,
 						Name:  toolCall.Name,
 						Input: toolCall.Input,
 					})
+
 					if toolErr != nil {
-						response = fmt.Sprintf("error running tool: %s", toolErr)
+						if errors.Is(toolErr, context.Canceled) {
+							response = "Tool execution canceled"
+
+							// Send cancellation error to error channel if it's empty
+							select {
+							case errChan <- ctx.Err():
+							default:
+							}
+						} else {
+							response = fmt.Sprintf("error running tool: %s", toolErr)
+						}
 						isError = true
 					} else {
 						response = toolResult.Content
@@ -160,7 +200,24 @@ func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall,
 		}(i, tc)
 	}
 
-	wg.Wait()
+	// Wait for all goroutines to finish or context to be canceled
+	done := make(chan struct{})
+	go func() {
+		wg.Wait()
+		close(done)
+	}()
+
+	select {
+	case <-done:
+		// All tools completed successfully
+	case err := <-errChan:
+		// One of the tools encountered a cancellation
+		return toolResults, err
+	case <-ctx.Done():
+		// Context was canceled externally
+		return toolResults, ctx.Err()
+	}
+
 	return toolResults, nil
 }
 
@@ -188,14 +245,14 @@ func (c *agent) handleToolExecution(
 	return &msg, err
 }
 
-func (c *agent) generate(sessionID string, content string) error {
+func (c *agent) generate(ctx context.Context, sessionID string, content string) error {
 	messages, err := c.Messages.List(sessionID)
 	if err != nil {
 		return err
 	}
 
 	if len(messages) == 0 {
-		go c.handleTitleGeneration(sessionID, content)
+		go c.handleTitleGeneration(ctx, sessionID, content)
 	}
 
 	userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
@@ -212,9 +269,36 @@ func (c *agent) generate(sessionID string, content string) error {
 
 	messages = append(messages, userMsg)
 	for {
+		select {
+		case <-ctx.Done():
+			assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
+				Role:  message.Assistant,
+				Parts: []message.ContentPart{},
+			})
+			if err != nil {
+				return err
+			}
+			assistantMsg.AddFinish("canceled")
+			c.Messages.Update(assistantMsg)
+			return context.Canceled
+		default:
+			// Continue processing
+		}
 
-		eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools)
+		eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools)
 		if err != nil {
+			if errors.Is(err, context.Canceled) {
+				assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{
+					Role:  message.Assistant,
+					Parts: []message.ContentPart{},
+				})
+				if err != nil {
+					return err
+				}
+				assistantMsg.AddFinish("canceled")
+				c.Messages.Update(assistantMsg)
+				return context.Canceled
+			}
 			return err
 		}
 
@@ -228,19 +312,47 @@ func (c *agent) generate(sessionID string, content string) error {
 		for event := range eventChan {
 			err = c.processEvent(sessionID, &assistantMsg, event)
 			if err != nil {
+				if errors.Is(err, context.Canceled) {
+					assistantMsg.AddFinish("canceled")
+					c.Messages.Update(assistantMsg)
+					return context.Canceled
+				}
 				assistantMsg.AddFinish("error:" + err.Error())
 				c.Messages.Update(assistantMsg)
 				return err
 			}
+
+			select {
+			case <-ctx.Done():
+				assistantMsg.AddFinish("canceled")
+				c.Messages.Update(assistantMsg)
+				return context.Canceled
+			default:
+			}
 		}
 
-		msg, err := c.handleToolExecution(c.Context, assistantMsg)
+		// Check for context cancellation before tool execution
+		select {
+		case <-ctx.Done():
+			assistantMsg.AddFinish("canceled")
+			c.Messages.Update(assistantMsg)
+			return context.Canceled
+		default:
+			// Continue processing
+		}
 
-		c.Messages.Update(assistantMsg)
+		msg, err := c.handleToolExecution(ctx, assistantMsg)
 		if err != nil {
+			if errors.Is(err, context.Canceled) {
+				assistantMsg.AddFinish("canceled")
+				c.Messages.Update(assistantMsg)
+				return context.Canceled
+			}
 			return err
 		}
 
+		c.Messages.Update(assistantMsg)
+
 		if len(assistantMsg.ToolCalls()) == 0 {
 			break
 		}
@@ -249,6 +361,16 @@ func (c *agent) generate(sessionID string, content string) error {
 		if msg != nil {
 			messages = append(messages, *msg)
 		}
+
+		// Check for context cancellation after tool execution
+		select {
+		case <-ctx.Done():
+			assistantMsg.AddFinish("canceled")
+			c.Messages.Update(assistantMsg)
+			return context.Canceled
+		default:
+			// Continue processing
+		}
 	}
 	return nil
 }

internal/llm/agent/coder.go 🔗

@@ -1,6 +1,7 @@
 package agent
 
 import (
+	"context"
 	"errors"
 
 	"github.com/kujtimiihoxha/termai/internal/app"
@@ -28,9 +29,9 @@ func (c *coderAgent) setAgentTool(sessionID string) {
 	}
 }
 
-func (c *coderAgent) Generate(sessionID string, content string) error {
+func (c *coderAgent) Generate(ctx context.Context, sessionID string, content string) error {
 	c.setAgentTool(sessionID)
-	return c.generate(sessionID, content)
+	return c.generate(ctx, sessionID, content)
 }
 
 func NewCoderAgent(app *app.App) (Agent, error) {

internal/llm/agent/mcp-tools.go 🔗

@@ -22,8 +22,6 @@ type mcpTool struct {
 	permissions permission.Service
 }
 
-var logger = logging.Get()
-
 type MCPClient interface {
 	Initialize(
 		ctx context.Context,
@@ -143,13 +141,13 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions
 
 	_, err := c.Initialize(ctx, initRequest)
 	if err != nil {
-		logger.Error("error initializing mcp client", "error", err)
+		logging.Error("error initializing mcp client", "error", err)
 		return stdioTools
 	}
 	toolsRequest := mcp.ListToolsRequest{}
 	tools, err := c.ListTools(ctx, toolsRequest)
 	if err != nil {
-		logger.Error("error listing tools", "error", err)
+		logging.Error("error listing tools", "error", err)
 		return stdioTools
 	}
 	for _, t := range tools.Tools {
@@ -172,7 +170,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
 				m.Args...,
 			)
 			if err != nil {
-				logger.Error("error creating mcp client", "error", err)
+				logging.Error("error creating mcp client", "error", err)
 				continue
 			}
 
@@ -183,7 +181,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba
 				client.WithHeaders(m.Headers),
 			)
 			if err != nil {
-				logger.Error("error creating mcp client", "error", err)
+				logging.Error("error creating mcp client", "error", err)
 				continue
 			}
 			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)

internal/llm/agent/task.go 🔗

@@ -1,6 +1,7 @@
 package agent
 
 import (
+	"context"
 	"errors"
 
 	"github.com/kujtimiihoxha/termai/internal/app"
@@ -13,8 +14,8 @@ type taskAgent struct {
 	*agent
 }
 
-func (c *taskAgent) Generate(sessionID string, content string) error {
-	return c.generate(sessionID, content)
+func (c *taskAgent) Generate(ctx context.Context, sessionID string, content string) error {
+	return c.generate(ctx, sessionID, content)
 }
 
 func NewTaskAgent(app *app.App) (Agent, error) {

internal/logging/default.go 🔗

@@ -1,12 +0,0 @@
-package logging
-
-var defaultLogger Interface
-
-func Get() Interface {
-	if defaultLogger == nil {
-		defaultLogger = NewLogger(Options{
-			Level: "info",
-		})
-	}
-	return defaultLogger
-}

internal/logging/logger.go 🔗

@@ -1,141 +1,39 @@
 package logging
 
-import (
-	"context"
-	"io"
-	"log/slog"
-	"slices"
+import "log/slog"
 
-	"github.com/kujtimiihoxha/termai/internal/pubsub"
-	"golang.org/x/exp/maps"
-)
-
-const DefaultLevel = "info"
-
-const (
-	persistKeyArg  = "$persist"
-	PersistTimeArg = "$persist_time"
-)
-
-var levels = map[string]slog.Level{
-	"debug":      slog.LevelDebug,
-	DefaultLevel: slog.LevelInfo,
-	"warn":       slog.LevelWarn,
-	"error":      slog.LevelError,
-}
-
-func ValidLevels() []string {
-	keys := maps.Keys(levels)
-	slices.SortFunc(keys, func(a, b string) int {
-		if a == DefaultLevel {
-			return -1
-		}
-		if b == DefaultLevel {
-			return 1
-		}
-		if a < b {
-			return -1
-		}
-		return 1
-	})
-	return keys
-}
-
-func NewLogger(opts Options) Interface {
-	logger := &Logger{}
-	broker := pubsub.NewBroker[LogMessage]()
-	writer := &writer{
-		messages: []LogMessage{},
-		Broker:   broker,
-	}
-
-	handler := slog.NewTextHandler(
-		io.MultiWriter(writer),
-		&slog.HandlerOptions{
-			Level: slog.Level(levels[opts.Level]),
-		},
-	)
-	logger.logger = slog.New(handler)
-	logger.writer = writer
-
-	return logger
+func Info(msg string, args ...any) {
+	slog.Info(msg, args...)
 }
 
-type Options struct {
-	Level string
+func Debug(msg string, args ...any) {
+	slog.Debug(msg, args...)
 }
 
-type Logger struct {
-	logger *slog.Logger
-	writer *writer
+func Warn(msg string, args ...any) {
+	slog.Warn(msg, args...)
 }
 
-func (l *Logger) SetLevel(level string) {
-	if _, ok := levels[level]; !ok {
-		level = DefaultLevel
-	}
-	handler := slog.NewTextHandler(
-		io.MultiWriter(l.writer),
-		&slog.HandlerOptions{
-			Level: levels[level],
-		},
-	)
-	l.logger = slog.New(handler)
+func Error(msg string, args ...any) {
+	slog.Error(msg, args...)
 }
 
-// PersistDebug implements Interface.
-func (l *Logger) PersistDebug(msg string, args ...any) {
+func InfoPersist(msg string, args ...any) {
 	args = append(args, persistKeyArg, true)
-	l.Debug(msg, args...)
+	slog.Info(msg, args...)
 }
 
-// PersistError implements Interface.
-func (l *Logger) PersistError(msg string, args ...any) {
+func DebugPersist(msg string, args ...any) {
 	args = append(args, persistKeyArg, true)
-	l.Error(msg, args...)
+	slog.Debug(msg, args...)
 }
 
-// PersistInfo implements Interface.
-func (l *Logger) PersistInfo(msg string, args ...any) {
+func WarnPersist(msg string, args ...any) {
 	args = append(args, persistKeyArg, true)
-	l.Info(msg, args...)
+	slog.Warn(msg, args...)
 }
 
-// PersistWarn implements Interface.
-func (l *Logger) PersistWarn(msg string, args ...any) {
+func ErrorPersist(msg string, args ...any) {
 	args = append(args, persistKeyArg, true)
-	l.Warn(msg, args...)
-}
-
-func (l *Logger) Debug(msg string, args ...any) {
-	l.logger.Debug(msg, args...)
-}
-
-func (l *Logger) Info(msg string, args ...any) {
-	l.logger.Info(msg, args...)
-}
-
-func (l *Logger) Warn(msg string, args ...any) {
-	l.logger.Warn(msg, args...)
-}
-
-func (l *Logger) Error(msg string, args ...any) {
-	l.logger.Error(msg, args...)
-}
-
-func (l *Logger) List() []LogMessage {
-	return l.writer.messages
-}
-
-func (l *Logger) Get(id string) (LogMessage, error) {
-	for _, msg := range l.writer.messages {
-		if msg.ID == id {
-			return msg, nil
-		}
-	}
-	return LogMessage{}, io.EOF
-}
-
-func (l *Logger) Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage] {
-	return l.writer.Subscribe(ctx)
+	slog.Error(msg, args...)
 }

internal/logging/logging.go 🔗

@@ -1,23 +0,0 @@
-package logging
-
-import (
-	"context"
-
-	"github.com/kujtimiihoxha/termai/internal/pubsub"
-)
-
-type Interface interface {
-	Debug(msg string, args ...any)
-	Info(msg string, args ...any)
-	Warn(msg string, args ...any)
-	Error(msg string, args ...any)
-	Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage]
-
-	PersistDebug(msg string, args ...any)
-	PersistInfo(msg string, args ...any)
-	PersistWarn(msg string, args ...any)
-	PersistError(msg string, args ...any)
-	List() []LogMessage
-
-	SetLevel(level string)
-}

internal/logging/writer.go 🔗

@@ -2,18 +2,47 @@ package logging
 
 import (
 	"bytes"
+	"context"
 	"fmt"
+	"strings"
+	"sync"
 	"time"
 
 	"github.com/go-logfmt/logfmt"
 	"github.com/kujtimiihoxha/termai/internal/pubsub"
 )
 
-type writer struct {
+const (
+	persistKeyArg  = "$_persist"
+	PersistTimeArg = "$_persist_time"
+)
+
+type LogData struct {
 	messages []LogMessage
 	*pubsub.Broker[LogMessage]
+	lock sync.Mutex
+}
+
+func (l *LogData) Add(msg LogMessage) {
+	l.lock.Lock()
+	defer l.lock.Unlock()
+	l.messages = append(l.messages, msg)
+	l.Publish(pubsub.CreatedEvent, msg)
+}
+
+func (l *LogData) List() []LogMessage {
+	l.lock.Lock()
+	defer l.lock.Unlock()
+	return l.messages
+}
+
+var defaultLogData = &LogData{
+	messages: make([]LogMessage, 0),
+	Broker:   pubsub.NewBroker[LogMessage](),
 }
 
+type writer struct{}
+
 func (w *writer) Write(p []byte) (int, error) {
 	d := logfmt.NewDecoder(bytes.NewReader(p))
 	for d.ScanRecord() {
@@ -30,7 +59,7 @@ func (w *writer) Write(p []byte) (int, error) {
 				}
 				msg.Time = parsed
 			case "level":
-				msg.Level = string(d.Value())
+				msg.Level = strings.ToLower(string(d.Value()))
 			case "msg":
 				msg.Message = string(d.Value())
 			default:
@@ -50,11 +79,23 @@ func (w *writer) Write(p []byte) (int, error) {
 				}
 			}
 		}
-		w.messages = append(w.messages, msg)
-		w.Publish(pubsub.CreatedEvent, msg)
+		defaultLogData.Add(msg)
 	}
 	if d.Err() != nil {
 		return 0, d.Err()
 	}
 	return len(p), nil
 }
+
+func NewWriter() *writer {
+	w := &writer{}
+	return w
+}
+
+func Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage] {
+	return defaultLogData.Subscribe(ctx)
+}
+
+func List() []LogMessage {
+	return defaultLogData.List()
+}

internal/lsp/client.go 🔗

@@ -18,8 +18,6 @@ import (
 	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 )
 
-var logger = logging.Get()
-
 type Client struct {
 	Cmd    *exec.Cmd
 	stdin  io.WriteCloser
@@ -377,7 +375,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error {
 	}
 
 	if cnf.Debug {
-		logger.Debug("Closing file", "file", filepath)
+		logging.Debug("Closing file", "file", filepath)
 	}
 	if err := c.Notify(ctx, "textDocument/didClose", params); err != nil {
 		return err
@@ -416,12 +414,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
 	for _, filePath := range filesToClose {
 		err := c.CloseFile(ctx, filePath)
 		if err != nil && cnf.Debug {
-			logger.Warn("Error closing file", "file", filePath, "error", err)
+			logging.Warn("Error closing file", "file", filePath, "error", err)
 		}
 	}
 
 	if cnf.Debug {
-		logger.Debug("Closed all files", "files", filesToClose)
+		logging.Debug("Closed all files", "files", filesToClose)
 	}
 }
 

internal/lsp/handlers.go 🔗

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 
 	"github.com/kujtimiihoxha/termai/internal/config"
+	"github.com/kujtimiihoxha/termai/internal/logging"
 	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 	"github.com/kujtimiihoxha/termai/internal/lsp/util"
 )
@@ -17,7 +18,7 @@ func HandleWorkspaceConfiguration(params json.RawMessage) (any, error) {
 func HandleRegisterCapability(params json.RawMessage) (any, error) {
 	var registerParams protocol.RegistrationParams
 	if err := json.Unmarshal(params, &registerParams); err != nil {
-		logger.Error("Error unmarshaling registration params", "error", err)
+		logging.Error("Error unmarshaling registration params", "error", err)
 		return nil, err
 	}
 
@@ -27,13 +28,13 @@ func HandleRegisterCapability(params json.RawMessage) (any, error) {
 			// Parse the registration options
 			optionsJSON, err := json.Marshal(reg.RegisterOptions)
 			if err != nil {
-				logger.Error("Error marshaling registration options", "error", err)
+				logging.Error("Error marshaling registration options", "error", err)
 				continue
 			}
 
 			var options protocol.DidChangeWatchedFilesRegistrationOptions
 			if err := json.Unmarshal(optionsJSON, &options); err != nil {
-				logger.Error("Error unmarshaling registration options", "error", err)
+				logging.Error("Error unmarshaling registration options", "error", err)
 				continue
 			}
 
@@ -53,7 +54,7 @@ func HandleApplyEdit(params json.RawMessage) (any, error) {
 
 	err := util.ApplyWorkspaceEdit(edit.Edit)
 	if err != nil {
-		logger.Error("Error applying workspace edit", "error", err)
+		logging.Error("Error applying workspace edit", "error", err)
 		return protocol.ApplyWorkspaceEditResult{Applied: false, FailureReason: err.Error()}, nil
 	}
 
@@ -88,7 +89,7 @@ func HandleServerMessage(params json.RawMessage) {
 	}
 	if err := json.Unmarshal(params, &msg); err == nil {
 		if cnf.Debug {
-			logger.Debug("Server message", "type", msg.Type, "message", msg.Message)
+			logging.Debug("Server message", "type", msg.Type, "message", msg.Message)
 		}
 	}
 }
@@ -96,7 +97,7 @@ func HandleServerMessage(params json.RawMessage) {
 func HandleDiagnostics(client *Client, params json.RawMessage) {
 	var diagParams protocol.PublishDiagnosticsParams
 	if err := json.Unmarshal(params, &diagParams); err != nil {
-		logger.Error("Error unmarshaling diagnostics params", "error", err)
+		logging.Error("Error unmarshaling diagnostics params", "error", err)
 		return
 	}
 

internal/lsp/transport.go 🔗

@@ -9,6 +9,7 @@ import (
 	"strings"
 
 	"github.com/kujtimiihoxha/termai/internal/config"
+	"github.com/kujtimiihoxha/termai/internal/logging"
 )
 
 // Write writes an LSP message to the given writer
@@ -20,7 +21,7 @@ func WriteMessage(w io.Writer, msg *Message) error {
 	cnf := config.Get()
 
 	if cnf.Debug {
-		logger.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
+		logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID)
 	}
 
 	_, err = fmt.Fprintf(w, "Content-Length: %d\r\n\r\n", len(data))
@@ -49,7 +50,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
 		line = strings.TrimSpace(line)
 
 		if cnf.Debug {
-			logger.Debug("Received header", "line", line)
+			logging.Debug("Received header", "line", line)
 		}
 
 		if line == "" {
@@ -65,7 +66,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
 	}
 
 	if cnf.Debug {
-		logger.Debug("Content-Length", "length", contentLength)
+		logging.Debug("Content-Length", "length", contentLength)
 	}
 
 	// Read content
@@ -76,7 +77,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) {
 	}
 
 	if cnf.Debug {
-		logger.Debug("Received content", "content", string(content))
+		logging.Debug("Received content", "content", string(content))
 	}
 
 	// Parse message
@@ -95,7 +96,7 @@ func (c *Client) handleMessages() {
 		msg, err := ReadMessage(c.stdout)
 		if err != nil {
 			if cnf.Debug {
-				logger.Error("Error reading message", "error", err)
+				logging.Error("Error reading message", "error", err)
 			}
 			return
 		}
@@ -103,7 +104,7 @@ func (c *Client) handleMessages() {
 		// Handle server->client request (has both Method and ID)
 		if msg.Method != "" && msg.ID != 0 {
 			if cnf.Debug {
-				logger.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
+				logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID)
 			}
 
 			response := &Message{
@@ -143,7 +144,7 @@ func (c *Client) handleMessages() {
 
 			// Send response back to server
 			if err := WriteMessage(c.stdin, response); err != nil {
-				logger.Error("Error sending response to server", "error", err)
+				logging.Error("Error sending response to server", "error", err)
 			}
 
 			continue
@@ -157,11 +158,11 @@ func (c *Client) handleMessages() {
 
 			if ok {
 				if cnf.Debug {
-					logger.Debug("Handling notification", "method", msg.Method)
+					logging.Debug("Handling notification", "method", msg.Method)
 				}
 				go handler(msg.Params)
 			} else if cnf.Debug {
-				logger.Debug("No handler for notification", "method", msg.Method)
+				logging.Debug("No handler for notification", "method", msg.Method)
 			}
 			continue
 		}
@@ -174,12 +175,12 @@ func (c *Client) handleMessages() {
 
 			if ok {
 				if cnf.Debug {
-					logger.Debug("Received response for request", "id", msg.ID)
+					logging.Debug("Received response for request", "id", msg.ID)
 				}
 				ch <- msg
 				close(ch)
 			} else if cnf.Debug {
-				logger.Debug("No handler for response", "id", msg.ID)
+				logging.Debug("No handler for response", "id", msg.ID)
 			}
 		}
 	}
@@ -191,7 +192,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
 	id := c.nextID.Add(1)
 
 	if cnf.Debug {
-		logger.Debug("Making call", "method", method, "id", id)
+		logging.Debug("Making call", "method", method, "id", id)
 	}
 
 	msg, err := NewRequest(id, method, params)
@@ -217,14 +218,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
 	}
 
 	if cnf.Debug {
-		logger.Debug("Request sent", "method", method, "id", id)
+		logging.Debug("Request sent", "method", method, "id", id)
 	}
 
 	// Wait for response
 	resp := <-ch
 
 	if cnf.Debug {
-		logger.Debug("Received response", "id", id)
+		logging.Debug("Received response", "id", id)
 	}
 
 	if resp.Error != nil {
@@ -250,7 +251,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any
 func (c *Client) Notify(ctx context.Context, method string, params any) error {
 	cnf := config.Get()
 	if cnf.Debug {
-		logger.Debug("Sending notification", "method", method)
+		logging.Debug("Sending notification", "method", method)
 	}
 
 	msg, err := NewNotification(method, params)

internal/lsp/watcher/watcher.go 🔗

@@ -16,8 +16,6 @@ import (
 	"github.com/kujtimiihoxha/termai/internal/lsp/protocol"
 )
 
-var logger = logging.Get()
-
 // WorkspaceWatcher manages LSP file watching
 type WorkspaceWatcher struct {
 	client        *lsp.Client
@@ -53,7 +51,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
 
 	// Print detailed registration information for debugging
 	if cnf.Debug {
-		logger.Debug("Adding file watcher registrations",
+		logging.Debug("Adding file watcher registrations",
 			"id", id,
 			"watchers", len(watchers),
 			"total", len(w.registrations),
@@ -61,26 +59,26 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
 		)
 
 		for i, watcher := range watchers {
-			logger.Debug("Registration", "index", i+1)
+			logging.Debug("Registration", "index", i+1)
 
 			// Log the GlobPattern
 			switch v := watcher.GlobPattern.Value.(type) {
 			case string:
-				logger.Debug("GlobPattern", "pattern", v)
+				logging.Debug("GlobPattern", "pattern", v)
 			case protocol.RelativePattern:
-				logger.Debug("GlobPattern", "pattern", v.Pattern)
+				logging.Debug("GlobPattern", "pattern", v.Pattern)
 
 				// Log BaseURI details
 				switch u := v.BaseURI.Value.(type) {
 				case string:
-					logger.Debug("BaseURI", "baseURI", u)
+					logging.Debug("BaseURI", "baseURI", u)
 				case protocol.DocumentUri:
-					logger.Debug("BaseURI", "baseURI", u)
+					logging.Debug("BaseURI", "baseURI", u)
 				default:
-					logger.Debug("BaseURI", "baseURI", u)
+					logging.Debug("BaseURI", "baseURI", u)
 				}
 			default:
-				logger.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v))
+				logging.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v))
 			}
 
 			// Log WatchKind
@@ -89,7 +87,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
 				watchKind = *watcher.Kind
 			}
 
-			logger.Debug("WatchKind", "kind", watchKind)
+			logging.Debug("WatchKind", "kind", watchKind)
 
 			// Test match against some example paths
 			testPaths := []string{
@@ -99,7 +97,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
 
 			for _, testPath := range testPaths {
 				isMatch := w.matchesPattern(testPath, watcher.GlobPattern)
-				logger.Debug("Test path", "path", testPath, "matches", isMatch)
+				logging.Debug("Test path", "path", testPath, "matches", isMatch)
 			}
 		}
 	}
@@ -119,7 +117,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
 			if d.IsDir() {
 				if path != w.workspacePath && shouldExcludeDir(path) {
 					if cnf.Debug {
-						logger.Debug("Skipping excluded directory", "path", path)
+						logging.Debug("Skipping excluded directory", "path", path)
 					}
 					return filepath.SkipDir
 				}
@@ -139,7 +137,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
 
 		elapsedTime := time.Since(startTime)
 		if cnf.Debug {
-			logger.Debug("Workspace scan complete",
+			logging.Debug("Workspace scan complete",
 				"filesOpened", filesOpened,
 				"elapsedTime", elapsedTime.Seconds(),
 				"workspacePath", w.workspacePath,
@@ -147,7 +145,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc
 		}
 
 		if err != nil && cnf.Debug {
-			logger.Debug("Error scanning workspace for files to open", "error", err)
+			logging.Debug("Error scanning workspace for files to open", "error", err)
 		}
 	}()
 }
@@ -164,7 +162,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
 
 	watcher, err := fsnotify.NewWatcher()
 	if err != nil {
-		logger.Error("Error creating watcher", "error", err)
+		logging.Error("Error creating watcher", "error", err)
 	}
 	defer watcher.Close()
 
@@ -178,7 +176,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
 		if d.IsDir() && path != workspacePath {
 			if shouldExcludeDir(path) {
 				if cnf.Debug {
-					logger.Debug("Skipping excluded directory", "path", path)
+					logging.Debug("Skipping excluded directory", "path", path)
 				}
 				return filepath.SkipDir
 			}
@@ -188,14 +186,14 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
 		if d.IsDir() {
 			err = watcher.Add(path)
 			if err != nil {
-				logger.Error("Error watching path", "path", path, "error", err)
+				logging.Error("Error watching path", "path", path, "error", err)
 			}
 		}
 
 		return nil
 	})
 	if err != nil {
-		logger.Error("Error walking workspace", "error", err)
+		logging.Error("Error walking workspace", "error", err)
 	}
 
 	// Event loop
@@ -217,7 +215,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
 						// Skip excluded directories
 						if !shouldExcludeDir(event.Name) {
 							if err := watcher.Add(event.Name); err != nil {
-								logger.Error("Error adding directory to watcher", "path", event.Name, "error", err)
+								logging.Error("Error adding directory to watcher", "path", event.Name, "error", err)
 							}
 						}
 					} else {
@@ -232,7 +230,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
 			// Debug logging
 			if cnf.Debug {
 				matched, kind := w.isPathWatched(event.Name)
-				logger.Debug("File event",
+				logging.Debug("File event",
 					"path", event.Name,
 					"operation", event.Op.String(),
 					"watched", matched,
@@ -277,7 +275,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str
 			if !ok {
 				return
 			}
-			logger.Error("Error watching file", "error", err)
+			logging.Error("Error watching file", "error", err)
 		}
 	}
 }
@@ -402,7 +400,7 @@ func matchesSimpleGlob(pattern, path string) bool {
 	// Fall back to simple matching for simpler patterns
 	matched, err := filepath.Match(pattern, path)
 	if err != nil {
-		logger.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err)
+		logging.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err)
 		return false
 	}
 
@@ -413,7 +411,7 @@ func matchesSimpleGlob(pattern, path string) bool {
 func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPattern) bool {
 	patternInfo, err := pattern.AsPattern()
 	if err != nil {
-		logger.Error("Error parsing pattern", "pattern", pattern, "error", err)
+		logging.Error("Error parsing pattern", "pattern", pattern, "error", err)
 		return false
 	}
 
@@ -438,7 +436,7 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt
 	// Make path relative to basePath for matching
 	relPath, err := filepath.Rel(basePath, path)
 	if err != nil {
-		logger.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err)
+		logging.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err)
 		return false
 	}
 	relPath = filepath.ToSlash(relPath)
@@ -479,14 +477,14 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
 	if changeType == protocol.FileChangeType(protocol.Changed) && w.client.IsFileOpen(filePath) {
 		err := w.client.NotifyChange(ctx, filePath)
 		if err != nil {
-			logger.Error("Error notifying change", "error", err)
+			logging.Error("Error notifying change", "error", err)
 		}
 		return
 	}
 
 	// Notify LSP server about the file event using didChangeWatchedFiles
 	if err := w.notifyFileEvent(ctx, uri, changeType); err != nil {
-		logger.Error("Error notifying LSP server about file event", "error", err)
+		logging.Error("Error notifying LSP server about file event", "error", err)
 	}
 }
 
@@ -494,7 +492,7 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan
 func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error {
 	cnf := config.Get()
 	if cnf.Debug {
-		logger.Debug("Notifying file event",
+		logging.Debug("Notifying file event",
 			"uri", uri,
 			"changeType", changeType,
 		)
@@ -618,7 +616,7 @@ func shouldExcludeFile(filePath string) bool {
 	// Skip large files
 	if info.Size() > maxFileSize {
 		if cnf.Debug {
-			logger.Debug("Skipping large file",
+			logging.Debug("Skipping large file",
 				"path", filePath,
 				"size", info.Size(),
 				"maxSize", maxFileSize,
@@ -651,7 +649,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) {
 	if watched, _ := w.isPathWatched(path); watched {
 		// Don't need to check if it's already open - the client.OpenFile handles that
 		if err := w.client.OpenFile(ctx, path); err != nil && cnf.Debug {
-			logger.Error("Error opening file", "path", path, "error", err)
+			logging.Error("Error opening file", "path", path, "error", err)
 		}
 	}
 }

internal/tui/components/core/status.go 🔗

@@ -13,7 +13,7 @@ import (
 )
 
 type statusCmp struct {
-	info       *util.InfoMsg
+	info       util.InfoMsg
 	width      int
 	messageTTL time.Duration
 }
@@ -35,14 +35,14 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		m.width = msg.Width
 		return m, nil
 	case util.InfoMsg:
-		m.info = &msg
+		m.info = msg
 		ttl := msg.TTL
 		if ttl == 0 {
 			ttl = m.messageTTL
 		}
 		return m, m.clearMessageCmd(ttl)
 	case util.ClearStatusMsg:
-		m.info = nil
+		m.info = util.InfoMsg{}
 	}
 	return m, nil
 }
@@ -54,7 +54,7 @@ var (
 
 func (m statusCmp) View() string {
 	status := styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help")
-	if m.info != nil {
+	if m.info.Msg != "" {
 		infoStyle := styles.Padded.
 			Foreground(styles.Base).
 			Width(m.availableFooterMsgWidth())

internal/tui/components/logs/details.go 🔗

@@ -30,7 +30,7 @@ type detailCmp struct {
 }
 
 func (i *detailCmp) Init() tea.Cmd {
-	messages := logging.Get().List()
+	messages := logging.List()
 	if len(messages) == 0 {
 		return nil
 	}

internal/tui/components/logs/table.go 🔗

@@ -22,8 +22,6 @@ type TableComponent interface {
 	layout.Bordered
 }
 
-var logger = logging.Get()
-
 type tableCmp struct {
 	table table.Model
 }
@@ -57,7 +55,7 @@ func (i *tableCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		if selectedRow != nil {
 			if prevSelectedRow == nil || selectedRow[0] == prevSelectedRow[0] {
 				var log logging.LogMessage
-				for _, row := range logging.Get().List() {
+				for _, row := range logging.List() {
 					if row.ID == selectedRow[0] {
 						log = row
 						break
@@ -112,7 +110,7 @@ func (i *tableCmp) BindingKeys() []key.Binding {
 func (i *tableCmp) setRows() {
 	rows := []table.Row{}
 
-	logs := logger.List()
+	logs := logging.List()
 	slices.SortFunc(logs, func(a, b logging.LogMessage) int {
 		if a.Time.Before(b.Time) {
 			return 1

internal/tui/components/repl/editor.go 🔗

@@ -12,6 +12,7 @@ import (
 	"github.com/kujtimiihoxha/termai/internal/tui/styles"
 	"github.com/kujtimiihoxha/termai/internal/tui/util"
 	"github.com/kujtimiihoxha/vimtea"
+	"golang.org/x/net/context"
 )
 
 type EditorCmp interface {
@@ -23,18 +24,20 @@ type EditorCmp interface {
 }
 
 type editorCmp struct {
-	app        *app.App
-	editor     vimtea.Editor
-	editorMode vimtea.EditorMode
-	sessionID  string
-	focused    bool
-	width      int
-	height     int
+	app           *app.App
+	editor        vimtea.Editor
+	editorMode    vimtea.EditorMode
+	sessionID     string
+	focused       bool
+	width         int
+	height        int
+	cancelMessage context.CancelFunc
 }
 
 type editorKeyMap struct {
 	SendMessage    key.Binding
 	SendMessageI   key.Binding
+	CancelMessage  key.Binding
 	InsertMode     key.Binding
 	NormaMode      key.Binding
 	VisualMode     key.Binding
@@ -50,6 +53,10 @@ var editorKeyMapValue = editorKeyMap{
 		key.WithKeys("ctrl+s"),
 		key.WithHelp("ctrl+s", "send message insert mode"),
 	),
+	CancelMessage: key.NewBinding(
+		key.WithKeys("ctrl+x"),
+		key.WithHelp("ctrl+x", "cancel current message"),
+	),
 	InsertMode: key.NewBinding(
 		key.WithKeys("i"),
 		key.WithHelp("i", "insert mode"),
@@ -93,6 +100,8 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 				if m.editorMode == vimtea.ModeInsert {
 					return m, m.Send()
 				}
+			case key.Matches(msg, editorKeyMapValue.CancelMessage):
+				return m, m.Cancel()
 			}
 		}
 		u, cmd := m.editor.Update(msg)
@@ -136,6 +145,16 @@ func (m *editorCmp) SetSize(width int, height int) {
 	m.editor.SetSize(width, height)
 }
 
+func (m *editorCmp) Cancel() tea.Cmd {
+	if m.cancelMessage == nil {
+		return util.ReportWarn("No message to cancel")
+	}
+
+	m.cancelMessage()
+	m.cancelMessage = nil
+	return util.ReportWarn("Message cancelled")
+}
+
 func (m *editorCmp) Send() tea.Cmd {
 	return func() tea.Msg {
 		messages, err := m.app.Messages.List(m.sessionID)
@@ -151,7 +170,13 @@ func (m *editorCmp) Send() tea.Cmd {
 		}
 
 		content := strings.Join(m.editor.GetBuffer().Lines(), "\n")
-		go a.Generate(m.sessionID, content)
+		ctx, cancel := context.WithCancel(m.app.Context)
+		m.cancelMessage = cancel
+		go func() {
+			defer cancel()
+			a.Generate(ctx, m.sessionID, content)
+			m.cancelMessage = nil
+		}()
 
 		return m.editor.Reset()
 	}

internal/tui/components/repl/messages.go 🔗

@@ -309,7 +309,7 @@ func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message.
 	}
 
 	for _, msg := range futureMessages {
-		if msg.Content().String() != "" {
+		if msg.Content().String() != "" || msg.FinishReason() == "canceled" {
 			break
 		}
 
@@ -345,13 +345,18 @@ func (m *messagesCmp) renderView() {
 	prevMessageWasUser := false
 	for inx, msg := range m.messages {
 		content := msg.Content().String()
-		if content != "" || prevMessageWasUser {
+		if content != "" || prevMessageWasUser || msg.FinishReason() == "canceled" {
 			if msg.ReasoningContent().String() != "" && content == "" {
 				content = msg.ReasoningContent().String()
 			} else if content == "" {
 				content = "..."
 			}
-			content, _ = r.Render(content)
+			if msg.FinishReason() == "canceled" {
+				content, _ = r.Render(content)
+				content += lipgloss.NewStyle().Padding(1, 0, 0, 1).Foreground(styles.Error).Render(styles.ErrorIcon + " Canceled")
+			} else {
+				content, _ = r.Render(content)
+			}
 
 			isSelected := inx == m.selectedMsgIdx
 

internal/tui/tui.go 🔗

@@ -101,7 +101,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	// Status
 	case util.InfoMsg:
 		a.status, cmd = a.status.Update(msg)
-		return a, cmd
+		cmds = append(cmds, cmd)
+		return a, tea.Batch(cmds...)
 	case pubsub.Event[logging.LogMessage]:
 		if msg.Payload.Persist {
 			switch msg.Payload.Level {