diff --git a/cmd/root.go b/cmd/root.go index c4e99985ac8e77f5b6eefac215181201c1508324..bd9336fcef7a090894194e146bbc426b4763816f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "log/slog" "os" "sync" @@ -10,6 +11,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" "github.com/kujtimiihoxha/termai/internal/llm/agent" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/tui" zone "github.com/lrstanley/bubblezone" "github.com/spf13/cobra" @@ -26,6 +28,16 @@ var rootCmd = &cobra.Command{ } debug, _ := cmd.Flags().GetBool("debug") err := config.Load(debug) + cfg := config.Get() + defaultLevel := slog.LevelInfo + if cfg.Debug { + defaultLevel = slog.LevelDebug + } + logger := slog.New(slog.NewTextHandler(logging.NewWriter(), &slog.HandlerOptions{ + Level: defaultLevel, + })) + slog.SetDefault(logger) + if err != nil { return err } @@ -37,14 +49,14 @@ var rootCmd = &cobra.Command{ app := app.New(ctx, conn) defer app.Close() - app.Logger.Info("Starting termai...") + logging.Info("Starting termai...") zone.NewGlobal() tui := tea.NewProgram( tui.New(app), tea.WithAltScreen(), tea.WithMouseCellMotion(), ) - app.Logger.Info("Setting up subscriptions...") + logging.Info("Setting up subscriptions...") ch, unsub := setupSubscriptions(app) defer unsub() @@ -66,9 +78,8 @@ func setupSubscriptions(app *app.App) (chan tea.Msg, func()) { ch := make(chan tea.Msg) wg := sync.WaitGroup{} ctx, cancel := context.WithCancel(app.Context) - { - sub := app.Logger.Subscribe(ctx) + sub := logging.Subscribe(ctx) wg.Add(1) go func() { for ev := range sub { diff --git a/go.mod b/go.mod index ab519a53d3afc44d865c26f9f2c825f3d385f062..63df37fba20eb74d1f2f24ccccb760acf9992bad 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.0 github.com/stretchr/testify v1.10.0 - golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 + golang.org/x/net v0.34.0 google.golang.org/api v0.215.0 ) @@ -116,10 +116,10 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.design/x/clipboard v0.7.0 // indirect golang.org/x/crypto v0.33.0 // indirect + golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect golang.org/x/exp/shiny v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/image v0.14.0 // indirect golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a // indirect - golang.org/x/net v0.34.0 // indirect golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.31.0 // indirect diff --git a/internal/app/services.go b/internal/app/services.go index d0beb0b5be6065ab8a048e8cd1a448419f9ef7a2..dcdfe12e00ba6b2745220f062f9886275d6ffc92 100644 --- a/internal/app/services.go +++ b/internal/app/services.go @@ -3,6 +3,7 @@ package app import ( "context" "database/sql" + "log/slog" "github.com/kujtimiihoxha/termai/internal/config" "github.com/kujtimiihoxha/termai/internal/db" @@ -23,16 +24,14 @@ type App struct { LSPClients map[string]*lsp.Client - Logger logging.Interface - ceanups []func() } func New(ctx context.Context, conn *sql.DB) *App { cfg := config.Get() + logging.Info("Debug mode enabled") + q := db.New(conn) - log := logging.Get() - log.SetLevel(cfg.Log.Level) sessions := session.NewService(ctx, q) messages := message.NewService(ctx, q) @@ -41,7 +40,6 @@ func New(ctx context.Context, conn *sql.DB) *App { Sessions: sessions, Messages: messages, Permissions: permission.NewPermissionService(), - Logger: log, LSPClients: make(map[string]*lsp.Client), } @@ -52,13 +50,13 @@ func New(ctx context.Context, conn *sql.DB) *App { }) workspaceWatcher := watcher.NewWorkspaceWatcher(lspClient) if err != nil { - log.Error("Failed to create LSP client for", name, err) + logging.Error("Failed to create LSP client for", name, err) continue } _, err = lspClient.InitializeLSPClient(ctx, config.WorkingDirectory()) if err != nil { - log.Error("Initialize failed", "error", err) + logging.Error("Initialize failed", "error", err) continue } go workspaceWatcher.WatchWorkspace(ctx, config.WorkingDirectory()) @@ -74,5 +72,5 @@ func (a *App) Close() { for _, client := range a.LSPClients { client.Close() } - a.Logger.Info("App closed") + slog.Info("App closed") } diff --git a/internal/db/connect.go b/internal/db/connect.go index aed04b986f3a5eb0187afe278c9eeb558c2d0c70..8bba9cad806c73327ffa016bd5851875403236e4 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -16,8 +16,6 @@ import ( "github.com/kujtimiihoxha/termai/internal/logging" ) -var log = logging.Get() - func Connect() (*sql.DB, error) { dataDir := config.Get().Data.Directory if dataDir == "" { @@ -50,43 +48,43 @@ func Connect() (*sql.DB, error) { for _, pragma := range pragmas { if _, err = db.Exec(pragma); err != nil { - log.Warn("Failed to set pragma", pragma, err) + logging.Warn("Failed to set pragma", pragma, err) } else { - log.Warn("Set pragma", "pragma", pragma) + logging.Warn("Set pragma", "pragma", pragma) } } // Initialize schema from embedded file d, err := iofs.New(FS, "migrations") if err != nil { - log.Error("Failed to open embedded migrations", "error", err) + logging.Error("Failed to open embedded migrations", "error", err) db.Close() return nil, fmt.Errorf("failed to open embedded migrations: %w", err) } driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { - log.Error("Failed to create SQLite driver", "error", err) + logging.Error("Failed to create SQLite driver", "error", err) db.Close() return nil, fmt.Errorf("failed to create SQLite driver: %w", err) } m, err := migrate.NewWithInstance("iofs", d, "ql", driver) if err != nil { - log.Error("Failed to create migration instance", "error", err) + logging.Error("Failed to create migration instance", "error", err) db.Close() return nil, fmt.Errorf("failed to create migration instance: %w", err) } err = m.Up() if err != nil && err != migrate.ErrNoChange { - log.Error("Migration failed", "error", err) + logging.Error("Migration failed", "error", err) db.Close() return nil, fmt.Errorf("failed to apply schema: %w", err) } else if err == migrate.ErrNoChange { - log.Info("No schema changes to apply") + logging.Info("No schema changes to apply") } else { - log.Info("Schema migration applied successfully") + logging.Info("Schema migration applied successfully") } return db, nil diff --git a/internal/llm/agent/agent-tool.go b/internal/llm/agent/agent-tool.go index bf5e31f8f18ae76a98ebaaa37a29922eeda423a9..deb6aed608625816b6f9715f33a823a99c337d47 100644 --- a/internal/llm/agent/agent-tool.go +++ b/internal/llm/agent/agent-tool.go @@ -56,7 +56,7 @@ func (b *agentTool) Run(ctx context.Context, call tools.ToolCall) (tools.ToolRes return tools.NewTextErrorResponse(fmt.Sprintf("error creating session: %s", err)), nil } - err = agent.Generate(session.ID, params.Prompt) + err = agent.Generate(ctx, session.ID, params.Prompt) if err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error generating agent: %s", err)), nil } diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 43ba0fc26c2631dac4fe5b0dc85839b85fd612e7..998dc1551e8adc3bbc5facfb80338803a8f0afb0 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -13,11 +13,12 @@ import ( "github.com/kujtimiihoxha/termai/internal/llm/prompt" "github.com/kujtimiihoxha/termai/internal/llm/provider" "github.com/kujtimiihoxha/termai/internal/llm/tools" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/message" ) type Agent interface { - Generate(sessionID string, content string) error + Generate(ctx context.Context, sessionID string, content string) error } type agent struct { @@ -28,9 +29,9 @@ type agent struct { titleGenerator provider.Provider } -func (c *agent) handleTitleGeneration(sessionID, content string) { +func (c *agent) handleTitleGeneration(ctx context.Context, sessionID, content string) { response, err := c.titleGenerator.SendMessages( - c.Context, + ctx, []message.Message{ { Role: message.User, @@ -91,13 +92,16 @@ func (c *agent) processEvent( assistantMsg.AppendContent(event.Content) return c.Messages.Update(*assistantMsg) case provider.EventError: - c.App.Logger.PersistError(event.Error.Error()) + if errors.Is(event.Error, context.Canceled) { + return nil + } + logging.ErrorPersist(event.Error.Error()) return event.Error case provider.EventWarning: - c.App.Logger.PersistWarn(event.Info) + logging.WarnPersist(event.Info) return nil case provider.EventInfo: - c.App.Logger.PersistInfo(event.Info) + logging.InfoPersist(event.Info) case provider.EventComplete: assistantMsg.SetToolCalls(event.Response.ToolCalls) assistantMsg.AddFinish(event.Response.FinishReason) @@ -115,12 +119,37 @@ func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, var wg sync.WaitGroup toolResults := make([]message.ToolResult, len(toolCalls)) mutex := &sync.Mutex{} + errChan := make(chan error, 1) + + // Create a child context that can be canceled + ctx, cancel := context.WithCancel(ctx) + defer cancel() for i, tc := range toolCalls { wg.Add(1) go func(index int, toolCall message.ToolCall) { defer wg.Done() + // Check if context is already canceled + select { + case <-ctx.Done(): + mutex.Lock() + toolResults[index] = message.ToolResult{ + ToolCallID: toolCall.ID, + Content: "Tool execution canceled", + IsError: true, + } + mutex.Unlock() + + // Send cancellation error to error channel if it's empty + select { + case errChan <- ctx.Err(): + default: + } + return + default: + } + response := "" isError := false found := false @@ -133,8 +162,19 @@ func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, Name: toolCall.Name, Input: toolCall.Input, }) + if toolErr != nil { - response = fmt.Sprintf("error running tool: %s", toolErr) + if errors.Is(toolErr, context.Canceled) { + response = "Tool execution canceled" + + // Send cancellation error to error channel if it's empty + select { + case errChan <- ctx.Err(): + default: + } + } else { + response = fmt.Sprintf("error running tool: %s", toolErr) + } isError = true } else { response = toolResult.Content @@ -160,7 +200,24 @@ func (c *agent) ExecuteTools(ctx context.Context, toolCalls []message.ToolCall, }(i, tc) } - wg.Wait() + // Wait for all goroutines to finish or context to be canceled + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // All tools completed successfully + case err := <-errChan: + // One of the tools encountered a cancellation + return toolResults, err + case <-ctx.Done(): + // Context was canceled externally + return toolResults, ctx.Err() + } + return toolResults, nil } @@ -188,14 +245,14 @@ func (c *agent) handleToolExecution( return &msg, err } -func (c *agent) generate(sessionID string, content string) error { +func (c *agent) generate(ctx context.Context, sessionID string, content string) error { messages, err := c.Messages.List(sessionID) if err != nil { return err } if len(messages) == 0 { - go c.handleTitleGeneration(sessionID, content) + go c.handleTitleGeneration(ctx, sessionID, content) } userMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ @@ -212,9 +269,36 @@ func (c *agent) generate(sessionID string, content string) error { messages = append(messages, userMsg) for { + select { + case <-ctx.Done(): + assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + }) + if err != nil { + return err + } + assistantMsg.AddFinish("canceled") + c.Messages.Update(assistantMsg) + return context.Canceled + default: + // Continue processing + } - eventChan, err := c.agent.StreamResponse(c.Context, messages, c.tools) + eventChan, err := c.agent.StreamResponse(ctx, messages, c.tools) if err != nil { + if errors.Is(err, context.Canceled) { + assistantMsg, err := c.Messages.Create(sessionID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{}, + }) + if err != nil { + return err + } + assistantMsg.AddFinish("canceled") + c.Messages.Update(assistantMsg) + return context.Canceled + } return err } @@ -228,19 +312,47 @@ func (c *agent) generate(sessionID string, content string) error { for event := range eventChan { err = c.processEvent(sessionID, &assistantMsg, event) if err != nil { + if errors.Is(err, context.Canceled) { + assistantMsg.AddFinish("canceled") + c.Messages.Update(assistantMsg) + return context.Canceled + } assistantMsg.AddFinish("error:" + err.Error()) c.Messages.Update(assistantMsg) return err } + + select { + case <-ctx.Done(): + assistantMsg.AddFinish("canceled") + c.Messages.Update(assistantMsg) + return context.Canceled + default: + } } - msg, err := c.handleToolExecution(c.Context, assistantMsg) + // Check for context cancellation before tool execution + select { + case <-ctx.Done(): + assistantMsg.AddFinish("canceled") + c.Messages.Update(assistantMsg) + return context.Canceled + default: + // Continue processing + } - c.Messages.Update(assistantMsg) + msg, err := c.handleToolExecution(ctx, assistantMsg) if err != nil { + if errors.Is(err, context.Canceled) { + assistantMsg.AddFinish("canceled") + c.Messages.Update(assistantMsg) + return context.Canceled + } return err } + c.Messages.Update(assistantMsg) + if len(assistantMsg.ToolCalls()) == 0 { break } @@ -249,6 +361,16 @@ func (c *agent) generate(sessionID string, content string) error { if msg != nil { messages = append(messages, *msg) } + + // Check for context cancellation after tool execution + select { + case <-ctx.Done(): + assistantMsg.AddFinish("canceled") + c.Messages.Update(assistantMsg) + return context.Canceled + default: + // Continue processing + } } return nil } diff --git a/internal/llm/agent/coder.go b/internal/llm/agent/coder.go index 8ff3c61aa8ba49b3f8f0447049ba23b33a056459..5deff05a8a85e5e8b1d38ece4715695809e19ec6 100644 --- a/internal/llm/agent/coder.go +++ b/internal/llm/agent/coder.go @@ -1,6 +1,7 @@ package agent import ( + "context" "errors" "github.com/kujtimiihoxha/termai/internal/app" @@ -28,9 +29,9 @@ func (c *coderAgent) setAgentTool(sessionID string) { } } -func (c *coderAgent) Generate(sessionID string, content string) error { +func (c *coderAgent) Generate(ctx context.Context, sessionID string, content string) error { c.setAgentTool(sessionID) - return c.generate(sessionID, content) + return c.generate(ctx, sessionID, content) } func NewCoderAgent(app *app.App) (Agent, error) { diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index dcf880fe797620c56af424d038405a7ca9be8685..b1c97b512b04f77d6da37636a474b8ef25ccf78c 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -22,8 +22,6 @@ type mcpTool struct { permissions permission.Service } -var logger = logging.Get() - type MCPClient interface { Initialize( ctx context.Context, @@ -143,13 +141,13 @@ func getTools(ctx context.Context, name string, m config.MCPServer, permissions _, err := c.Initialize(ctx, initRequest) if err != nil { - logger.Error("error initializing mcp client", "error", err) + logging.Error("error initializing mcp client", "error", err) return stdioTools } toolsRequest := mcp.ListToolsRequest{} tools, err := c.ListTools(ctx, toolsRequest) if err != nil { - logger.Error("error listing tools", "error", err) + logging.Error("error listing tools", "error", err) return stdioTools } for _, t := range tools.Tools { @@ -172,7 +170,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba m.Args..., ) if err != nil { - logger.Error("error creating mcp client", "error", err) + logging.Error("error creating mcp client", "error", err) continue } @@ -183,7 +181,7 @@ func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.Ba client.WithHeaders(m.Headers), ) if err != nil { - logger.Error("error creating mcp client", "error", err) + logging.Error("error creating mcp client", "error", err) continue } mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...) diff --git a/internal/llm/agent/task.go b/internal/llm/agent/task.go index 9737d41b84eb4cf94bf3c36b3d1ede2a9862080e..034e9346003c747884afd12098682397cf80dc46 100644 --- a/internal/llm/agent/task.go +++ b/internal/llm/agent/task.go @@ -1,6 +1,7 @@ package agent import ( + "context" "errors" "github.com/kujtimiihoxha/termai/internal/app" @@ -13,8 +14,8 @@ type taskAgent struct { *agent } -func (c *taskAgent) Generate(sessionID string, content string) error { - return c.generate(sessionID, content) +func (c *taskAgent) Generate(ctx context.Context, sessionID string, content string) error { + return c.generate(ctx, sessionID, content) } func NewTaskAgent(app *app.App) (Agent, error) { diff --git a/internal/logging/default.go b/internal/logging/default.go deleted file mode 100644 index 54cfaa49098ff012e52f36a2c574d57ba04af7e3..0000000000000000000000000000000000000000 --- a/internal/logging/default.go +++ /dev/null @@ -1,12 +0,0 @@ -package logging - -var defaultLogger Interface - -func Get() Interface { - if defaultLogger == nil { - defaultLogger = NewLogger(Options{ - Level: "info", - }) - } - return defaultLogger -} diff --git a/internal/logging/logger.go b/internal/logging/logger.go index 1f0e61d002597e999c60f8e5cb75dc10781c5043..b0639147271b50b4029f1cdbe5baf22c9d60333f 100644 --- a/internal/logging/logger.go +++ b/internal/logging/logger.go @@ -1,141 +1,39 @@ package logging -import ( - "context" - "io" - "log/slog" - "slices" +import "log/slog" - "github.com/kujtimiihoxha/termai/internal/pubsub" - "golang.org/x/exp/maps" -) - -const DefaultLevel = "info" - -const ( - persistKeyArg = "$persist" - PersistTimeArg = "$persist_time" -) - -var levels = map[string]slog.Level{ - "debug": slog.LevelDebug, - DefaultLevel: slog.LevelInfo, - "warn": slog.LevelWarn, - "error": slog.LevelError, -} - -func ValidLevels() []string { - keys := maps.Keys(levels) - slices.SortFunc(keys, func(a, b string) int { - if a == DefaultLevel { - return -1 - } - if b == DefaultLevel { - return 1 - } - if a < b { - return -1 - } - return 1 - }) - return keys -} - -func NewLogger(opts Options) Interface { - logger := &Logger{} - broker := pubsub.NewBroker[LogMessage]() - writer := &writer{ - messages: []LogMessage{}, - Broker: broker, - } - - handler := slog.NewTextHandler( - io.MultiWriter(writer), - &slog.HandlerOptions{ - Level: slog.Level(levels[opts.Level]), - }, - ) - logger.logger = slog.New(handler) - logger.writer = writer - - return logger +func Info(msg string, args ...any) { + slog.Info(msg, args...) } -type Options struct { - Level string +func Debug(msg string, args ...any) { + slog.Debug(msg, args...) } -type Logger struct { - logger *slog.Logger - writer *writer +func Warn(msg string, args ...any) { + slog.Warn(msg, args...) } -func (l *Logger) SetLevel(level string) { - if _, ok := levels[level]; !ok { - level = DefaultLevel - } - handler := slog.NewTextHandler( - io.MultiWriter(l.writer), - &slog.HandlerOptions{ - Level: levels[level], - }, - ) - l.logger = slog.New(handler) +func Error(msg string, args ...any) { + slog.Error(msg, args...) } -// PersistDebug implements Interface. -func (l *Logger) PersistDebug(msg string, args ...any) { +func InfoPersist(msg string, args ...any) { args = append(args, persistKeyArg, true) - l.Debug(msg, args...) + slog.Info(msg, args...) } -// PersistError implements Interface. -func (l *Logger) PersistError(msg string, args ...any) { +func DebugPersist(msg string, args ...any) { args = append(args, persistKeyArg, true) - l.Error(msg, args...) + slog.Debug(msg, args...) } -// PersistInfo implements Interface. -func (l *Logger) PersistInfo(msg string, args ...any) { +func WarnPersist(msg string, args ...any) { args = append(args, persistKeyArg, true) - l.Info(msg, args...) + slog.Warn(msg, args...) } -// PersistWarn implements Interface. -func (l *Logger) PersistWarn(msg string, args ...any) { +func ErrorPersist(msg string, args ...any) { args = append(args, persistKeyArg, true) - l.Warn(msg, args...) -} - -func (l *Logger) Debug(msg string, args ...any) { - l.logger.Debug(msg, args...) -} - -func (l *Logger) Info(msg string, args ...any) { - l.logger.Info(msg, args...) -} - -func (l *Logger) Warn(msg string, args ...any) { - l.logger.Warn(msg, args...) -} - -func (l *Logger) Error(msg string, args ...any) { - l.logger.Error(msg, args...) -} - -func (l *Logger) List() []LogMessage { - return l.writer.messages -} - -func (l *Logger) Get(id string) (LogMessage, error) { - for _, msg := range l.writer.messages { - if msg.ID == id { - return msg, nil - } - } - return LogMessage{}, io.EOF -} - -func (l *Logger) Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage] { - return l.writer.Subscribe(ctx) + slog.Error(msg, args...) } diff --git a/internal/logging/logging.go b/internal/logging/logging.go deleted file mode 100644 index c23cfaff8439d77523cd8ccf59625451b18b99aa..0000000000000000000000000000000000000000 --- a/internal/logging/logging.go +++ /dev/null @@ -1,23 +0,0 @@ -package logging - -import ( - "context" - - "github.com/kujtimiihoxha/termai/internal/pubsub" -) - -type Interface interface { - Debug(msg string, args ...any) - Info(msg string, args ...any) - Warn(msg string, args ...any) - Error(msg string, args ...any) - Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage] - - PersistDebug(msg string, args ...any) - PersistInfo(msg string, args ...any) - PersistWarn(msg string, args ...any) - PersistError(msg string, args ...any) - List() []LogMessage - - SetLevel(level string) -} diff --git a/internal/logging/writer.go b/internal/logging/writer.go index 06a5330e314093235e2976d709b5d61e26a223ab..9fe469c5e0953009413bc5171aba95d04936e92e 100644 --- a/internal/logging/writer.go +++ b/internal/logging/writer.go @@ -2,18 +2,47 @@ package logging import ( "bytes" + "context" "fmt" + "strings" + "sync" "time" "github.com/go-logfmt/logfmt" "github.com/kujtimiihoxha/termai/internal/pubsub" ) -type writer struct { +const ( + persistKeyArg = "$_persist" + PersistTimeArg = "$_persist_time" +) + +type LogData struct { messages []LogMessage *pubsub.Broker[LogMessage] + lock sync.Mutex +} + +func (l *LogData) Add(msg LogMessage) { + l.lock.Lock() + defer l.lock.Unlock() + l.messages = append(l.messages, msg) + l.Publish(pubsub.CreatedEvent, msg) +} + +func (l *LogData) List() []LogMessage { + l.lock.Lock() + defer l.lock.Unlock() + return l.messages +} + +var defaultLogData = &LogData{ + messages: make([]LogMessage, 0), + Broker: pubsub.NewBroker[LogMessage](), } +type writer struct{} + func (w *writer) Write(p []byte) (int, error) { d := logfmt.NewDecoder(bytes.NewReader(p)) for d.ScanRecord() { @@ -30,7 +59,7 @@ func (w *writer) Write(p []byte) (int, error) { } msg.Time = parsed case "level": - msg.Level = string(d.Value()) + msg.Level = strings.ToLower(string(d.Value())) case "msg": msg.Message = string(d.Value()) default: @@ -50,11 +79,23 @@ func (w *writer) Write(p []byte) (int, error) { } } } - w.messages = append(w.messages, msg) - w.Publish(pubsub.CreatedEvent, msg) + defaultLogData.Add(msg) } if d.Err() != nil { return 0, d.Err() } return len(p), nil } + +func NewWriter() *writer { + w := &writer{} + return w +} + +func Subscribe(ctx context.Context) <-chan pubsub.Event[LogMessage] { + return defaultLogData.Subscribe(ctx) +} + +func List() []LogMessage { + return defaultLogData.List() +} diff --git a/internal/lsp/client.go b/internal/lsp/client.go index 1b1456614ea24d7929f8555273283300af8af7a1..824a84b5d82d5974d4e378e7766f0892d9d21e87 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -18,8 +18,6 @@ import ( "github.com/kujtimiihoxha/termai/internal/lsp/protocol" ) -var logger = logging.Get() - type Client struct { Cmd *exec.Cmd stdin io.WriteCloser @@ -377,7 +375,7 @@ func (c *Client) CloseFile(ctx context.Context, filepath string) error { } if cnf.Debug { - logger.Debug("Closing file", "file", filepath) + logging.Debug("Closing file", "file", filepath) } if err := c.Notify(ctx, "textDocument/didClose", params); err != nil { return err @@ -416,12 +414,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) { for _, filePath := range filesToClose { err := c.CloseFile(ctx, filePath) if err != nil && cnf.Debug { - logger.Warn("Error closing file", "file", filePath, "error", err) + logging.Warn("Error closing file", "file", filePath, "error", err) } } if cnf.Debug { - logger.Debug("Closed all files", "files", filesToClose) + logging.Debug("Closed all files", "files", filesToClose) } } diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index 5fb9b49a5f20db5b52c26d2aff3ef85676cb57a4..4913c743d97c5e195dc77986ebacd622baf21958 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -4,6 +4,7 @@ import ( "encoding/json" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/logging" "github.com/kujtimiihoxha/termai/internal/lsp/protocol" "github.com/kujtimiihoxha/termai/internal/lsp/util" ) @@ -17,7 +18,7 @@ func HandleWorkspaceConfiguration(params json.RawMessage) (any, error) { func HandleRegisterCapability(params json.RawMessage) (any, error) { var registerParams protocol.RegistrationParams if err := json.Unmarshal(params, ®isterParams); err != nil { - logger.Error("Error unmarshaling registration params", "error", err) + logging.Error("Error unmarshaling registration params", "error", err) return nil, err } @@ -27,13 +28,13 @@ func HandleRegisterCapability(params json.RawMessage) (any, error) { // Parse the registration options optionsJSON, err := json.Marshal(reg.RegisterOptions) if err != nil { - logger.Error("Error marshaling registration options", "error", err) + logging.Error("Error marshaling registration options", "error", err) continue } var options protocol.DidChangeWatchedFilesRegistrationOptions if err := json.Unmarshal(optionsJSON, &options); err != nil { - logger.Error("Error unmarshaling registration options", "error", err) + logging.Error("Error unmarshaling registration options", "error", err) continue } @@ -53,7 +54,7 @@ func HandleApplyEdit(params json.RawMessage) (any, error) { err := util.ApplyWorkspaceEdit(edit.Edit) if err != nil { - logger.Error("Error applying workspace edit", "error", err) + logging.Error("Error applying workspace edit", "error", err) return protocol.ApplyWorkspaceEditResult{Applied: false, FailureReason: err.Error()}, nil } @@ -88,7 +89,7 @@ func HandleServerMessage(params json.RawMessage) { } if err := json.Unmarshal(params, &msg); err == nil { if cnf.Debug { - logger.Debug("Server message", "type", msg.Type, "message", msg.Message) + logging.Debug("Server message", "type", msg.Type, "message", msg.Message) } } } @@ -96,7 +97,7 @@ func HandleServerMessage(params json.RawMessage) { func HandleDiagnostics(client *Client, params json.RawMessage) { var diagParams protocol.PublishDiagnosticsParams if err := json.Unmarshal(params, &diagParams); err != nil { - logger.Error("Error unmarshaling diagnostics params", "error", err) + logging.Error("Error unmarshaling diagnostics params", "error", err) return } diff --git a/internal/lsp/transport.go b/internal/lsp/transport.go index 93253f4d6a1162a2c2e7e4157bd3abc65c5f5bb1..4185966f32d199961f66da25222e06d735ed41a5 100644 --- a/internal/lsp/transport.go +++ b/internal/lsp/transport.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/kujtimiihoxha/termai/internal/config" + "github.com/kujtimiihoxha/termai/internal/logging" ) // Write writes an LSP message to the given writer @@ -20,7 +21,7 @@ func WriteMessage(w io.Writer, msg *Message) error { cnf := config.Get() if cnf.Debug { - logger.Debug("Sending message to server", "method", msg.Method, "id", msg.ID) + logging.Debug("Sending message to server", "method", msg.Method, "id", msg.ID) } _, err = fmt.Fprintf(w, "Content-Length: %d\r\n\r\n", len(data)) @@ -49,7 +50,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { line = strings.TrimSpace(line) if cnf.Debug { - logger.Debug("Received header", "line", line) + logging.Debug("Received header", "line", line) } if line == "" { @@ -65,7 +66,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } if cnf.Debug { - logger.Debug("Content-Length", "length", contentLength) + logging.Debug("Content-Length", "length", contentLength) } // Read content @@ -76,7 +77,7 @@ func ReadMessage(r *bufio.Reader) (*Message, error) { } if cnf.Debug { - logger.Debug("Received content", "content", string(content)) + logging.Debug("Received content", "content", string(content)) } // Parse message @@ -95,7 +96,7 @@ func (c *Client) handleMessages() { msg, err := ReadMessage(c.stdout) if err != nil { if cnf.Debug { - logger.Error("Error reading message", "error", err) + logging.Error("Error reading message", "error", err) } return } @@ -103,7 +104,7 @@ func (c *Client) handleMessages() { // Handle server->client request (has both Method and ID) if msg.Method != "" && msg.ID != 0 { if cnf.Debug { - logger.Debug("Received request from server", "method", msg.Method, "id", msg.ID) + logging.Debug("Received request from server", "method", msg.Method, "id", msg.ID) } response := &Message{ @@ -143,7 +144,7 @@ func (c *Client) handleMessages() { // Send response back to server if err := WriteMessage(c.stdin, response); err != nil { - logger.Error("Error sending response to server", "error", err) + logging.Error("Error sending response to server", "error", err) } continue @@ -157,11 +158,11 @@ func (c *Client) handleMessages() { if ok { if cnf.Debug { - logger.Debug("Handling notification", "method", msg.Method) + logging.Debug("Handling notification", "method", msg.Method) } go handler(msg.Params) } else if cnf.Debug { - logger.Debug("No handler for notification", "method", msg.Method) + logging.Debug("No handler for notification", "method", msg.Method) } continue } @@ -174,12 +175,12 @@ func (c *Client) handleMessages() { if ok { if cnf.Debug { - logger.Debug("Received response for request", "id", msg.ID) + logging.Debug("Received response for request", "id", msg.ID) } ch <- msg close(ch) } else if cnf.Debug { - logger.Debug("No handler for response", "id", msg.ID) + logging.Debug("No handler for response", "id", msg.ID) } } } @@ -191,7 +192,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any id := c.nextID.Add(1) if cnf.Debug { - logger.Debug("Making call", "method", method, "id", id) + logging.Debug("Making call", "method", method, "id", id) } msg, err := NewRequest(id, method, params) @@ -217,14 +218,14 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any } if cnf.Debug { - logger.Debug("Request sent", "method", method, "id", id) + logging.Debug("Request sent", "method", method, "id", id) } // Wait for response resp := <-ch if cnf.Debug { - logger.Debug("Received response", "id", id) + logging.Debug("Received response", "id", id) } if resp.Error != nil { @@ -250,7 +251,7 @@ func (c *Client) Call(ctx context.Context, method string, params any, result any func (c *Client) Notify(ctx context.Context, method string, params any) error { cnf := config.Get() if cnf.Debug { - logger.Debug("Sending notification", "method", method) + logging.Debug("Sending notification", "method", method) } msg, err := NewNotification(method, params) diff --git a/internal/lsp/watcher/watcher.go b/internal/lsp/watcher/watcher.go index cfd389ed147172c3c8c9c6a868ae59be9493bf37..b5ef157109e517c9ba8afc5d26c29ac76a527868 100644 --- a/internal/lsp/watcher/watcher.go +++ b/internal/lsp/watcher/watcher.go @@ -16,8 +16,6 @@ import ( "github.com/kujtimiihoxha/termai/internal/lsp/protocol" ) -var logger = logging.Get() - // WorkspaceWatcher manages LSP file watching type WorkspaceWatcher struct { client *lsp.Client @@ -53,7 +51,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc // Print detailed registration information for debugging if cnf.Debug { - logger.Debug("Adding file watcher registrations", + logging.Debug("Adding file watcher registrations", "id", id, "watchers", len(watchers), "total", len(w.registrations), @@ -61,26 +59,26 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc ) for i, watcher := range watchers { - logger.Debug("Registration", "index", i+1) + logging.Debug("Registration", "index", i+1) // Log the GlobPattern switch v := watcher.GlobPattern.Value.(type) { case string: - logger.Debug("GlobPattern", "pattern", v) + logging.Debug("GlobPattern", "pattern", v) case protocol.RelativePattern: - logger.Debug("GlobPattern", "pattern", v.Pattern) + logging.Debug("GlobPattern", "pattern", v.Pattern) // Log BaseURI details switch u := v.BaseURI.Value.(type) { case string: - logger.Debug("BaseURI", "baseURI", u) + logging.Debug("BaseURI", "baseURI", u) case protocol.DocumentUri: - logger.Debug("BaseURI", "baseURI", u) + logging.Debug("BaseURI", "baseURI", u) default: - logger.Debug("BaseURI", "baseURI", u) + logging.Debug("BaseURI", "baseURI", u) } default: - logger.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v)) + logging.Debug("GlobPattern", "unknown type", fmt.Sprintf("%T", v)) } // Log WatchKind @@ -89,7 +87,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc watchKind = *watcher.Kind } - logger.Debug("WatchKind", "kind", watchKind) + logging.Debug("WatchKind", "kind", watchKind) // Test match against some example paths testPaths := []string{ @@ -99,7 +97,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc for _, testPath := range testPaths { isMatch := w.matchesPattern(testPath, watcher.GlobPattern) - logger.Debug("Test path", "path", testPath, "matches", isMatch) + logging.Debug("Test path", "path", testPath, "matches", isMatch) } } } @@ -119,7 +117,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc if d.IsDir() { if path != w.workspacePath && shouldExcludeDir(path) { if cnf.Debug { - logger.Debug("Skipping excluded directory", "path", path) + logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir } @@ -139,7 +137,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc elapsedTime := time.Since(startTime) if cnf.Debug { - logger.Debug("Workspace scan complete", + logging.Debug("Workspace scan complete", "filesOpened", filesOpened, "elapsedTime", elapsedTime.Seconds(), "workspacePath", w.workspacePath, @@ -147,7 +145,7 @@ func (w *WorkspaceWatcher) AddRegistrations(ctx context.Context, id string, watc } if err != nil && cnf.Debug { - logger.Debug("Error scanning workspace for files to open", "error", err) + logging.Debug("Error scanning workspace for files to open", "error", err) } }() } @@ -164,7 +162,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str watcher, err := fsnotify.NewWatcher() if err != nil { - logger.Error("Error creating watcher", "error", err) + logging.Error("Error creating watcher", "error", err) } defer watcher.Close() @@ -178,7 +176,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str if d.IsDir() && path != workspacePath { if shouldExcludeDir(path) { if cnf.Debug { - logger.Debug("Skipping excluded directory", "path", path) + logging.Debug("Skipping excluded directory", "path", path) } return filepath.SkipDir } @@ -188,14 +186,14 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str if d.IsDir() { err = watcher.Add(path) if err != nil { - logger.Error("Error watching path", "path", path, "error", err) + logging.Error("Error watching path", "path", path, "error", err) } } return nil }) if err != nil { - logger.Error("Error walking workspace", "error", err) + logging.Error("Error walking workspace", "error", err) } // Event loop @@ -217,7 +215,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str // Skip excluded directories if !shouldExcludeDir(event.Name) { if err := watcher.Add(event.Name); err != nil { - logger.Error("Error adding directory to watcher", "path", event.Name, "error", err) + logging.Error("Error adding directory to watcher", "path", event.Name, "error", err) } } } else { @@ -232,7 +230,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str // Debug logging if cnf.Debug { matched, kind := w.isPathWatched(event.Name) - logger.Debug("File event", + logging.Debug("File event", "path", event.Name, "operation", event.Op.String(), "watched", matched, @@ -277,7 +275,7 @@ func (w *WorkspaceWatcher) WatchWorkspace(ctx context.Context, workspacePath str if !ok { return } - logger.Error("Error watching file", "error", err) + logging.Error("Error watching file", "error", err) } } } @@ -402,7 +400,7 @@ func matchesSimpleGlob(pattern, path string) bool { // Fall back to simple matching for simpler patterns matched, err := filepath.Match(pattern, path) if err != nil { - logger.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err) + logging.Error("Error matching pattern", "pattern", pattern, "path", path, "error", err) return false } @@ -413,7 +411,7 @@ func matchesSimpleGlob(pattern, path string) bool { func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPattern) bool { patternInfo, err := pattern.AsPattern() if err != nil { - logger.Error("Error parsing pattern", "pattern", pattern, "error", err) + logging.Error("Error parsing pattern", "pattern", pattern, "error", err) return false } @@ -438,7 +436,7 @@ func (w *WorkspaceWatcher) matchesPattern(path string, pattern protocol.GlobPatt // Make path relative to basePath for matching relPath, err := filepath.Rel(basePath, path) if err != nil { - logger.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err) + logging.Error("Error getting relative path", "path", path, "basePath", basePath, "error", err) return false } relPath = filepath.ToSlash(relPath) @@ -479,14 +477,14 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan if changeType == protocol.FileChangeType(protocol.Changed) && w.client.IsFileOpen(filePath) { err := w.client.NotifyChange(ctx, filePath) if err != nil { - logger.Error("Error notifying change", "error", err) + logging.Error("Error notifying change", "error", err) } return } // Notify LSP server about the file event using didChangeWatchedFiles if err := w.notifyFileEvent(ctx, uri, changeType); err != nil { - logger.Error("Error notifying LSP server about file event", "error", err) + logging.Error("Error notifying LSP server about file event", "error", err) } } @@ -494,7 +492,7 @@ func (w *WorkspaceWatcher) handleFileEvent(ctx context.Context, uri string, chan func (w *WorkspaceWatcher) notifyFileEvent(ctx context.Context, uri string, changeType protocol.FileChangeType) error { cnf := config.Get() if cnf.Debug { - logger.Debug("Notifying file event", + logging.Debug("Notifying file event", "uri", uri, "changeType", changeType, ) @@ -618,7 +616,7 @@ func shouldExcludeFile(filePath string) bool { // Skip large files if info.Size() > maxFileSize { if cnf.Debug { - logger.Debug("Skipping large file", + logging.Debug("Skipping large file", "path", filePath, "size", info.Size(), "maxSize", maxFileSize, @@ -651,7 +649,7 @@ func (w *WorkspaceWatcher) openMatchingFile(ctx context.Context, path string) { if watched, _ := w.isPathWatched(path); watched { // Don't need to check if it's already open - the client.OpenFile handles that if err := w.client.OpenFile(ctx, path); err != nil && cnf.Debug { - logger.Error("Error opening file", "path", path, "error", err) + logging.Error("Error opening file", "path", path, "error", err) } } } diff --git a/internal/tui/components/core/status.go b/internal/tui/components/core/status.go index d67bdb7a6ce62bcf36f37ceeebd6e82e3c17da77..93ba345075e513e691cb0b27a6f5de870c5c2354 100644 --- a/internal/tui/components/core/status.go +++ b/internal/tui/components/core/status.go @@ -13,7 +13,7 @@ import ( ) type statusCmp struct { - info *util.InfoMsg + info util.InfoMsg width int messageTTL time.Duration } @@ -35,14 +35,14 @@ func (m statusCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.width = msg.Width return m, nil case util.InfoMsg: - m.info = &msg + m.info = msg ttl := msg.TTL if ttl == 0 { ttl = m.messageTTL } return m, m.clearMessageCmd(ttl) case util.ClearStatusMsg: - m.info = nil + m.info = util.InfoMsg{} } return m, nil } @@ -54,7 +54,7 @@ var ( func (m statusCmp) View() string { status := styles.Padded.Background(styles.Grey).Foreground(styles.Text).Render("? help") - if m.info != nil { + if m.info.Msg != "" { infoStyle := styles.Padded. Foreground(styles.Base). Width(m.availableFooterMsgWidth()) diff --git a/internal/tui/components/logs/details.go b/internal/tui/components/logs/details.go index e02dfd9cc513942ce3435f2d81dafe4360dc0cf3..dbace5508d253220df39862ae0e1992ed6616d83 100644 --- a/internal/tui/components/logs/details.go +++ b/internal/tui/components/logs/details.go @@ -30,7 +30,7 @@ type detailCmp struct { } func (i *detailCmp) Init() tea.Cmd { - messages := logging.Get().List() + messages := logging.List() if len(messages) == 0 { return nil } diff --git a/internal/tui/components/logs/table.go b/internal/tui/components/logs/table.go index 0e90d2d2a5aad99f0e3ce2d9846b653051dafd7e..9500059b1e6052dd95c9b11bf58be4a21bd7ac25 100644 --- a/internal/tui/components/logs/table.go +++ b/internal/tui/components/logs/table.go @@ -22,8 +22,6 @@ type TableComponent interface { layout.Bordered } -var logger = logging.Get() - type tableCmp struct { table table.Model } @@ -57,7 +55,7 @@ func (i *tableCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if selectedRow != nil { if prevSelectedRow == nil || selectedRow[0] == prevSelectedRow[0] { var log logging.LogMessage - for _, row := range logging.Get().List() { + for _, row := range logging.List() { if row.ID == selectedRow[0] { log = row break @@ -112,7 +110,7 @@ func (i *tableCmp) BindingKeys() []key.Binding { func (i *tableCmp) setRows() { rows := []table.Row{} - logs := logger.List() + logs := logging.List() slices.SortFunc(logs, func(a, b logging.LogMessage) int { if a.Time.Before(b.Time) { return 1 diff --git a/internal/tui/components/repl/editor.go b/internal/tui/components/repl/editor.go index 0228996b816016a3e970760917a072d3ae597445..37ac275e3f11db19fd464ba94394fbc65bcdc054 100644 --- a/internal/tui/components/repl/editor.go +++ b/internal/tui/components/repl/editor.go @@ -12,6 +12,7 @@ import ( "github.com/kujtimiihoxha/termai/internal/tui/styles" "github.com/kujtimiihoxha/termai/internal/tui/util" "github.com/kujtimiihoxha/vimtea" + "golang.org/x/net/context" ) type EditorCmp interface { @@ -23,18 +24,20 @@ type EditorCmp interface { } type editorCmp struct { - app *app.App - editor vimtea.Editor - editorMode vimtea.EditorMode - sessionID string - focused bool - width int - height int + app *app.App + editor vimtea.Editor + editorMode vimtea.EditorMode + sessionID string + focused bool + width int + height int + cancelMessage context.CancelFunc } type editorKeyMap struct { SendMessage key.Binding SendMessageI key.Binding + CancelMessage key.Binding InsertMode key.Binding NormaMode key.Binding VisualMode key.Binding @@ -50,6 +53,10 @@ var editorKeyMapValue = editorKeyMap{ key.WithKeys("ctrl+s"), key.WithHelp("ctrl+s", "send message insert mode"), ), + CancelMessage: key.NewBinding( + key.WithKeys("ctrl+x"), + key.WithHelp("ctrl+x", "cancel current message"), + ), InsertMode: key.NewBinding( key.WithKeys("i"), key.WithHelp("i", "insert mode"), @@ -93,6 +100,8 @@ func (m *editorCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if m.editorMode == vimtea.ModeInsert { return m, m.Send() } + case key.Matches(msg, editorKeyMapValue.CancelMessage): + return m, m.Cancel() } } u, cmd := m.editor.Update(msg) @@ -136,6 +145,16 @@ func (m *editorCmp) SetSize(width int, height int) { m.editor.SetSize(width, height) } +func (m *editorCmp) Cancel() tea.Cmd { + if m.cancelMessage == nil { + return util.ReportWarn("No message to cancel") + } + + m.cancelMessage() + m.cancelMessage = nil + return util.ReportWarn("Message cancelled") +} + func (m *editorCmp) Send() tea.Cmd { return func() tea.Msg { messages, err := m.app.Messages.List(m.sessionID) @@ -151,7 +170,13 @@ func (m *editorCmp) Send() tea.Cmd { } content := strings.Join(m.editor.GetBuffer().Lines(), "\n") - go a.Generate(m.sessionID, content) + ctx, cancel := context.WithCancel(m.app.Context) + m.cancelMessage = cancel + go func() { + defer cancel() + a.Generate(ctx, m.sessionID, content) + m.cancelMessage = nil + }() return m.editor.Reset() } diff --git a/internal/tui/components/repl/messages.go b/internal/tui/components/repl/messages.go index 481783a2e81481f890ca349de0dfae7a01e3a5b8..57a55c579b432edec387e5e8c8f05ec858dc2c5c 100644 --- a/internal/tui/components/repl/messages.go +++ b/internal/tui/components/repl/messages.go @@ -309,7 +309,7 @@ func (m *messagesCmp) renderMessageWithToolCall(content string, tools []message. } for _, msg := range futureMessages { - if msg.Content().String() != "" { + if msg.Content().String() != "" || msg.FinishReason() == "canceled" { break } @@ -345,13 +345,18 @@ func (m *messagesCmp) renderView() { prevMessageWasUser := false for inx, msg := range m.messages { content := msg.Content().String() - if content != "" || prevMessageWasUser { + if content != "" || prevMessageWasUser || msg.FinishReason() == "canceled" { if msg.ReasoningContent().String() != "" && content == "" { content = msg.ReasoningContent().String() } else if content == "" { content = "..." } - content, _ = r.Render(content) + if msg.FinishReason() == "canceled" { + content, _ = r.Render(content) + content += lipgloss.NewStyle().Padding(1, 0, 0, 1).Foreground(styles.Error).Render(styles.ErrorIcon + " Canceled") + } else { + content, _ = r.Render(content) + } isSelected := inx == m.selectedMsgIdx diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 233b189dd1de917aa784dda80f2d05cf746a2017..9e863d2ac844c3f3a034124b0eadbd6700893b01 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -101,7 +101,8 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Status case util.InfoMsg: a.status, cmd = a.status.Update(msg) - return a, cmd + cmds = append(cmds, cmd) + return a, tea.Batch(cmds...) case pubsub.Event[logging.LogMessage]: if msg.Payload.Persist { switch msg.Payload.Level {