diff --git a/go.mod b/go.mod index 770fdd7d04c36909edcfabff1777a13b1823c519..89763cb3f3664493e8d729fbd1423e75e46a03fe 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/invopop/jsonschema v0.13.0 github.com/joho/godotenv v1.5.1 github.com/lucasb-eyer/go-colorful v1.3.0 + github.com/maniartech/signals v1.3.1 github.com/modelcontextprotocol/go-sdk v1.2.0 github.com/muesli/termenv v0.16.0 github.com/ncruces/go-sqlite3 v0.30.4 diff --git a/go.sum b/go.sum index 5e1cb24cf95c53384a2dea077c3633c267c2d11d..7e7f21eb1f7a8eb9dfbcf6c7fdd011e858c4512a 100644 --- a/go.sum +++ b/go.sum @@ -243,6 +243,8 @@ github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQ github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/maniartech/signals v1.3.1 h1:pT3dK6x5Un+B6L3ZLAKygEe+L49TClPreyT08vOoHXY= +github.com/maniartech/signals v1.3.1/go.mod h1:AbE8Yy9ZjKCWNU/VhQ+0Ea9KOaTWHp6aOfdLBe5m1iM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index e1e7d609efc86d0dcb510fa5963552f7d487a134..32d337156650a2b3a1b122be181429b89df63b29 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -92,9 +92,9 @@ type ClientInfo struct { ConnectedAt time.Time } -// SubscribeEvents returns a channel for MCP events -func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] { - return broker.Subscribe(ctx) +// AddEventListener registers a callback for MCP events. +func AddEventListener(key string, fn func(pubsub.Event[Event])) { + broker.AddListener(key, fn) } // GetStates returns the current state of all MCP clients @@ -139,13 +139,13 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config // Initialize states for all configured MCPs for name, m := range cfg.MCP { if m.Disabled { - updateState(name, StateDisabled, nil, nil, Counts{}) + updateState(ctx, name, StateDisabled, nil, nil, Counts{}) slog.Debug("skipping disabled mcp", "name", name) continue } // Set initial starting state - updateState(name, StateStarting, nil, nil, Counts{}) + updateState(ctx, name, StateStarting, nil, nil, Counts{}) wg.Add(1) go func(name string, m config.MCPConfig) { @@ -161,7 +161,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config default: err = fmt.Errorf("panic: %v", v) } - updateState(name, StateError, err, nil, Counts{}) + updateState(ctx, name, StateError, err, nil, Counts{}) slog.Error("panic in mcp client initialization", "error", err, "name", name) } }() @@ -175,7 +175,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config tools, err := getTools(ctx, session) if err != nil { slog.Error("error listing tools", "error", err) - updateState(name, StateError, err, nil, Counts{}) + updateState(ctx, name, StateError, err, nil, Counts{}) session.Close() return } @@ -183,7 +183,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config prompts, err := getPrompts(ctx, session) if err != nil { slog.Error("error listing prompts", "error", err) - updateState(name, StateError, err, nil, Counts{}) + updateState(ctx, name, StateError, err, nil, Counts{}) session.Close() return } @@ -192,7 +192,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config updatePrompts(name, prompts) sessions.Set(name, session) - updateState(name, StateConnected, nil, session, Counts{ + updateState(ctx, name, StateConnected, nil, session, Counts{ Tools: toolCount, Prompts: len(prompts), }) @@ -230,20 +230,20 @@ func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, err if err == nil { return sess, nil } - updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts) + updateState(ctx, name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts) sess, err = createSession(ctx, name, m, cfg.Resolver()) if err != nil { return nil, err } - updateState(name, StateConnected, nil, sess, state.Counts) + updateState(ctx, name, StateConnected, nil, sess, state.Counts) sessions.Set(name, sess) return sess, nil } // updateState updates the state of an MCP client and publishes an event -func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) { +func updateState(ctx context.Context, name string, state State, err error, client *mcp.ClientSession, counts Counts) { info := ClientInfo{ Name: name, State: state, @@ -260,7 +260,7 @@ func updateState(name string, state State, err error, client *mcp.ClientSession, states.Set(name, info) // Publish state change event - broker.Publish(pubsub.UpdatedEvent, Event{ + broker.Publish(ctx, pubsub.UpdatedEvent, Event{ Type: EventStateChanged, Name: name, State: state, @@ -276,7 +276,7 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve transport, err := createTransport(mcpCtx, m, resolver) if err != nil { - updateState(name, StateError, err, nil, Counts{}) + updateState(ctx, name, StateError, err, nil, Counts{}) slog.Error("error creating mcp client", "error", err, "name", name) cancel() cancelTimer.Stop() @@ -290,14 +290,14 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve Title: "Crush", }, &mcp.ClientOptions{ - ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) { - broker.Publish(pubsub.UpdatedEvent, Event{ + ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) { + broker.Publish(ctx, pubsub.UpdatedEvent, Event{ Type: EventToolsListChanged, Name: name, }) }, - PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) { - broker.Publish(pubsub.UpdatedEvent, Event{ + PromptListChangedHandler: func(ctx context.Context, _ *mcp.PromptListChangedRequest) { + broker.Publish(ctx, pubsub.UpdatedEvent, Event{ Type: EventPromptsListChanged, Name: name, }) @@ -311,7 +311,7 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve session, err := client.Connect(mcpCtx, transport, nil) if err != nil { err = maybeStdioErr(err, transport) - updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{}) + updateState(ctx, name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{}) slog.Error("MCP client failed to initialize", "error", err, "name", name) cancel() cancelTimer.Stop() diff --git a/internal/agent/tools/mcp/prompts.go b/internal/agent/tools/mcp/prompts.go index 0bd6e665dd80dad90c844d7d31c61c506ea83803..0306a18732e335c6f4bd30380ed8e18b718adf0e 100644 --- a/internal/agent/tools/mcp/prompts.go +++ b/internal/agent/tools/mcp/prompts.go @@ -55,7 +55,7 @@ func RefreshPrompts(ctx context.Context, name string) { prompts, err := getPrompts(ctx, session) if err != nil { - updateState(name, StateError, err, nil, Counts{}) + updateState(ctx, name, StateError, err, nil, Counts{}) return } @@ -63,7 +63,7 @@ func RefreshPrompts(ctx context.Context, name string) { prev, _ := states.Get(name) prev.Counts.Prompts = len(prompts) - updateState(name, StateConnected, nil, session, prev.Counts) + updateState(ctx, name, StateConnected, nil, session, prev.Counts) } func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) { diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index 779baa55d93bc54523bac81c5094bacee7fc68fb..55830822b242fb88d3919618cbe139faed2f7c56 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -117,7 +117,7 @@ func RefreshTools(ctx context.Context, name string) { tools, err := getTools(ctx, session) if err != nil { - updateState(name, StateError, err, nil, Counts{}) + updateState(ctx, name, StateError, err, nil, Counts{}) return } @@ -125,7 +125,7 @@ func RefreshTools(ctx context.Context, name string) { prev, _ := states.Get(name) prev.Counts.Tools = toolCount - updateState(name, StateConnected, nil, session, prev.Counts) + updateState(ctx, name, StateConnected, nil, session, prev.Counts) } func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) { diff --git a/internal/agent/tools/multiedit_test.go b/internal/agent/tools/multiedit_test.go index b6d575435e63dcd62a4dc9a7efb76cf13c14ad05..221d84a0afaa624b91a093de3935260c5006f400 100644 --- a/internal/agent/tools/multiedit_test.go +++ b/internal/agent/tools/multiedit_test.go @@ -23,11 +23,12 @@ func (m *mockPermissionService) Request(ctx context.Context, req permission.Crea return true, nil } -func (m *mockPermissionService) Grant(req permission.PermissionRequest) {} +func (m *mockPermissionService) Grant(ctx context.Context, req permission.PermissionRequest) {} -func (m *mockPermissionService) Deny(req permission.PermissionRequest) {} +func (m *mockPermissionService) Deny(ctx context.Context, req permission.PermissionRequest) {} -func (m *mockPermissionService) GrantPersistent(req permission.PermissionRequest) {} +func (m *mockPermissionService) GrantPersistent(ctx context.Context, req permission.PermissionRequest) { +} func (m *mockPermissionService) AutoApproveSession(sessionID string) {} @@ -37,8 +38,7 @@ func (m *mockPermissionService) SkipRequests() bool { return false } -func (m *mockPermissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[permission.PermissionNotification] { - return make(<-chan pubsub.Event[permission.PermissionNotification]) +func (m *mockPermissionService) AddNotificationListener(key string, fn func(pubsub.Event[permission.PermissionNotification])) { } type mockHistoryService struct { diff --git a/internal/app/app.go b/internal/app/app.go index 0f98a8383124274d8aaae12b40146411ed969c8d..c86bf3dd0266fe65c37b33768480bc0d5340cd54 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -52,10 +52,8 @@ type App struct { config *config.Config - serviceEventsWG *sync.WaitGroup - eventsCtx context.Context - events chan tea.Msg - tuiWG *sync.WaitGroup + // program holds reference to the TUI program for sending events. + program *tea.Program // global context and cleanup functions globalCtx context.Context @@ -84,14 +82,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { globalCtx: ctx, config: cfg, - - events: make(chan tea.Msg, 100), - serviceEventsWG: &sync.WaitGroup{}, - tuiWG: &sync.WaitGroup{}, } - app.setupEvents() - // Initialize LSP clients in the background. app.initLSPClients(ctx) @@ -223,8 +215,35 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt } }(ctx, sess.ID, prompt) - messageEvents := app.Messages.Subscribe(ctx) - messageReadBytes := make(map[string]int) + var ( + messageReadBytes = make(map[string]int) + messageReadBytesMu sync.Mutex + ) + + // Register callback to process messages directly. + app.Messages.AddListener("non-interactive", func(event pubsub.Event[message.Message]) { + msg := event.Payload + if msg.SessionID != sess.ID || msg.Role != message.Assistant || len(msg.Parts) == 0 { + return + } + stopSpinner() + + content := msg.Content().String() + + messageReadBytesMu.Lock() + readBytes := messageReadBytes[msg.ID] + if len(content) >= readBytes { + part := content[readBytes:] + // Trim leading whitespace. Sometimes the LLM includes leading + // formatting and indentation, which we don't want here. + if readBytes == 0 { + part = strings.TrimLeft(part, " \t") + } + fmt.Fprint(output, part) + messageReadBytes[msg.ID] = len(content) + } + messageReadBytesMu.Unlock() + }) defer func() { if stderrTTY { @@ -236,6 +255,7 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt _, _ = fmt.Fprintln(output) }() + // Wait for agent completion or cancellation. for { if stderrTTY { // HACK: Reinitialize the terminal progress bar on every iteration @@ -255,29 +275,6 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt } return nil - case event := <-messageEvents: - msg := event.Payload - if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 { - stopSpinner() - - content := msg.Content().String() - readBytes := messageReadBytes[msg.ID] - - if len(content) < readBytes { - slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(content), "read_bytes", readBytes) - return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes) - } - - part := content[readBytes:] - // Trim leading whitespace. Sometimes the LLM includes leading - // formatting and intentation, which we don't want here. - if readBytes == 0 { - part = strings.TrimLeft(part, " \t") - } - fmt.Fprint(output, part) - messageReadBytes[msg.ID] = len(content) - } - case <-ctx.Done(): stopSpinner() return ctx.Err() @@ -292,57 +289,6 @@ func (app *App) UpdateAgentModel(ctx context.Context) error { return app.AgentCoordinator.UpdateModels(ctx) } -func (app *App) setupEvents() { - ctx, cancel := context.WithCancel(app.globalCtx) - app.eventsCtx = ctx - setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events) - setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events) - setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events) - setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events) - setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events) - setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events) - setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events) - cleanupFunc := func() error { - cancel() - app.serviceEventsWG.Wait() - return nil - } - app.cleanupFuncs = append(app.cleanupFuncs, cleanupFunc) -} - -func setupSubscriber[T any]( - ctx context.Context, - wg *sync.WaitGroup, - name string, - subscriber func(context.Context) <-chan pubsub.Event[T], - outputCh chan<- tea.Msg, -) { - wg.Go(func() { - subCh := subscriber(ctx) - for { - select { - case event, ok := <-subCh: - if !ok { - slog.Debug("subscription channel closed", "name", name) - return - } - var msg tea.Msg = event - select { - case outputCh <- msg: - case <-time.After(2 * time.Second): - slog.Warn("message dropped due to slow consumer", "name", name) - case <-ctx.Done(): - slog.Debug("subscription cancelled", "name", name) - return - } - case <-ctx.Done(): - slog.Debug("subscription cancelled", "name", name) - return - } - } - }) -} - func (app *App) InitCoderAgent(ctx context.Context) error { coderAgentCfg := app.config.Agents[config.AgentCoder] if coderAgentCfg.ID == "" { @@ -365,36 +311,37 @@ func (app *App) InitCoderAgent(ctx context.Context) error { return nil } -// Subscribe sends events to the TUI as tea.Msgs. +// Subscribe registers event listeners that send events to the TUI program. func (app *App) Subscribe(program *tea.Program) { defer log.RecoverPanic("app.Subscribe", func() { slog.Info("TUI subscription panic: attempting graceful shutdown") program.Quit() }) - app.tuiWG.Add(1) - tuiCtx, tuiCancel := context.WithCancel(app.globalCtx) - app.cleanupFuncs = append(app.cleanupFuncs, func() error { - slog.Debug("Cancelling TUI message handler") - tuiCancel() - app.tuiWG.Wait() - return nil - }) - defer app.tuiWG.Done() + app.program = program - for { - select { - case <-tuiCtx.Done(): - slog.Debug("TUI message handler shutting down") - return - case msg, ok := <-app.events: - if !ok { - slog.Debug("TUI message channel closed") - return - } - program.Send(msg) - } - } + // Register listeners that send directly to the program. + app.Sessions.AddListener("tui-sessions", func(event pubsub.Event[session.Session]) { + program.Send(event) + }) + app.Messages.AddListener("tui-messages", func(event pubsub.Event[message.Message]) { + program.Send(event) + }) + app.Permissions.AddListener("tui-permissions", func(event pubsub.Event[permission.PermissionRequest]) { + program.Send(event) + }) + app.Permissions.AddNotificationListener("tui-permissions-notifications", func(event pubsub.Event[permission.PermissionNotification]) { + program.Send(event) + }) + app.History.AddListener("tui-history", func(event pubsub.Event[history.File]) { + program.Send(event) + }) + mcp.AddEventListener("tui-mcp", func(event pubsub.Event[mcp.Event]) { + program.Send(event) + }) + AddLSPEventListener("tui-lsp", func(event pubsub.Event[LSPEvent]) { + program.Send(event) + }) } // Shutdown performs a graceful shutdown of the application. @@ -452,9 +399,11 @@ func (app *App) checkForUpdates(ctx context.Context) { if err != nil || !info.Available() { return } - app.events <- pubsub.UpdateAvailableMsg{ - CurrentVersion: info.Current, - LatestVersion: info.Latest, - IsDevelopment: info.IsDevelopment(), + if app.program != nil { + app.program.Send(pubsub.UpdateAvailableMsg{ + CurrentVersion: info.Current, + LatestVersion: info.Latest, + IsDevelopment: info.IsDevelopment(), + }) } } diff --git a/internal/app/lsp_events.go b/internal/app/lsp_events.go index 5292983d46cf867b9380ad45f7831007da54f0d7..971c7a27766ce314ff8628293bd2d8a622f01fd4 100644 --- a/internal/app/lsp_events.go +++ b/internal/app/lsp_events.go @@ -41,9 +41,9 @@ var ( lspBroker = pubsub.NewBroker[LSPEvent]() ) -// SubscribeLSPEvents returns a channel for LSP events -func SubscribeLSPEvents(ctx context.Context) <-chan pubsub.Event[LSPEvent] { - return lspBroker.Subscribe(ctx) +// AddLSPEventListener registers a callback for LSP events. +func AddLSPEventListener(key string, fn func(pubsub.Event[LSPEvent])) { + lspBroker.AddListener(key, fn) } // GetLSPStates returns the current state of all LSP clients @@ -71,7 +71,7 @@ func updateLSPState(name string, state lsp.ServerState, err error, client *lsp.C lspStates.Set(name, info) // Publish state change event - lspBroker.Publish(pubsub.UpdatedEvent, LSPEvent{ + lspBroker.Publish(context.Background(), pubsub.UpdatedEvent, LSPEvent{ Type: LSPEventStateChanged, Name: name, State: state, @@ -87,7 +87,7 @@ func updateLSPDiagnostics(name string, diagnosticCount int) { lspStates.Set(name, info) // Publish diagnostics change event - lspBroker.Publish(pubsub.UpdatedEvent, LSPEvent{ + lspBroker.Publish(context.Background(), pubsub.UpdatedEvent, LSPEvent{ Type: LSPEventDiagnosticsChanged, Name: name, State: info.State, diff --git a/internal/history/file.go b/internal/history/file.go index 438b4116bddfc8e6ce54e8226c4210c7b3a7f7be..e573e04f013545676ad02a4111689179f4468e58 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -120,7 +120,7 @@ func (s *service) createWithVersion(ctx context.Context, sessionID, path, conten } file = s.fromDBItem(dbFile) - s.Publish(pubsub.CreatedEvent, file) + s.Publish(ctx, pubsub.CreatedEvent, file) return file, nil } @@ -179,7 +179,7 @@ func (s *service) Delete(ctx context.Context, id string) error { if err != nil { return err } - s.Publish(pubsub.DeletedEvent, file) + s.Publish(ctx, pubsub.DeletedEvent, file) return nil } diff --git a/internal/message/message.go b/internal/message/message.go index a09d0acbf590e840541a7d5e057fb89513cc0618..e29d1ba177436589313cabe03825bef71bca7b6f 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -53,7 +53,7 @@ func (s *service) Delete(ctx context.Context, id string) error { } // Clone the message before publishing to avoid race conditions with // concurrent modifications to the Parts slice. - s.Publish(pubsub.DeletedEvent, message.Clone()) + s.Publish(ctx, pubsub.DeletedEvent, message.Clone()) return nil } @@ -89,7 +89,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes } // Clone the message before publishing to avoid race conditions with // concurrent modifications to the Parts slice. - s.Publish(pubsub.CreatedEvent, message.Clone()) + s.Publish(ctx, pubsub.CreatedEvent, message.Clone()) return message, nil } @@ -130,7 +130,7 @@ func (s *service) Update(ctx context.Context, message Message) error { message.UpdatedAt = time.Now().Unix() // Clone the message before publishing to avoid race conditions with // concurrent modifications to the Parts slice. - s.Publish(pubsub.UpdatedEvent, message.Clone()) + s.Publish(ctx, pubsub.UpdatedEvent, message.Clone()) return nil } diff --git a/internal/permission/permission.go b/internal/permission/permission.go index e1bf1bae14b8473989b1c0890c58188591123d71..4dcf8a3b900dbdccce1301017a1e43aea5cc8fc9 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -44,14 +44,14 @@ type PermissionRequest struct { type Service interface { pubsub.Subscriber[PermissionRequest] - GrantPersistent(permission PermissionRequest) - Grant(permission PermissionRequest) - Deny(permission PermissionRequest) + GrantPersistent(ctx context.Context, permission PermissionRequest) + Grant(ctx context.Context, permission PermissionRequest) + Deny(ctx context.Context, permission PermissionRequest) Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) AutoApproveSession(sessionID string) SetSkipRequests(skip bool) SkipRequests() bool - SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[PermissionNotification] + AddNotificationListener(key string, fn func(pubsub.Event[PermissionNotification])) } type permissionService struct { @@ -73,8 +73,8 @@ type permissionService struct { activeRequestMu sync.Mutex } -func (s *permissionService) GrantPersistent(permission PermissionRequest) { - s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ +func (s *permissionService) GrantPersistent(ctx context.Context, permission PermissionRequest) { + s.notificationBroker.Publish(ctx, pubsub.CreatedEvent, PermissionNotification{ ToolCallID: permission.ToolCallID, Granted: true, }) @@ -94,8 +94,8 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) { s.activeRequestMu.Unlock() } -func (s *permissionService) Grant(permission PermissionRequest) { - s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ +func (s *permissionService) Grant(ctx context.Context, permission PermissionRequest) { + s.notificationBroker.Publish(ctx, pubsub.CreatedEvent, PermissionNotification{ ToolCallID: permission.ToolCallID, Granted: true, }) @@ -111,8 +111,8 @@ func (s *permissionService) Grant(permission PermissionRequest) { s.activeRequestMu.Unlock() } -func (s *permissionService) Deny(permission PermissionRequest) { - s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ +func (s *permissionService) Deny(ctx context.Context, permission PermissionRequest) { + s.notificationBroker.Publish(ctx, pubsub.CreatedEvent, PermissionNotification{ ToolCallID: permission.ToolCallID, Granted: false, Denied: true, @@ -135,7 +135,7 @@ func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRe } // tell the UI that a permission was requested - s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ + s.notificationBroker.Publish(ctx, pubsub.CreatedEvent, PermissionNotification{ ToolCallID: opts.ToolCallID, }) s.requestMu.Lock() @@ -206,7 +206,7 @@ func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRe defer s.pendingRequests.Del(permission.ID) // Publish the request - s.Publish(pubsub.CreatedEvent, permission) + s.Publish(ctx, pubsub.CreatedEvent, permission) select { case <-ctx.Done(): @@ -222,8 +222,8 @@ func (s *permissionService) AutoApproveSession(sessionID string) { s.autoApproveSessionsMu.Unlock() } -func (s *permissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[PermissionNotification] { - return s.notificationBroker.Subscribe(ctx) +func (s *permissionService) AddNotificationListener(key string, fn func(pubsub.Event[PermissionNotification])) { + s.notificationBroker.AddListener(key, fn) } func (s *permissionService) SetSkipRequests(skip bool) { diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go index 79930f3ae1e2ef15257f09724fef64d3ea28dada..410ae97e0d7c6d57978aa2ae045f59b1b5e64270 100644 --- a/internal/permission/permission_test.go +++ b/internal/permission/permission_test.go @@ -4,6 +4,7 @@ import ( "sync" "testing" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -114,7 +115,10 @@ func TestPermissionService_SequentialProperties(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - events := service.Subscribe(t.Context()) + events := make(chan pubsub.Event[PermissionRequest], 10) + service.AddListener("test", func(event pubsub.Event[PermissionRequest]) { + events <- event + }) go func() { defer wg.Done() @@ -125,7 +129,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) { event := <-events permissionReq = event.Payload - service.GrantPersistent(permissionReq) + service.GrantPersistent(t.Context(), permissionReq) wg.Wait() assert.True(t, result1, "First request should be granted") @@ -155,7 +159,10 @@ func TestPermissionService_SequentialProperties(t *testing.T) { Path: "/tmp/test.txt", } - events := service.Subscribe(t.Context()) + events := make(chan pubsub.Event[PermissionRequest], 10) + service.AddListener("test", func(event pubsub.Event[PermissionRequest]) { + events <- event + }) var result1 bool var wg sync.WaitGroup @@ -167,7 +174,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) { event := <-events permissionReq = event.Payload - service.Grant(permissionReq) + service.Grant(t.Context(), permissionReq) wg.Wait() assert.True(t, result1, "First request should be granted") @@ -179,14 +186,17 @@ func TestPermissionService_SequentialProperties(t *testing.T) { event = <-events permissionReq = event.Payload - service.Deny(permissionReq) + service.Deny(t.Context(), permissionReq) wg.Wait() assert.False(t, result2, "Second request should be denied") }) t.Run("Concurrent requests with different outcomes", func(t *testing.T) { service := NewPermissionService("/tmp", false, []string{}) - events := service.Subscribe(t.Context()) + events := make(chan pubsub.Event[PermissionRequest], 10) + service.AddListener("test", func(event pubsub.Event[PermissionRequest]) { + events <- event + }) var wg sync.WaitGroup results := make([]bool, 3) @@ -228,11 +238,11 @@ func TestPermissionService_SequentialProperties(t *testing.T) { event := <-events switch event.Payload.ToolName { case "tool1": - service.Grant(event.Payload) + service.Grant(t.Context(), event.Payload) case "tool2": - service.GrantPersistent(event.Payload) + service.GrantPersistent(t.Context(), event.Payload) case "tool3": - service.Deny(event.Payload) + service.Deny(t.Context(), event.Payload) } } wg.Wait() diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index ed14cbfed6c8fd44355501e16457e0dd92a494bc..6986b1b7bb234e004a97cc8db7f381d422ebada6 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -2,112 +2,51 @@ package pubsub import ( "context" - "sync" -) -const bufferSize = 64 + "github.com/maniartech/signals" +) +// Broker is a generic pub/sub broker backed by maniartech/signals. type Broker[T any] struct { - subs map[chan Event[T]]struct{} - mu sync.RWMutex - done chan struct{} - subCount int - maxEvents int + signal *signals.AsyncSignal[Event[T]] } +// NewBroker creates a new broker. func NewBroker[T any]() *Broker[T] { - return NewBrokerWithOptions[T](bufferSize, 1000) + return &Broker[T]{ + signal: signals.New[Event[T]](), + } } -func NewBrokerWithOptions[T any](channelBufferSize, maxEvents int) *Broker[T] { - b := &Broker[T]{ - subs: make(map[chan Event[T]]struct{}), - done: make(chan struct{}), - subCount: 0, - maxEvents: maxEvents, - } - return b +// NewBrokerWithOptions creates a new broker (options ignored for compatibility). +func NewBrokerWithOptions[T any](_, _ int) *Broker[T] { + return NewBroker[T]() } +// Shutdown removes all listeners. func (b *Broker[T]) Shutdown() { - select { - case <-b.done: // Already closed - return - default: - close(b.done) - } - - b.mu.Lock() - defer b.mu.Unlock() - - for ch := range b.subs { - delete(b.subs, ch) - close(ch) - } - - b.subCount = 0 + b.signal.Reset() } -func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] { - b.mu.Lock() - defer b.mu.Unlock() - - select { - case <-b.done: - ch := make(chan Event[T]) - close(ch) - return ch - default: - } - - sub := make(chan Event[T], bufferSize) - b.subs[sub] = struct{}{} - b.subCount++ - - go func() { - <-ctx.Done() - - b.mu.Lock() - defer b.mu.Unlock() - - select { - case <-b.done: - return - default: - } - - delete(b.subs, sub) - close(sub) - b.subCount-- - }() - - return sub +// AddListener registers a callback for events. +func (b *Broker[T]) AddListener(key string, fn func(Event[T])) { + b.signal.AddListener(func(_ context.Context, event Event[T]) { + fn(event) + }, key) } -func (b *Broker[T]) GetSubscriberCount() int { - b.mu.RLock() - defer b.mu.RUnlock() - return b.subCount +// RemoveListener removes a listener by key. +func (b *Broker[T]) RemoveListener(key string) { + b.signal.RemoveListener(key) } -func (b *Broker[T]) Publish(t EventType, payload T) { - b.mu.RLock() - defer b.mu.RUnlock() - - select { - case <-b.done: - return - default: - } - +// Publish emits an event to all listeners without blocking. +func (b *Broker[T]) Publish(ctx context.Context, t EventType, payload T) { event := Event[T]{Type: t, Payload: payload} + go b.signal.Emit(ctx, event) +} - for sub := range b.subs { - select { - case sub <- event: - default: - // Channel is full, subscriber is slow - skip this event - // This prevents blocking the publisher - } - } +// Len returns the number of listeners. +func (b *Broker[T]) Len() int { + return b.signal.Len() } diff --git a/internal/pubsub/events.go b/internal/pubsub/events.go index 016cc10c9f8a51039ce9eeda6210f5f59bdc1e6c..ec0826fdcdb46c9ff249659f8d9206f1fa82d5f6 100644 --- a/internal/pubsub/events.go +++ b/internal/pubsub/events.go @@ -9,7 +9,7 @@ const ( ) type Subscriber[T any] interface { - Subscribe(context.Context) <-chan Event[T] + AddListener(key string, fn func(Event[T])) } type ( @@ -23,7 +23,7 @@ type ( } Publisher[T any] interface { - Publish(EventType, T) + Publish(context.Context, EventType, T) } ) diff --git a/internal/session/session.go b/internal/session/session.go index 3792cc1d576cdd7ebd0dbf0b64670c746718da9c..6e218e3400af2fab6dd614c5d824ebfe821044e3 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -73,7 +73,7 @@ func (s *service) Create(ctx context.Context, title string) (Session, error) { return Session{}, err } session := s.fromDBItem(dbSession) - s.Publish(pubsub.CreatedEvent, session) + s.Publish(ctx, pubsub.CreatedEvent, session) event.SessionCreated() return session, nil } @@ -88,7 +88,7 @@ func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessi return Session{}, err } session := s.fromDBItem(dbSession) - s.Publish(pubsub.CreatedEvent, session) + s.Publish(ctx, pubsub.CreatedEvent, session) return session, nil } @@ -102,7 +102,7 @@ func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string return Session{}, err } session := s.fromDBItem(dbSession) - s.Publish(pubsub.CreatedEvent, session) + s.Publish(ctx, pubsub.CreatedEvent, session) return session, nil } @@ -115,7 +115,7 @@ func (s *service) Delete(ctx context.Context, id string) error { if err != nil { return err } - s.Publish(pubsub.DeletedEvent, session) + s.Publish(ctx, pubsub.DeletedEvent, session) event.SessionDeleted() return nil } @@ -153,7 +153,7 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) { return Session{}, err } session = s.fromDBItem(dbSession) - s.Publish(pubsub.UpdatedEvent, session) + s.Publish(ctx, pubsub.UpdatedEvent, session) return session, nil } diff --git a/internal/tui/tui.go b/internal/tui/tui.go index e91fae5592b8d51963e524d0662d868cbfed6869..9881e58d0c71c23208980d721d796905cc493076 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -3,6 +3,7 @@ package tui import ( "context" "fmt" + "log/slog" "math/rand" "regexp" "slices" @@ -112,6 +113,7 @@ func (a appModel) Init() tea.Cmd { // Update handles incoming messages and updates the application state. func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + slog.Info("app update", "msg", fmt.Sprintf("%T", msg)) var cmds []tea.Cmd var cmd tea.Cmd a.isConfigured = config.HasInitialDataConfig() @@ -324,11 +326,14 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case permissions.PermissionResponseMsg: switch msg.Action { case permissions.PermissionAllow: - a.app.Permissions.Grant(msg.Permission) + // TODO: get the context from somewhere + a.app.Permissions.Grant(context.Background(), msg.Permission) case permissions.PermissionAllowForSession: - a.app.Permissions.GrantPersistent(msg.Permission) + // TODO: get the context from somewhere + a.app.Permissions.GrantPersistent(context.Background(), msg.Permission) case permissions.PermissionDeny: - a.app.Permissions.Deny(msg.Permission) + // TODO: get the context from somewhere + a.app.Permissions.Deny(context.Background(), msg.Permission) } return a, nil case splash.OnboardingCompleteMsg: