Detailed changes
@@ -38,6 +38,7 @@ require (
github.com/invopop/jsonschema v0.13.0
github.com/joho/godotenv v1.5.1
github.com/lucasb-eyer/go-colorful v1.3.0
+ github.com/maniartech/signals v1.3.1
github.com/modelcontextprotocol/go-sdk v1.2.0
github.com/muesli/termenv v0.16.0
github.com/ncruces/go-sqlite3 v0.30.4
@@ -243,6 +243,8 @@ github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQ
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
+github.com/maniartech/signals v1.3.1 h1:pT3dK6x5Un+B6L3ZLAKygEe+L49TClPreyT08vOoHXY=
+github.com/maniartech/signals v1.3.1/go.mod h1:AbE8Yy9ZjKCWNU/VhQ+0Ea9KOaTWHp6aOfdLBe5m1iM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
@@ -92,9 +92,9 @@ type ClientInfo struct {
ConnectedAt time.Time
}
-// SubscribeEvents returns a channel for MCP events
-func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
- return broker.Subscribe(ctx)
+// AddEventListener registers a callback for MCP events.
+func AddEventListener(key string, fn func(pubsub.Event[Event])) {
+ broker.AddListener(key, fn)
}
// GetStates returns the current state of all MCP clients
@@ -139,13 +139,13 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
// Initialize states for all configured MCPs
for name, m := range cfg.MCP {
if m.Disabled {
- updateState(name, StateDisabled, nil, nil, Counts{})
+ updateState(ctx, name, StateDisabled, nil, nil, Counts{})
slog.Debug("skipping disabled mcp", "name", name)
continue
}
// Set initial starting state
- updateState(name, StateStarting, nil, nil, Counts{})
+ updateState(ctx, name, StateStarting, nil, nil, Counts{})
wg.Add(1)
go func(name string, m config.MCPConfig) {
@@ -161,7 +161,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
default:
err = fmt.Errorf("panic: %v", v)
}
- updateState(name, StateError, err, nil, Counts{})
+ updateState(ctx, name, StateError, err, nil, Counts{})
slog.Error("panic in mcp client initialization", "error", err, "name", name)
}
}()
@@ -175,7 +175,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
tools, err := getTools(ctx, session)
if err != nil {
slog.Error("error listing tools", "error", err)
- updateState(name, StateError, err, nil, Counts{})
+ updateState(ctx, name, StateError, err, nil, Counts{})
session.Close()
return
}
@@ -183,7 +183,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
prompts, err := getPrompts(ctx, session)
if err != nil {
slog.Error("error listing prompts", "error", err)
- updateState(name, StateError, err, nil, Counts{})
+ updateState(ctx, name, StateError, err, nil, Counts{})
session.Close()
return
}
@@ -192,7 +192,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
updatePrompts(name, prompts)
sessions.Set(name, session)
- updateState(name, StateConnected, nil, session, Counts{
+ updateState(ctx, name, StateConnected, nil, session, Counts{
Tools: toolCount,
Prompts: len(prompts),
})
@@ -230,20 +230,20 @@ func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, err
if err == nil {
return sess, nil
}
- updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
+ updateState(ctx, name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
sess, err = createSession(ctx, name, m, cfg.Resolver())
if err != nil {
return nil, err
}
- updateState(name, StateConnected, nil, sess, state.Counts)
+ updateState(ctx, name, StateConnected, nil, sess, state.Counts)
sessions.Set(name, sess)
return sess, nil
}
// updateState updates the state of an MCP client and publishes an event
-func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
+func updateState(ctx context.Context, name string, state State, err error, client *mcp.ClientSession, counts Counts) {
info := ClientInfo{
Name: name,
State: state,
@@ -260,7 +260,7 @@ func updateState(name string, state State, err error, client *mcp.ClientSession,
states.Set(name, info)
// Publish state change event
- broker.Publish(pubsub.UpdatedEvent, Event{
+ broker.Publish(ctx, pubsub.UpdatedEvent, Event{
Type: EventStateChanged,
Name: name,
State: state,
@@ -276,7 +276,7 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve
transport, err := createTransport(mcpCtx, m, resolver)
if err != nil {
- updateState(name, StateError, err, nil, Counts{})
+ updateState(ctx, name, StateError, err, nil, Counts{})
slog.Error("error creating mcp client", "error", err, "name", name)
cancel()
cancelTimer.Stop()
@@ -290,14 +290,14 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve
Title: "Crush",
},
&mcp.ClientOptions{
- ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
- broker.Publish(pubsub.UpdatedEvent, Event{
+ ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) {
+ broker.Publish(ctx, pubsub.UpdatedEvent, Event{
Type: EventToolsListChanged,
Name: name,
})
},
- PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
- broker.Publish(pubsub.UpdatedEvent, Event{
+ PromptListChangedHandler: func(ctx context.Context, _ *mcp.PromptListChangedRequest) {
+ broker.Publish(ctx, pubsub.UpdatedEvent, Event{
Type: EventPromptsListChanged,
Name: name,
})
@@ -311,7 +311,7 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve
session, err := client.Connect(mcpCtx, transport, nil)
if err != nil {
err = maybeStdioErr(err, transport)
- updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
+ updateState(ctx, name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
slog.Error("MCP client failed to initialize", "error", err, "name", name)
cancel()
cancelTimer.Stop()
@@ -55,7 +55,7 @@ func RefreshPrompts(ctx context.Context, name string) {
prompts, err := getPrompts(ctx, session)
if err != nil {
- updateState(name, StateError, err, nil, Counts{})
+ updateState(ctx, name, StateError, err, nil, Counts{})
return
}
@@ -63,7 +63,7 @@ func RefreshPrompts(ctx context.Context, name string) {
prev, _ := states.Get(name)
prev.Counts.Prompts = len(prompts)
- updateState(name, StateConnected, nil, session, prev.Counts)
+ updateState(ctx, name, StateConnected, nil, session, prev.Counts)
}
func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
@@ -117,7 +117,7 @@ func RefreshTools(ctx context.Context, name string) {
tools, err := getTools(ctx, session)
if err != nil {
- updateState(name, StateError, err, nil, Counts{})
+ updateState(ctx, name, StateError, err, nil, Counts{})
return
}
@@ -125,7 +125,7 @@ func RefreshTools(ctx context.Context, name string) {
prev, _ := states.Get(name)
prev.Counts.Tools = toolCount
- updateState(name, StateConnected, nil, session, prev.Counts)
+ updateState(ctx, name, StateConnected, nil, session, prev.Counts)
}
func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
@@ -23,11 +23,12 @@ func (m *mockPermissionService) Request(ctx context.Context, req permission.Crea
return true, nil
}
-func (m *mockPermissionService) Grant(req permission.PermissionRequest) {}
+func (m *mockPermissionService) Grant(ctx context.Context, req permission.PermissionRequest) {}
-func (m *mockPermissionService) Deny(req permission.PermissionRequest) {}
+func (m *mockPermissionService) Deny(ctx context.Context, req permission.PermissionRequest) {}
-func (m *mockPermissionService) GrantPersistent(req permission.PermissionRequest) {}
+func (m *mockPermissionService) GrantPersistent(ctx context.Context, req permission.PermissionRequest) {
+}
func (m *mockPermissionService) AutoApproveSession(sessionID string) {}
@@ -37,8 +38,7 @@ func (m *mockPermissionService) SkipRequests() bool {
return false
}
-func (m *mockPermissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[permission.PermissionNotification] {
- return make(<-chan pubsub.Event[permission.PermissionNotification])
+func (m *mockPermissionService) AddNotificationListener(key string, fn func(pubsub.Event[permission.PermissionNotification])) {
}
type mockHistoryService struct {
@@ -52,10 +52,8 @@ type App struct {
config *config.Config
- serviceEventsWG *sync.WaitGroup
- eventsCtx context.Context
- events chan tea.Msg
- tuiWG *sync.WaitGroup
+ // program holds reference to the TUI program for sending events.
+ program *tea.Program
// global context and cleanup functions
globalCtx context.Context
@@ -84,14 +82,8 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
globalCtx: ctx,
config: cfg,
-
- events: make(chan tea.Msg, 100),
- serviceEventsWG: &sync.WaitGroup{},
- tuiWG: &sync.WaitGroup{},
}
- app.setupEvents()
-
// Initialize LSP clients in the background.
app.initLSPClients(ctx)
@@ -223,8 +215,35 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt
}
}(ctx, sess.ID, prompt)
- messageEvents := app.Messages.Subscribe(ctx)
- messageReadBytes := make(map[string]int)
+ var (
+ messageReadBytes = make(map[string]int)
+ messageReadBytesMu sync.Mutex
+ )
+
+ // Register callback to process messages directly.
+ app.Messages.AddListener("non-interactive", func(event pubsub.Event[message.Message]) {
+ msg := event.Payload
+ if msg.SessionID != sess.ID || msg.Role != message.Assistant || len(msg.Parts) == 0 {
+ return
+ }
+ stopSpinner()
+
+ content := msg.Content().String()
+
+ messageReadBytesMu.Lock()
+ readBytes := messageReadBytes[msg.ID]
+ if len(content) >= readBytes {
+ part := content[readBytes:]
+ // Trim leading whitespace. Sometimes the LLM includes leading
+ // formatting and indentation, which we don't want here.
+ if readBytes == 0 {
+ part = strings.TrimLeft(part, " \t")
+ }
+ fmt.Fprint(output, part)
+ messageReadBytes[msg.ID] = len(content)
+ }
+ messageReadBytesMu.Unlock()
+ })
defer func() {
if stderrTTY {
@@ -236,6 +255,7 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt
_, _ = fmt.Fprintln(output)
}()
+ // Wait for agent completion or cancellation.
for {
if stderrTTY {
// HACK: Reinitialize the terminal progress bar on every iteration
@@ -255,29 +275,6 @@ func (app *App) RunNonInteractive(ctx context.Context, output io.Writer, prompt
}
return nil
- case event := <-messageEvents:
- msg := event.Payload
- if msg.SessionID == sess.ID && msg.Role == message.Assistant && len(msg.Parts) > 0 {
- stopSpinner()
-
- content := msg.Content().String()
- readBytes := messageReadBytes[msg.ID]
-
- if len(content) < readBytes {
- slog.Error("Non-interactive: message content is shorter than read bytes", "message_length", len(content), "read_bytes", readBytes)
- return fmt.Errorf("message content is shorter than read bytes: %d < %d", len(content), readBytes)
- }
-
- part := content[readBytes:]
- // Trim leading whitespace. Sometimes the LLM includes leading
- // formatting and intentation, which we don't want here.
- if readBytes == 0 {
- part = strings.TrimLeft(part, " \t")
- }
- fmt.Fprint(output, part)
- messageReadBytes[msg.ID] = len(content)
- }
-
case <-ctx.Done():
stopSpinner()
return ctx.Err()
@@ -292,57 +289,6 @@ func (app *App) UpdateAgentModel(ctx context.Context) error {
return app.AgentCoordinator.UpdateModels(ctx)
}
-func (app *App) setupEvents() {
- ctx, cancel := context.WithCancel(app.globalCtx)
- app.eventsCtx = ctx
- setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events)
- setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events)
- setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events)
- setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events)
- setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events)
- setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events)
- setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events)
- cleanupFunc := func() error {
- cancel()
- app.serviceEventsWG.Wait()
- return nil
- }
- app.cleanupFuncs = append(app.cleanupFuncs, cleanupFunc)
-}
-
-func setupSubscriber[T any](
- ctx context.Context,
- wg *sync.WaitGroup,
- name string,
- subscriber func(context.Context) <-chan pubsub.Event[T],
- outputCh chan<- tea.Msg,
-) {
- wg.Go(func() {
- subCh := subscriber(ctx)
- for {
- select {
- case event, ok := <-subCh:
- if !ok {
- slog.Debug("subscription channel closed", "name", name)
- return
- }
- var msg tea.Msg = event
- select {
- case outputCh <- msg:
- case <-time.After(2 * time.Second):
- slog.Warn("message dropped due to slow consumer", "name", name)
- case <-ctx.Done():
- slog.Debug("subscription cancelled", "name", name)
- return
- }
- case <-ctx.Done():
- slog.Debug("subscription cancelled", "name", name)
- return
- }
- }
- })
-}
-
func (app *App) InitCoderAgent(ctx context.Context) error {
coderAgentCfg := app.config.Agents[config.AgentCoder]
if coderAgentCfg.ID == "" {
@@ -365,36 +311,37 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
return nil
}
-// Subscribe sends events to the TUI as tea.Msgs.
+// Subscribe registers event listeners that send events to the TUI program.
func (app *App) Subscribe(program *tea.Program) {
defer log.RecoverPanic("app.Subscribe", func() {
slog.Info("TUI subscription panic: attempting graceful shutdown")
program.Quit()
})
- app.tuiWG.Add(1)
- tuiCtx, tuiCancel := context.WithCancel(app.globalCtx)
- app.cleanupFuncs = append(app.cleanupFuncs, func() error {
- slog.Debug("Cancelling TUI message handler")
- tuiCancel()
- app.tuiWG.Wait()
- return nil
- })
- defer app.tuiWG.Done()
+ app.program = program
- for {
- select {
- case <-tuiCtx.Done():
- slog.Debug("TUI message handler shutting down")
- return
- case msg, ok := <-app.events:
- if !ok {
- slog.Debug("TUI message channel closed")
- return
- }
- program.Send(msg)
- }
- }
+ // Register listeners that send directly to the program.
+ app.Sessions.AddListener("tui-sessions", func(event pubsub.Event[session.Session]) {
+ program.Send(event)
+ })
+ app.Messages.AddListener("tui-messages", func(event pubsub.Event[message.Message]) {
+ program.Send(event)
+ })
+ app.Permissions.AddListener("tui-permissions", func(event pubsub.Event[permission.PermissionRequest]) {
+ program.Send(event)
+ })
+ app.Permissions.AddNotificationListener("tui-permissions-notifications", func(event pubsub.Event[permission.PermissionNotification]) {
+ program.Send(event)
+ })
+ app.History.AddListener("tui-history", func(event pubsub.Event[history.File]) {
+ program.Send(event)
+ })
+ mcp.AddEventListener("tui-mcp", func(event pubsub.Event[mcp.Event]) {
+ program.Send(event)
+ })
+ AddLSPEventListener("tui-lsp", func(event pubsub.Event[LSPEvent]) {
+ program.Send(event)
+ })
}
// Shutdown performs a graceful shutdown of the application.
@@ -452,9 +399,11 @@ func (app *App) checkForUpdates(ctx context.Context) {
if err != nil || !info.Available() {
return
}
- app.events <- pubsub.UpdateAvailableMsg{
- CurrentVersion: info.Current,
- LatestVersion: info.Latest,
- IsDevelopment: info.IsDevelopment(),
+ if app.program != nil {
+ app.program.Send(pubsub.UpdateAvailableMsg{
+ CurrentVersion: info.Current,
+ LatestVersion: info.Latest,
+ IsDevelopment: info.IsDevelopment(),
+ })
}
}
@@ -41,9 +41,9 @@ var (
lspBroker = pubsub.NewBroker[LSPEvent]()
)
-// SubscribeLSPEvents returns a channel for LSP events
-func SubscribeLSPEvents(ctx context.Context) <-chan pubsub.Event[LSPEvent] {
- return lspBroker.Subscribe(ctx)
+// AddLSPEventListener registers a callback for LSP events.
+func AddLSPEventListener(key string, fn func(pubsub.Event[LSPEvent])) {
+ lspBroker.AddListener(key, fn)
}
// GetLSPStates returns the current state of all LSP clients
@@ -71,7 +71,7 @@ func updateLSPState(name string, state lsp.ServerState, err error, client *lsp.C
lspStates.Set(name, info)
// Publish state change event
- lspBroker.Publish(pubsub.UpdatedEvent, LSPEvent{
+ lspBroker.Publish(context.Background(), pubsub.UpdatedEvent, LSPEvent{
Type: LSPEventStateChanged,
Name: name,
State: state,
@@ -87,7 +87,7 @@ func updateLSPDiagnostics(name string, diagnosticCount int) {
lspStates.Set(name, info)
// Publish diagnostics change event
- lspBroker.Publish(pubsub.UpdatedEvent, LSPEvent{
+ lspBroker.Publish(context.Background(), pubsub.UpdatedEvent, LSPEvent{
Type: LSPEventDiagnosticsChanged,
Name: name,
State: info.State,
@@ -120,7 +120,7 @@ func (s *service) createWithVersion(ctx context.Context, sessionID, path, conten
}
file = s.fromDBItem(dbFile)
- s.Publish(pubsub.CreatedEvent, file)
+ s.Publish(ctx, pubsub.CreatedEvent, file)
return file, nil
}
@@ -179,7 +179,7 @@ func (s *service) Delete(ctx context.Context, id string) error {
if err != nil {
return err
}
- s.Publish(pubsub.DeletedEvent, file)
+ s.Publish(ctx, pubsub.DeletedEvent, file)
return nil
}
@@ -53,7 +53,7 @@ func (s *service) Delete(ctx context.Context, id string) error {
}
// Clone the message before publishing to avoid race conditions with
// concurrent modifications to the Parts slice.
- s.Publish(pubsub.DeletedEvent, message.Clone())
+ s.Publish(ctx, pubsub.DeletedEvent, message.Clone())
return nil
}
@@ -89,7 +89,7 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
}
// Clone the message before publishing to avoid race conditions with
// concurrent modifications to the Parts slice.
- s.Publish(pubsub.CreatedEvent, message.Clone())
+ s.Publish(ctx, pubsub.CreatedEvent, message.Clone())
return message, nil
}
@@ -130,7 +130,7 @@ func (s *service) Update(ctx context.Context, message Message) error {
message.UpdatedAt = time.Now().Unix()
// Clone the message before publishing to avoid race conditions with
// concurrent modifications to the Parts slice.
- s.Publish(pubsub.UpdatedEvent, message.Clone())
+ s.Publish(ctx, pubsub.UpdatedEvent, message.Clone())
return nil
}
@@ -44,14 +44,14 @@ type PermissionRequest struct {
type Service interface {
pubsub.Subscriber[PermissionRequest]
- GrantPersistent(permission PermissionRequest)
- Grant(permission PermissionRequest)
- Deny(permission PermissionRequest)
+ GrantPersistent(ctx context.Context, permission PermissionRequest)
+ Grant(ctx context.Context, permission PermissionRequest)
+ Deny(ctx context.Context, permission PermissionRequest)
Request(ctx context.Context, opts CreatePermissionRequest) (bool, error)
AutoApproveSession(sessionID string)
SetSkipRequests(skip bool)
SkipRequests() bool
- SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[PermissionNotification]
+ AddNotificationListener(key string, fn func(pubsub.Event[PermissionNotification]))
}
type permissionService struct {
@@ -73,8 +73,8 @@ type permissionService struct {
activeRequestMu sync.Mutex
}
-func (s *permissionService) GrantPersistent(permission PermissionRequest) {
- s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{
+func (s *permissionService) GrantPersistent(ctx context.Context, permission PermissionRequest) {
+ s.notificationBroker.Publish(ctx, pubsub.CreatedEvent, PermissionNotification{
ToolCallID: permission.ToolCallID,
Granted: true,
})
@@ -94,8 +94,8 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) {
s.activeRequestMu.Unlock()
}
-func (s *permissionService) Grant(permission PermissionRequest) {
- s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{
+func (s *permissionService) Grant(ctx context.Context, permission PermissionRequest) {
+ s.notificationBroker.Publish(ctx, pubsub.CreatedEvent, PermissionNotification{
ToolCallID: permission.ToolCallID,
Granted: true,
})
@@ -111,8 +111,8 @@ func (s *permissionService) Grant(permission PermissionRequest) {
s.activeRequestMu.Unlock()
}
-func (s *permissionService) Deny(permission PermissionRequest) {
- s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{
+func (s *permissionService) Deny(ctx context.Context, permission PermissionRequest) {
+ s.notificationBroker.Publish(ctx, pubsub.CreatedEvent, PermissionNotification{
ToolCallID: permission.ToolCallID,
Granted: false,
Denied: true,
@@ -135,7 +135,7 @@ func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRe
}
// tell the UI that a permission was requested
- s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{
+ s.notificationBroker.Publish(ctx, pubsub.CreatedEvent, PermissionNotification{
ToolCallID: opts.ToolCallID,
})
s.requestMu.Lock()
@@ -206,7 +206,7 @@ func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRe
defer s.pendingRequests.Del(permission.ID)
// Publish the request
- s.Publish(pubsub.CreatedEvent, permission)
+ s.Publish(ctx, pubsub.CreatedEvent, permission)
select {
case <-ctx.Done():
@@ -222,8 +222,8 @@ func (s *permissionService) AutoApproveSession(sessionID string) {
s.autoApproveSessionsMu.Unlock()
}
-func (s *permissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[PermissionNotification] {
- return s.notificationBroker.Subscribe(ctx)
+func (s *permissionService) AddNotificationListener(key string, fn func(pubsub.Event[PermissionNotification])) {
+ s.notificationBroker.AddListener(key, fn)
}
func (s *permissionService) SetSkipRequests(skip bool) {
@@ -4,6 +4,7 @@ import (
"sync"
"testing"
+ "github.com/charmbracelet/crush/internal/pubsub"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -114,7 +115,10 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
- events := service.Subscribe(t.Context())
+ events := make(chan pubsub.Event[PermissionRequest], 10)
+ service.AddListener("test", func(event pubsub.Event[PermissionRequest]) {
+ events <- event
+ })
go func() {
defer wg.Done()
@@ -125,7 +129,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
event := <-events
permissionReq = event.Payload
- service.GrantPersistent(permissionReq)
+ service.GrantPersistent(t.Context(), permissionReq)
wg.Wait()
assert.True(t, result1, "First request should be granted")
@@ -155,7 +159,10 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
Path: "/tmp/test.txt",
}
- events := service.Subscribe(t.Context())
+ events := make(chan pubsub.Event[PermissionRequest], 10)
+ service.AddListener("test", func(event pubsub.Event[PermissionRequest]) {
+ events <- event
+ })
var result1 bool
var wg sync.WaitGroup
@@ -167,7 +174,7 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
event := <-events
permissionReq = event.Payload
- service.Grant(permissionReq)
+ service.Grant(t.Context(), permissionReq)
wg.Wait()
assert.True(t, result1, "First request should be granted")
@@ -179,14 +186,17 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
event = <-events
permissionReq = event.Payload
- service.Deny(permissionReq)
+ service.Deny(t.Context(), permissionReq)
wg.Wait()
assert.False(t, result2, "Second request should be denied")
})
t.Run("Concurrent requests with different outcomes", func(t *testing.T) {
service := NewPermissionService("/tmp", false, []string{})
- events := service.Subscribe(t.Context())
+ events := make(chan pubsub.Event[PermissionRequest], 10)
+ service.AddListener("test", func(event pubsub.Event[PermissionRequest]) {
+ events <- event
+ })
var wg sync.WaitGroup
results := make([]bool, 3)
@@ -228,11 +238,11 @@ func TestPermissionService_SequentialProperties(t *testing.T) {
event := <-events
switch event.Payload.ToolName {
case "tool1":
- service.Grant(event.Payload)
+ service.Grant(t.Context(), event.Payload)
case "tool2":
- service.GrantPersistent(event.Payload)
+ service.GrantPersistent(t.Context(), event.Payload)
case "tool3":
- service.Deny(event.Payload)
+ service.Deny(t.Context(), event.Payload)
}
}
wg.Wait()
@@ -2,112 +2,51 @@ package pubsub
import (
"context"
- "sync"
-)
-const bufferSize = 64
+ "github.com/maniartech/signals"
+)
+// Broker is a generic pub/sub broker backed by maniartech/signals.
type Broker[T any] struct {
- subs map[chan Event[T]]struct{}
- mu sync.RWMutex
- done chan struct{}
- subCount int
- maxEvents int
+ signal *signals.AsyncSignal[Event[T]]
}
+// NewBroker creates a new broker.
func NewBroker[T any]() *Broker[T] {
- return NewBrokerWithOptions[T](bufferSize, 1000)
+ return &Broker[T]{
+ signal: signals.New[Event[T]](),
+ }
}
-func NewBrokerWithOptions[T any](channelBufferSize, maxEvents int) *Broker[T] {
- b := &Broker[T]{
- subs: make(map[chan Event[T]]struct{}),
- done: make(chan struct{}),
- subCount: 0,
- maxEvents: maxEvents,
- }
- return b
+// NewBrokerWithOptions creates a new broker (options ignored for compatibility).
+func NewBrokerWithOptions[T any](_, _ int) *Broker[T] {
+ return NewBroker[T]()
}
+// Shutdown removes all listeners.
func (b *Broker[T]) Shutdown() {
- select {
- case <-b.done: // Already closed
- return
- default:
- close(b.done)
- }
-
- b.mu.Lock()
- defer b.mu.Unlock()
-
- for ch := range b.subs {
- delete(b.subs, ch)
- close(ch)
- }
-
- b.subCount = 0
+ b.signal.Reset()
}
-func (b *Broker[T]) Subscribe(ctx context.Context) <-chan Event[T] {
- b.mu.Lock()
- defer b.mu.Unlock()
-
- select {
- case <-b.done:
- ch := make(chan Event[T])
- close(ch)
- return ch
- default:
- }
-
- sub := make(chan Event[T], bufferSize)
- b.subs[sub] = struct{}{}
- b.subCount++
-
- go func() {
- <-ctx.Done()
-
- b.mu.Lock()
- defer b.mu.Unlock()
-
- select {
- case <-b.done:
- return
- default:
- }
-
- delete(b.subs, sub)
- close(sub)
- b.subCount--
- }()
-
- return sub
+// AddListener registers a callback for events.
+func (b *Broker[T]) AddListener(key string, fn func(Event[T])) {
+ b.signal.AddListener(func(_ context.Context, event Event[T]) {
+ fn(event)
+ }, key)
}
-func (b *Broker[T]) GetSubscriberCount() int {
- b.mu.RLock()
- defer b.mu.RUnlock()
- return b.subCount
+// RemoveListener removes a listener by key.
+func (b *Broker[T]) RemoveListener(key string) {
+ b.signal.RemoveListener(key)
}
-func (b *Broker[T]) Publish(t EventType, payload T) {
- b.mu.RLock()
- defer b.mu.RUnlock()
-
- select {
- case <-b.done:
- return
- default:
- }
-
+// Publish emits an event to all listeners without blocking.
+func (b *Broker[T]) Publish(ctx context.Context, t EventType, payload T) {
event := Event[T]{Type: t, Payload: payload}
+ go b.signal.Emit(ctx, event)
+}
- for sub := range b.subs {
- select {
- case sub <- event:
- default:
- // Channel is full, subscriber is slow - skip this event
- // This prevents blocking the publisher
- }
- }
+// Len returns the number of listeners.
+func (b *Broker[T]) Len() int {
+ return b.signal.Len()
}
@@ -9,7 +9,7 @@ const (
)
type Subscriber[T any] interface {
- Subscribe(context.Context) <-chan Event[T]
+ AddListener(key string, fn func(Event[T]))
}
type (
@@ -23,7 +23,7 @@ type (
}
Publisher[T any] interface {
- Publish(EventType, T)
+ Publish(context.Context, EventType, T)
}
)
@@ -73,7 +73,7 @@ func (s *service) Create(ctx context.Context, title string) (Session, error) {
return Session{}, err
}
session := s.fromDBItem(dbSession)
- s.Publish(pubsub.CreatedEvent, session)
+ s.Publish(ctx, pubsub.CreatedEvent, session)
event.SessionCreated()
return session, nil
}
@@ -88,7 +88,7 @@ func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessi
return Session{}, err
}
session := s.fromDBItem(dbSession)
- s.Publish(pubsub.CreatedEvent, session)
+ s.Publish(ctx, pubsub.CreatedEvent, session)
return session, nil
}
@@ -102,7 +102,7 @@ func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string
return Session{}, err
}
session := s.fromDBItem(dbSession)
- s.Publish(pubsub.CreatedEvent, session)
+ s.Publish(ctx, pubsub.CreatedEvent, session)
return session, nil
}
@@ -115,7 +115,7 @@ func (s *service) Delete(ctx context.Context, id string) error {
if err != nil {
return err
}
- s.Publish(pubsub.DeletedEvent, session)
+ s.Publish(ctx, pubsub.DeletedEvent, session)
event.SessionDeleted()
return nil
}
@@ -153,7 +153,7 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) {
return Session{}, err
}
session = s.fromDBItem(dbSession)
- s.Publish(pubsub.UpdatedEvent, session)
+ s.Publish(ctx, pubsub.UpdatedEvent, session)
return session, nil
}
@@ -3,6 +3,7 @@ package tui
import (
"context"
"fmt"
+ "log/slog"
"math/rand"
"regexp"
"slices"
@@ -112,6 +113,7 @@ func (a appModel) Init() tea.Cmd {
// Update handles incoming messages and updates the application state.
func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ slog.Info("app update", "msg", fmt.Sprintf("%T", msg))
var cmds []tea.Cmd
var cmd tea.Cmd
a.isConfigured = config.HasInitialDataConfig()
@@ -324,11 +326,14 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
case permissions.PermissionResponseMsg:
switch msg.Action {
case permissions.PermissionAllow:
- a.app.Permissions.Grant(msg.Permission)
+ // TODO: get the context from somewhere
+ a.app.Permissions.Grant(context.Background(), msg.Permission)
case permissions.PermissionAllowForSession:
- a.app.Permissions.GrantPersistent(msg.Permission)
+ // TODO: get the context from somewhere
+ a.app.Permissions.GrantPersistent(context.Background(), msg.Permission)
case permissions.PermissionDeny:
- a.app.Permissions.Deny(msg.Permission)
+ // TODO: get the context from somewhere
+ a.app.Permissions.Deny(context.Background(), msg.Permission)
}
return a, nil
case splash.OnboardingCompleteMsg: