diff --git a/README.md b/README.md index 3c1e78e60d4c4a614129560251274292202ce850..f3ddaf7afa23014f3b0bb75297e0cac3da81318b 100644 --- a/README.md +++ b/README.md @@ -368,6 +368,40 @@ which do expand. Crush has preliminary support for hooks. For details, see [the hook guide](./docs/hooks/). +### Sharing a workspace across clients + +When Crush is run against a shared backend (for example two TUIs talking to +the same `crush serve`), clients are grouped into **workspaces** keyed by +their resolved `--cwd`. Two clients with the same `--cwd` join the same +underlying workspace, so they share the session list, message history, +permission queue, LSP, and MCP state. + +Joining is implicit: pointing a second client at the same working directory +attaches it to the existing workspace. Each new invocation, however, starts +in its own fresh session by default. To pick up the conversation another +client already has open, use the session manager (the session picker) and +select it. Sessions surface two signals there: + +- `IsBusy` is set while an agent turn is in flight for that session. +- `AttachedClients` reports how many clients are currently viewing it. + +A non-zero `AttachedClients` (often combined with `IsBusy`) is the cue that a +session is "in progress" on another client and joining it will mirror that +view live. + +The first client to create a workspace fixes its process-wide flags. In +particular, `--yolo` and `--debug` follow a **first-wins** rule: later +clients that arrive at the same `--cwd` with different values for those +flags do not change the running workspace. A debug log line is emitted +recording the mismatch, and the workspace keeps the flags it was created +with. + +A workspace lives as long as at least one client has an SSE event stream +open against it. When the last stream disconnects, the workspace is torn +down. There is a short grace window right after `POST /v1/workspaces` so a +client that has created the workspace but not yet opened its event stream +does not get reaped before it can attach. + ### Ignoring Files Crush respects `.gitignore` files by default, but you can also create a diff --git a/internal/agent/tools/bash_test.go b/internal/agent/tools/bash_test.go index b9c4a13adbb1f948c9fb85f5cb762bd79906bd68..40169e84e6c691e0ee8272cfcab71dde8ac86762 100644 --- a/internal/agent/tools/bash_test.go +++ b/internal/agent/tools/bash_test.go @@ -21,11 +21,13 @@ func (m *mockBashPermissionService) Request(ctx context.Context, req permission. return true, nil } -func (m *mockBashPermissionService) Grant(req permission.PermissionRequest) {} +func (m *mockBashPermissionService) Grant(req permission.PermissionRequest) bool { return true } -func (m *mockBashPermissionService) Deny(req permission.PermissionRequest) {} +func (m *mockBashPermissionService) Deny(req permission.PermissionRequest) bool { return true } -func (m *mockBashPermissionService) GrantPersistent(req permission.PermissionRequest) {} +func (m *mockBashPermissionService) GrantPersistent(req permission.PermissionRequest) bool { + return true +} func (m *mockBashPermissionService) AutoApproveSession(sessionID string) {} @@ -90,11 +92,13 @@ func (m *recordingPermissionService) Request(ctx context.Context, req permission return m.allow, nil } -func (m *recordingPermissionService) Grant(req permission.PermissionRequest) {} +func (m *recordingPermissionService) Grant(req permission.PermissionRequest) bool { return true } -func (m *recordingPermissionService) Deny(req permission.PermissionRequest) {} +func (m *recordingPermissionService) Deny(req permission.PermissionRequest) bool { return true } -func (m *recordingPermissionService) GrantPersistent(req permission.PermissionRequest) {} +func (m *recordingPermissionService) GrantPersistent(req permission.PermissionRequest) bool { + return true +} func (m *recordingPermissionService) AutoApproveSession(sessionID string) {} diff --git a/internal/agent/tools/multiedit_test.go b/internal/agent/tools/multiedit_test.go index 1ca2a6f7689e345ac944889f1f92284de0652f90..fe56ad6859e896c7a39cd487f7b55e8f59dcbd2f 100644 --- a/internal/agent/tools/multiedit_test.go +++ b/internal/agent/tools/multiedit_test.go @@ -20,11 +20,13 @@ func (m *mockPermissionService) Request(ctx context.Context, req permission.Crea return true, nil } -func (m *mockPermissionService) Grant(req permission.PermissionRequest) {} +func (m *mockPermissionService) Grant(req permission.PermissionRequest) bool { return true } -func (m *mockPermissionService) Deny(req permission.PermissionRequest) {} +func (m *mockPermissionService) Deny(req permission.PermissionRequest) bool { return true } -func (m *mockPermissionService) GrantPersistent(req permission.PermissionRequest) {} +func (m *mockPermissionService) GrantPersistent(req permission.PermissionRequest) bool { + return true +} func (m *mockPermissionService) AutoApproveSession(sessionID string) {} diff --git a/internal/agent/tools/view_test.go b/internal/agent/tools/view_test.go index de853f6cc3f1a0a5b72808983f0fe628f5145f59..43c793a39e94064a43dd27a954a5ed9cbfb572f8 100644 --- a/internal/agent/tools/view_test.go +++ b/internal/agent/tools/view_test.go @@ -216,11 +216,13 @@ func (m *mockViewPermissionService) Request(ctx context.Context, req permission. return true, nil } -func (m *mockViewPermissionService) Grant(req permission.PermissionRequest) {} +func (m *mockViewPermissionService) Grant(req permission.PermissionRequest) bool { return true } -func (m *mockViewPermissionService) Deny(req permission.PermissionRequest) {} +func (m *mockViewPermissionService) Deny(req permission.PermissionRequest) bool { return true } -func (m *mockViewPermissionService) GrantPersistent(req permission.PermissionRequest) {} +func (m *mockViewPermissionService) GrantPersistent(req permission.PermissionRequest) bool { + return true +} func (m *mockViewPermissionService) AutoApproveSession(sessionID string) {} diff --git a/internal/app/testing.go b/internal/app/testing.go new file mode 100644 index 0000000000000000000000000000000000000000..f17e94cfa99411b4594fce72bd894cc5fba4c4fd --- /dev/null +++ b/internal/app/testing.go @@ -0,0 +1,69 @@ +package app + +import ( + "context" + "sync" + + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent/notify" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/pubsub" +) + +// NewForTest constructs a minimal [App] suitable for in-process tests +// that need a working event broker and permission service without +// booting a real config, database, LSP, MCP, or agent coordinator. +// +// The returned App has: +// +// - A live `events` broker that [App.SendEvent] publishes to and +// [App.Events] subscribes from. +// - A real [permission.Service] whose request and notification +// brokers are fanned into the events broker, so subscribers to +// [App.Events] observe the same permission events the production +// wiring would deliver to SSE clients. +// - An [App.agentNotifications] broker. +// +// The caller owns lifetime: cancel ctx (or call [App.Shutdown]) to +// tear down the fan-in goroutines and the events broker. +func NewForTest(ctx context.Context) *App { + app := &App{ + Permissions: permission.NewPermissionService("", false, nil), + globalCtx: ctx, + events: pubsub.NewBroker[tea.Msg](), + serviceEventsWG: &sync.WaitGroup{}, + tuiWG: &sync.WaitGroup{}, + agentNotifications: pubsub.NewBroker[notify.Notification](), + } + + eventsCtx, cancel := context.WithCancel(ctx) + app.eventsCtx = eventsCtx + setupSubscriber(eventsCtx, app.serviceEventsWG, "permissions", + app.Permissions.Subscribe, app.events) + setupSubscriber(eventsCtx, app.serviceEventsWG, "permissions-notifications", + app.Permissions.SubscribeNotifications, app.events) + setupSubscriber(eventsCtx, app.serviceEventsWG, "agent-notifications", + app.agentNotifications.Subscribe, app.events) + app.cleanupFuncs = append(app.cleanupFuncs, func(context.Context) error { + cancel() + app.serviceEventsWG.Wait() + app.events.Shutdown() + return nil + }) + return app +} + +// ShutdownForTest tears down the App's event broker and fan-in +// goroutines. It is safe to call multiple times. +// +// Use this in tests instead of [App.Shutdown], which drives a full +// production shutdown path (database release, LSP teardown, MCP +// shutdown) that synthetic test apps cannot satisfy. +func (app *App) ShutdownForTest() { + for _, cleanup := range app.cleanupFuncs { + if cleanup != nil { + _ = cleanup(context.Background()) + } + } + app.cleanupFuncs = nil +} diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 5a593ee30b014848e982e6075a5ae64c1d17eab7..3dfeb9bf9b79bba7650ebbf8d30f58f3b7d9a41a 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -8,7 +8,10 @@ import ( "errors" "fmt" "log/slog" + "path/filepath" "runtime" + "sync" + "time" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" @@ -29,19 +32,65 @@ var ( ErrPathRequired = errors.New("path is required") ErrInvalidPermissionAction = errors.New("invalid permission action") ErrUnknownCommand = errors.New("unknown command") + ErrInvalidClientID = errors.New("invalid client_id") + ErrClientNotAttached = errors.New("client not attached") ) +// DefaultCreateGrace is the window in which a client must open an SSE +// stream after creating a workspace before its creation hold is +// released. Exposed as a package variable so tests can shorten it. +var DefaultCreateGrace = 30 * time.Second + // ShutdownFunc is called when the backend needs to trigger a server // shutdown (e.g. when the last workspace is removed). type ShutdownFunc func() // Backend provides transport-agnostic business logic for the Crush // server. It manages workspaces and delegates to [app.App] services. +// +// Locking order: when both [Backend.mu] and [Workspace.clientsMu] are +// held at once, [Backend.mu] is acquired first. Detach paths +// ([detachStream], [releaseHoldLocked], [expireHold]) only hold +// [Workspace.clientsMu] briefly, drop it, then call [teardown] which +// takes [Backend.mu] (and then re-takes [Workspace.clientsMu] to +// re-check that the workspace has not been re-claimed). This avoids +// the AB/BA hazard with [CreateWorkspace], which holds [Backend.mu] +// while calling [registerClient] so that a workspace cannot be torn +// down beneath it. type Backend struct { workspaces *csync.Map[string, *Workspace] - cfg *config.ConfigStore - ctx context.Context - shutdownFn ShutdownFunc + // pathIndex maps a resolved absolute workspace path to its + // workspace ID. Reads and writes are serialised via mu so + // concurrent CreateWorkspace calls at the same path deduplicate + // deterministically. + pathIndex map[string]string + mu sync.Mutex + + cfg *config.ConfigStore + ctx context.Context + shutdownFn ShutdownFunc + createGrace time.Duration +} + +// clientState tracks one client's claim on a workspace. +// +// - streams counts the number of live SSE event streams the client +// currently has open against the workspace. +// - holdTimer is non-nil iff the client created the workspace but has +// not yet attached an SSE stream; it fires after createGrace and +// releases the hold. +// - currentSessionID records which session this client is currently +// viewing. Empty string means the client has no session selected +// (e.g. the landing screen). Cleared automatically when the +// clientState entry is removed. +// +// streams and holdTimer are mutually exclusive in practice (the hold +// timer is stopped the moment an SSE stream attaches), but both being +// zero/nil means the entry has been released and should be removed. +type clientState struct { + streams int + holdTimer *time.Timer + currentSessionID string } // Workspace represents a running [app.App] workspace with its @@ -53,18 +102,57 @@ type Workspace struct { Cfg *config.ConfigStore Env []string Skills *skills.Manager + + // resolvedPath is the path used as the dedup key in + // Backend.pathIndex. It is filepath.EvalSymlinks(filepath.Abs(Path)) + // with fallback to the cleaned absolute path. + resolvedPath string + + // clientsMu guards clients. It is held only briefly (no IO). + clientsMu sync.Mutex + // clients tracks each client's claim on this workspace. Refcount + // is a derived value: len(clients). + clients map[string]*clientState + + // shutdownFn is the function invoked by [Backend.teardown] to + // release the workspace's underlying resources. It defaults to the + // embedded [app.App.Shutdown]; tests may override it to avoid + // driving a full [app.App] through shutdown. + shutdownFn func() +} + +// invokeShutdown calls the workspace shutdown hook if set, falling +// back to the embedded [app.App.Shutdown] when not. +func (w *Workspace) invokeShutdown() { + if w.shutdownFn != nil { + w.shutdownFn() + return + } + if w.App != nil { + w.Shutdown() + } } // New creates a new [Backend]. func New(ctx context.Context, cfg *config.ConfigStore, shutdownFn ShutdownFunc) *Backend { return &Backend{ - workspaces: csync.NewMap[string, *Workspace](), - cfg: cfg, - ctx: ctx, - shutdownFn: shutdownFn, + workspaces: csync.NewMap[string, *Workspace](), + pathIndex: make(map[string]string), + cfg: cfg, + ctx: ctx, + shutdownFn: shutdownFn, + createGrace: DefaultCreateGrace, } } +// SetCreateGrace overrides the create-grace window. Intended for tests +// that need short timeouts. +func (b *Backend) SetCreateGrace(d time.Duration) { + b.mu.Lock() + defer b.mu.Unlock() + b.createGrace = d +} + // GetWorkspace retrieves a workspace by ID. func (b *Backend) GetWorkspace(id string) (*Workspace, error) { ws, ok := b.workspaces.Get(id) @@ -84,12 +172,46 @@ func (b *Backend) ListWorkspaces() []proto.Workspace { } // CreateWorkspace initializes a new workspace from the given -// parameters. It creates the config, database connection, and -// [app.App] instance. +// parameters, or returns an existing workspace if one already exists at +// the same resolved path (first-wins semantics). +// +// args.ClientID must be a valid UUID identifying the calling client; +// the resulting workspace registers a creation hold on behalf of that +// client which is released either by the first SSE attach (which +// converts it into a stream claim) or by the grace window expiring. func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Workspace, error) { if args.Path == "" { return nil, proto.Workspace{}, ErrPathRequired } + clientID, err := validateClientID(args.ClientID) + if err != nil { + return nil, proto.Workspace{}, err + } + + key, err := resolveWorkspaceKey(args.Path) + if err != nil { + return nil, proto.Workspace{}, fmt.Errorf("failed to resolve workspace path: %w", err) + } + + b.mu.Lock() + if existingID, ok := b.pathIndex[key]; ok { + if ws, found := b.workspaces.Get(existingID); found { + // Hold b.mu while registering: teardown also + // acquires b.mu before tearing the workspace + // down, so this guarantees the workspace we + // return cannot be torn out from under us + // between lookup and registerClient. Lock order + // here is b.mu -> ws.clientsMu. + logFirstWinsMismatch(ws, args) + b.registerClient(ws, clientID) + b.mu.Unlock() + return ws, workspaceToProto(ws), nil + } + // pathIndex referenced a workspace that has since been + // removed; clean the stale entry and fall through. + delete(b.pathIndex, key) + } + b.mu.Unlock() id := uuid.New().String() cfg, err := config.Init(args.Path, args.DataDir, args.Debug) @@ -103,7 +225,7 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works return nil, proto.Workspace{}, fmt.Errorf("failed to create data directory: %w", err) } - conn, err := db.Connect(b.ctx, cfg.Config().Options.DataDirectory) + conn, err := db.Connect(b.ctx, cfg.Config().Options.DataDirectory, db.WithDataDirLock(true)) if err != nil { return nil, proto.Workspace{}, fmt.Errorf("failed to connect to database: %w", err) } @@ -125,15 +247,39 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works } ws := &Workspace{ - App: appWorkspace, - ID: id, - Path: args.Path, - Cfg: cfg, - Env: args.Env, - Skills: skillsMgr, + App: appWorkspace, + ID: id, + Path: args.Path, + Cfg: cfg, + Env: args.Env, + Skills: skillsMgr, + resolvedPath: key, + clients: make(map[string]*clientState), } + b.mu.Lock() + // Re-check the index under the lock: a concurrent caller may have + // won the race between the initial unlock and here. + if existingID, ok := b.pathIndex[key]; ok { + if existing, found := b.workspaces.Get(existingID); found { + // Register under b.mu so teardown cannot run + // between lookup and registerClient. Lock order + // is b.mu -> ws.clientsMu. + logFirstWinsMismatch(existing, args) + b.registerClient(existing, clientID) + b.mu.Unlock() + ws.invokeShutdown() + return existing, workspaceToProto(existing), nil + } + delete(b.pathIndex, key) + } b.workspaces.Set(id, ws) + b.pathIndex[key] = id + // Register the originating client's hold while still holding + // b.mu so the workspace is observable with its claim from the + // moment it appears in the index. + b.registerClient(ws, clientID) + b.mu.Unlock() if args.Version != "" && args.Version != version.Version { slog.Warn( @@ -147,18 +293,7 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works ))) } - result := proto.Workspace{ - ID: id, - Path: args.Path, - DataDir: cfg.Config().Options.DataDirectory, - Debug: cfg.Config().Options.Debug, - YOLO: cfg.Overrides().SkipPermissionRequests, - Config: cfg.Config(), - Env: args.Env, - Skills: skillStatesToProto(skillStates), - } - - return ws, result, nil + return ws, workspaceToProto(ws), nil } // skillsDiscoveryConfig adapts a *config.ConfigStore to the @@ -203,21 +338,261 @@ func skillStatesToProto(states []*skills.SkillState) []proto.SkillState { return out } -// DeleteWorkspace shuts down and removes a workspace. If it was the -// last workspace, the shutdown callback is invoked. -func (b *Backend) DeleteWorkspace(id string) { - ws, ok := b.workspaces.Get(id) - if ok { - ws.Shutdown() +// AttachClient registers a new SSE stream for the given client on the +// workspace. The stream's deferred cleanup must call DetachClient with +// the same arguments to release the claim. +// +// The lookup and the clients-map mutation are performed under +// [Backend.mu] so that AttachClient cannot race with [Backend.teardown]: +// teardown also holds [Backend.mu] while removing the workspace from +// b.workspaces, so once AttachClient observes the workspace and takes +// ws.clientsMu (under b.mu), no concurrent teardown can succeed without +// re-checking the (now non-empty) clients map. Lock order is the +// canonical b.mu -> ws.clientsMu. +func (b *Backend) AttachClient(workspaceID, clientID string) error { + if _, err := validateClientID(clientID); err != nil { + return err + } + + b.mu.Lock() + defer b.mu.Unlock() + ws, ok := b.workspaces.Get(workspaceID) + if !ok { + return ErrWorkspaceNotFound + } + + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + cs, ok := ws.clients[clientID] + if !ok { + // Defensive: SSE attach without a prior CreateWorkspace by + // this client still installs a stream claim so the stream + // stays alive for its duration. + ws.clients[clientID] = &clientState{streams: 1} + return nil + } + if cs.holdTimer != nil { + cs.holdTimer.Stop() + cs.holdTimer = nil + } + cs.streams++ + return nil +} + +// DetachClient releases one SSE stream's hold on the workspace. If the +// client has no other streams and no pending creation hold, its claim +// is removed and the workspace is torn down once refcount hits zero. +func (b *Backend) DetachClient(workspaceID, clientID string) { + ws, ok := b.workspaces.Get(workspaceID) + if !ok { + return + } + b.detachStream(ws, clientID) +} + +// releaseHold releases the creation hold for a client, if any. Active +// stream claims are unaffected. Idempotent: returns nil if the +// workspace or the client's hold no longer exist. +func (b *Backend) releaseHold(workspaceID, clientID string) error { + if _, err := validateClientID(clientID); err != nil { + return err + } + ws, ok := b.workspaces.Get(workspaceID) + if !ok { + return nil + } + b.releaseHoldLocked(ws, clientID) + return nil +} + +// registerClient installs (idempotently) the given client's claim on +// the workspace and starts a grace timer if the entry is fresh. +func (b *Backend) registerClient(ws *Workspace, clientID string) { + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + if _, ok := ws.clients[clientID]; ok { + // Idempotent: a duplicate CreateWorkspace from the same + // client does not add a second claim. + return + } + cs := &clientState{} + cs.holdTimer = time.AfterFunc(b.createGrace, func() { + b.expireHold(ws, clientID, cs) + }) + ws.clients[clientID] = cs +} + +// expireHold is the body of the grace timer. It runs in its own +// goroutine and races against AttachClient/releaseHold; the timer +// stays valid only while the entry's holdTimer still points at it. +func (b *Backend) expireHold(ws *Workspace, clientID string, timer *clientState) { + ws.clientsMu.Lock() + cs, ok := ws.clients[clientID] + if !ok || cs != timer || cs.holdTimer == nil || cs.streams > 0 { + ws.clientsMu.Unlock() + return + } + cs.holdTimer = nil + delete(ws.clients, clientID) + teardown := len(ws.clients) == 0 + ws.clientsMu.Unlock() + if teardown { + b.teardown(ws) + } +} + +func (b *Backend) releaseHoldLocked(ws *Workspace, clientID string) { + ws.clientsMu.Lock() + cs, ok := ws.clients[clientID] + if !ok { + ws.clientsMu.Unlock() + return + } + if cs.holdTimer != nil { + cs.holdTimer.Stop() + cs.holdTimer = nil + } + teardown := false + if cs.streams == 0 { + delete(ws.clients, clientID) + teardown = len(ws.clients) == 0 + } + ws.clientsMu.Unlock() + if teardown { + b.teardown(ws) + } +} + +func (b *Backend) detachStream(ws *Workspace, clientID string) { + ws.clientsMu.Lock() + cs, ok := ws.clients[clientID] + if !ok { + ws.clientsMu.Unlock() + return + } + if cs.streams > 0 { + cs.streams-- + } + teardown := false + if cs.streams == 0 && cs.holdTimer == nil { + delete(ws.clients, clientID) + teardown = len(ws.clients) == 0 + } + ws.clientsMu.Unlock() + if teardown { + b.teardown(ws) } - b.workspaces.Del(id) +} + +// teardown removes the workspace from the index, shuts down its +// underlying [app.App], and triggers a server shutdown if it was the +// last workspace alive. +// +// Callers reach teardown after observing len(ws.clients) == 0 while +// holding ws.clientsMu and then releasing it. Between that release +// and the b.mu.Lock below, a concurrent CreateWorkspace may have +// re-registered a client (CreateWorkspace holds b.mu while doing so, +// so it is mutually exclusive with this critical section). teardown +// re-checks under both locks (in the canonical b.mu -> ws.clientsMu +// order) and aborts if the workspace has been re-claimed. +func (b *Backend) teardown(ws *Workspace) { + b.mu.Lock() + ws.clientsMu.Lock() + if len(ws.clients) > 0 { + // Race: a CreateWorkspace re-registered a client + // between the detach path dropping ws.clientsMu and us + // taking b.mu. Abort: the workspace is still alive. + ws.clientsMu.Unlock() + b.mu.Unlock() + return + } + ws.clientsMu.Unlock() + if existing, ok := b.pathIndex[ws.resolvedPath]; ok && existing == ws.ID { + delete(b.pathIndex, ws.resolvedPath) + } + b.workspaces.Del(ws.ID) + remaining := b.workspaces.Len() + b.mu.Unlock() - if b.workspaces.Len() == 0 && b.shutdownFn != nil { + ws.invokeShutdown() + + if remaining == 0 && b.shutdownFn != nil { slog.Info("Last workspace removed, shutting down server...") b.shutdownFn() } } +// DeleteWorkspace is the public entry point used by the HTTP DELETE +// handler. It releases the named client's creation hold; live streams +// from the same client remain attached and continue holding the +// workspace open until their own deferred DetachClient runs. +func (b *Backend) DeleteWorkspace(id, clientID string) error { + return b.releaseHold(id, clientID) +} + +// SetCurrentSession records which session the given client is +// currently viewing within the workspace. Passing an empty sessionID +// clears the client's current-session entry (e.g. the client has +// returned to the landing screen). +// +// The client must be actually attached — i.e. its [clientState] entry +// must exist and have at least one live stream. A bare creation hold +// (streams == 0) is rejected with [ErrClientNotAttached]. This +// guards against zombie writes from a client that has detached and +// against ghost presence from a hold-only client that never opened an +// SSE stream. +func (b *Backend) SetCurrentSession(workspaceID, clientID, sessionID string) error { + if _, err := validateClientID(clientID); err != nil { + return err + } + ws, ok := b.workspaces.Get(workspaceID) + if !ok { + return ErrWorkspaceNotFound + } + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + cs, ok := ws.clients[clientID] + if !ok || cs.streams == 0 { + // No entry, or hold-only (no live stream): refuse the + // write. The presence record this is meant to feed + // should only reflect clients that can actually observe + // session events. + return ErrClientNotAttached + } + cs.currentSessionID = sessionID + return nil +} + +// AttachedClients returns the number of clients currently viewing +// sessionID in the given workspace. Only clients with at least one live +// SSE stream (streams > 0) AND a matching currentSessionID are counted; +// pure creation holds do not contribute. Returns [ErrWorkspaceNotFound] +// if the workspace is unknown. +func (b *Backend) AttachedClients(workspaceID, sessionID string) (int, error) { + ws, ok := b.workspaces.Get(workspaceID) + if !ok { + return 0, ErrWorkspaceNotFound + } + return ws.AttachedClientsForSession(sessionID), nil +} + +// AttachedClientsForSession returns the number of clients in this +// workspace whose currentSessionID equals sessionID and which have at +// least one live SSE stream. Hold-only clients (streams == 0) do not +// contribute. Acquires the workspace's [clientsMu] briefly; the +// returned count is a point-in-time snapshot. +func (w *Workspace) AttachedClientsForSession(sessionID string) int { + w.clientsMu.Lock() + defer w.clientsMu.Unlock() + n := 0 + for _, cs := range w.clients { + if cs.streams > 0 && cs.currentSessionID == sessionID { + n++ + } + } + return n +} + // GetWorkspaceProto returns the proto representation of a workspace. func (b *Backend) GetWorkspaceProto(id string) (proto.Workspace, error) { ws, err := b.GetWorkspace(id) @@ -250,6 +625,33 @@ func (b *Backend) Shutdown() { } } +// resolveWorkspaceKey returns a stable canonical form of path suitable +// for use as a dedup key. It applies filepath.Abs, then attempts +// filepath.EvalSymlinks; because EvalSymlinks errors on non-existent +// paths, it falls back to the cleaned absolute path in that case. +func resolveWorkspaceKey(path string) (string, error) { + abs, err := filepath.Abs(path) + if err != nil { + return "", err + } + if resolved, err := filepath.EvalSymlinks(abs); err == nil { + return resolved, nil + } + return abs, nil +} + +// validateClientID returns the trimmed UUID string or an error if the +// input is empty or not a valid UUID. +func validateClientID(id string) (string, error) { + if id == "" { + return "", ErrInvalidClientID + } + if _, err := uuid.Parse(id); err != nil { + return "", fmt.Errorf("%w: %v", ErrInvalidClientID, err) + } + return id, nil +} + func workspaceToProto(ws *Workspace) proto.Workspace { cfg := ws.Cfg.Config() out := proto.Workspace{ @@ -259,9 +661,58 @@ func workspaceToProto(ws *Workspace) proto.Workspace { DataDir: cfg.Options.DataDirectory, Debug: cfg.Options.Debug, Config: cfg, + Env: ws.Env, + Version: version.Version, } if ws.Skills != nil { out.Skills = skillStatesToProto(ws.Skills.States()) } return out } + +// logFirstWinsMismatch emits a debug line whenever a second +// CreateWorkspace at the same resolved path arrives with flags that +// differ from the originating workspace. The existing workspace wins; +// the incoming flags are silently ignored. +// +// The comparison is done against the incoming args as the caller sent +// them — including empty/zero values — rather than after defaulting. +// This means that, for example, a second caller who omits DataDir +// while the first set one will still log the mismatch. +func logFirstWinsMismatch(existing *Workspace, args proto.Workspace) { + existingCfg := existing.Cfg.Config() + existingYOLO := existing.Cfg.Overrides().SkipPermissionRequests + if existingYOLO == args.YOLO && + existingCfg.Options.Debug == args.Debug && + existingCfg.Options.DataDirectory == args.DataDir && + stringSlicesEqual(existing.Env, args.Env) { + return + } + slog.Debug( + "Workspace flag mismatch on duplicate create; first wins", + "workspace_id", existing.ID, + "path", existing.Path, + "existing_yolo", existingYOLO, + "requested_yolo", args.YOLO, + "existing_debug", existingCfg.Options.Debug, + "requested_debug", args.Debug, + "existing_data_dir", existingCfg.Options.DataDirectory, + "requested_data_dir", args.DataDir, + "existing_env", existing.Env, + "requested_env", args.Env, + ) +} + +// stringSlicesEqual reports whether a and b contain the same strings +// in the same order. nil and empty are treated as equal. +func stringSlicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/backend/backend_skills_test.go b/internal/backend/backend_skills_test.go new file mode 100644 index 0000000000000000000000000000000000000000..4f7dfa8f30e45976b9edba771576911a3205812a --- /dev/null +++ b/internal/backend/backend_skills_test.go @@ -0,0 +1,167 @@ +package backend_test + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/crush/internal/backend" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/skills" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// TestBackend_WorkspaceSkillsIsolation verifies that skill discovery +// state and SSE events are per-workspace, not process-global. Two +// workspaces in the same backend process must not see each other's +// discoveries (either in their initial snapshot or in subsequent +// PublishStates events). +func TestBackend_WorkspaceSkillsIsolation(t *testing.T) { + // Isolate all of config.Init's filesystem reads from the host. The + // project-local .agents/skills//SKILL.md per working dir is + // what we actually want each workspace to see; everything else + // (global skills, XDG dirs, etc.) must be empty/deterministic. + hostHome := t.TempDir() + t.Setenv("HOME", hostHome) + t.Setenv("XDG_CONFIG_HOME", filepath.Join(hostHome, ".config")) + t.Setenv("XDG_DATA_HOME", filepath.Join(hostHome, ".local", "share")) + t.Setenv("XDG_CACHE_HOME", filepath.Join(hostHome, ".cache")) + t.Setenv("CRUSH_SKILLS_DIR", t.TempDir()) + + // Each workspace gets its own working directory containing a + // distinct project-local skill so the discovery output differs. + wdA := t.TempDir() + wdB := t.TempDir() + writeSkill(t, wdA, "wsa-only-skill", "Workspace A only skill.") + writeSkill(t, wdB, "wsb-only-skill", "Workspace B only skill.") + + srvCfg, err := config.Init(wdA, "", false) + require.NoError(t, err) + b := backend.New(t.Context(), srvCfg, nil) + + cidA := uuid.New().String() + cidB := uuid.New().String() + + wsA, _, err := b.CreateWorkspace(proto.Workspace{ + ClientID: cidA, + Path: wdA, + DataDir: filepath.Join(wdA, ".crush"), + }) + require.NoError(t, err) + t.Cleanup(func() { _ = b.DeleteWorkspace(wsA.ID, cidA) }) + + wsB, _, err := b.CreateWorkspace(proto.Workspace{ + ClientID: cidB, + Path: wdB, + DataDir: filepath.Join(wdB, ".crush"), + }) + require.NoError(t, err) + t.Cleanup(func() { _ = b.DeleteWorkspace(wsB.ID, cidB) }) + + require.NotNil(t, wsA.Skills, "workspace A must have its own skills.Manager") + require.NotNil(t, wsB.Skills, "workspace B must have its own skills.Manager") + require.NotSame(t, wsA.Skills, wsB.Skills, "managers must be distinct instances per workspace") + + // Initial snapshots see each workspace's own filesystem skill, and + // neither sees the other's. + statesA := wsA.Skills.States() + statesB := wsB.Skills.States() + require.True(t, containsSkillName(statesA, "wsa-only-skill"), + "workspace A snapshot missing its own skill") + require.False(t, containsSkillName(statesA, "wsb-only-skill"), + "workspace A snapshot leaked workspace B's skill") + require.True(t, containsSkillName(statesB, "wsb-only-skill"), + "workspace B snapshot missing its own skill") + require.False(t, containsSkillName(statesB, "wsa-only-skill"), + "workspace B snapshot leaked workspace A's skill") + + // Subscribe to each workspace's SSE event stream. + ctxA, cancelA := context.WithCancel(t.Context()) + t.Cleanup(cancelA) + chA, err := b.SubscribeEvents(ctxA, wsA.ID) + require.NoError(t, err) + + ctxB, cancelB := context.WithCancel(t.Context()) + t.Cleanup(cancelB) + chB, err := b.SubscribeEvents(ctxB, wsB.ID) + require.NoError(t, err) + + // Trigger a republish on workspace A only. The marker name lets us + // distinguish this event from any incidental backend activity. + const marker = "wsa-republish-marker" + wsA.Skills.PublishStates([]*skills.SkillState{ + {Name: marker, State: skills.StateNormal}, + }) + + // Workspace A must receive its own event. + require.True(t, + waitForSkillsEvent(t, chA, marker, 2*time.Second), + "workspace A never received its own skills event") + + // Workspace B must NOT receive workspace A's event. + require.False(t, + waitForSkillsEvent(t, chB, marker, 250*time.Millisecond), + "workspace B leaked workspace A's skills event") + + // And A's published states are now visible on its manager's + // snapshot (verifies PublishStates updates the cache, not just the + // broker). + updatedA := wsA.Skills.States() + require.True(t, containsSkillName(updatedA, marker), + "PublishStates must update Manager.States()") + + // B's manager snapshot is untouched. + require.False(t, containsSkillName(wsB.Skills.States(), marker), + "workspace B's Manager.States() leaked workspace A's republish") +} + +func writeSkill(t *testing.T, workingDir, name, desc string) { + t.Helper() + skillDir := filepath.Join(workingDir, ".agents", "skills", name) + require.NoError(t, os.MkdirAll(skillDir, 0o755)) + content := fmt.Sprintf("---\nname: %s\ndescription: %s\n---\n%s\n", name, desc, desc) + require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte(content), 0o644)) +} + +func containsSkillName(states []*skills.SkillState, name string) bool { + for _, s := range states { + if s.Name == name { + return true + } + } + return false +} + +// waitForSkillsEvent drains the given event channel until either a +// pubsub.Event[skills.Event] containing a state named marker arrives or +// the timeout elapses. Non-skills events on the channel are silently +// skipped — the backend fans in many event types and we only care +// about skills here. +func waitForSkillsEvent(t *testing.T, ch <-chan pubsub.Event[tea.Msg], marker string, timeout time.Duration) bool { + t.Helper() + deadline := time.After(timeout) + for { + select { + case ev, ok := <-ch: + if !ok { + return false + } + se, ok := ev.Payload.(pubsub.Event[skills.Event]) + if !ok { + continue + } + if containsSkillName(se.Payload.States, marker) { + return true + } + case <-deadline: + return false + } + } +} diff --git a/internal/backend/backend_test.go b/internal/backend/backend_test.go index b1dabc58540ab0b165378a93b9fad4617a10928b..ee1165dfba5c27e2f021ba2e834a99b4d3a769e5 100644 --- a/internal/backend/backend_test.go +++ b/internal/backend/backend_test.go @@ -1,170 +1,1241 @@ -package backend_test +package backend import ( + "bytes" "context" - "fmt" + "errors" + "log/slog" "os" "path/filepath" + "runtime" + "strings" + "sync" + "sync/atomic" "testing" "time" - tea "charm.land/bubbletea/v2" - "github.com/charmbracelet/crush/internal/backend" - "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/proto" - "github.com/charmbracelet/crush/internal/pubsub" - "github.com/charmbracelet/crush/internal/skills" + "github.com/google/uuid" "github.com/stretchr/testify/require" ) -// TestBackend_WorkspaceSkillsIsolation verifies that skill discovery -// state and SSE events are per-workspace, not process-global. This -// guards the structural change in §5 of the plan: two workspaces in the -// same backend process must not see each other's discoveries (either in -// their initial snapshot or in subsequent PublishStates events). -// -// Before that change landed, the package-level latestStates cache and -// the package-level broker meant that: -// - workspace A's PublishStates would arrive on workspace B's SSE -// stream, and -// - the most recent SetLatestStates would silently overwrite the -// other workspace's cache. -// -// This test fails on either scenario. -func TestBackend_WorkspaceSkillsIsolation(t *testing.T) { - // Isolate all of config.Init's filesystem reads from the host. The - // project-local .agents/skills//SKILL.md per working dir is - // what we actually want each workspace to see; everything else - // (global skills, XDG dirs, etc.) must be empty/deterministic. - hostHome := t.TempDir() - t.Setenv("HOME", hostHome) - t.Setenv("XDG_CONFIG_HOME", filepath.Join(hostHome, ".config")) - t.Setenv("XDG_DATA_HOME", filepath.Join(hostHome, ".local", "share")) - t.Setenv("XDG_CACHE_HOME", filepath.Join(hostHome, ".cache")) - t.Setenv("CRUSH_SKILLS_DIR", t.TempDir()) - - // Each workspace gets its own working directory containing a - // distinct project-local skill so the discovery output differs. - wdA := t.TempDir() - wdB := t.TempDir() - writeSkill(t, wdA, "wsa-only-skill", "Workspace A only skill.") - writeSkill(t, wdB, "wsb-only-skill", "Workspace B only skill.") - - srvCfg, err := config.Init(wdA, "", false) - require.NoError(t, err) - b := backend.New(t.Context(), srvCfg, nil) - - wsA, _, err := b.CreateWorkspace(proto.Workspace{ - Path: wdA, - DataDir: filepath.Join(wdA, ".crush"), - }) - require.NoError(t, err) - t.Cleanup(func() { b.DeleteWorkspace(wsA.ID) }) - - wsB, _, err := b.CreateWorkspace(proto.Workspace{ - Path: wdB, - DataDir: filepath.Join(wdB, ".crush"), - }) - require.NoError(t, err) - t.Cleanup(func() { b.DeleteWorkspace(wsB.ID) }) - - require.NotNil(t, wsA.Skills, "workspace A must have its own skills.Manager") - require.NotNil(t, wsB.Skills, "workspace B must have its own skills.Manager") - require.NotSame(t, wsA.Skills, wsB.Skills, "managers must be distinct instances per workspace") - - // Initial snapshots see each workspace's own filesystem skill, and - // neither sees the other's. - statesA := wsA.Skills.States() - statesB := wsB.Skills.States() - require.True(t, containsSkillName(statesA, "wsa-only-skill"), - "workspace A snapshot missing its own skill") - require.False(t, containsSkillName(statesA, "wsb-only-skill"), - "workspace A snapshot leaked workspace B's skill") - require.True(t, containsSkillName(statesB, "wsb-only-skill"), - "workspace B snapshot missing its own skill") - require.False(t, containsSkillName(statesB, "wsa-only-skill"), - "workspace B snapshot leaked workspace A's skill") - - // Subscribe to each workspace's SSE event stream. - ctxA, cancelA := context.WithCancel(t.Context()) - t.Cleanup(cancelA) - chA, err := b.SubscribeEvents(ctxA, wsA.ID) - require.NoError(t, err) - - ctxB, cancelB := context.WithCancel(t.Context()) - t.Cleanup(cancelB) - chB, err := b.SubscribeEvents(ctxB, wsB.ID) - require.NoError(t, err) - - // Trigger a republish on workspace A only. The marker name lets us - // distinguish this event from any incidental backend activity. - const marker = "wsa-republish-marker" - wsA.Skills.PublishStates([]*skills.SkillState{ - {Name: marker, State: skills.StateNormal}, - }) - - // Workspace A must receive its own event. - require.True(t, - waitForSkillsEvent(t, chA, marker, 2*time.Second), - "workspace A never received its own skills event") - - // Workspace B must NOT receive workspace A's event. - require.False(t, - waitForSkillsEvent(t, chB, marker, 250*time.Millisecond), - "workspace B leaked workspace A's skills event") - - // And A's published states are now visible on its manager's - // snapshot (verifies PublishStates updates the cache, not just the - // broker). - updatedA := wsA.Skills.States() - require.True(t, containsSkillName(updatedA, marker), - "PublishStates must update Manager.States()") +// newTestBackend returns a Backend whose teardown path skips any +// real [app.App] shutdown work. Useful for state-machine tests that +// install synthetic workspaces directly via insertTestWorkspace. +func newTestBackend(t *testing.T) (*Backend, *atomic.Int32) { + t.Helper() + var shutdownCount atomic.Int32 + b := &Backend{ + workspaces: csync.NewMap[string, *Workspace](), + pathIndex: make(map[string]string), + ctx: context.Background(), + createGrace: 50 * time.Millisecond, + shutdownFn: func() { shutdownCount.Add(1) }, + } + return b, &shutdownCount +} - // B's manager snapshot is untouched. - require.False(t, containsSkillName(wsB.Skills.States(), marker), - "workspace B's Manager.States() leaked workspace A's republish") +// insertTestWorkspace installs a synthetic workspace into b at the +// given resolved path. Its shutdownFn is recorded in the returned +// counter so tests can assert it ran exactly once. +func insertTestWorkspace(t *testing.T, b *Backend, key string) (*Workspace, *atomic.Int32) { + t.Helper() + var shutdowns atomic.Int32 + ws := &Workspace{ + ID: uuid.New().String(), + Path: key, + resolvedPath: key, + clients: make(map[string]*clientState), + shutdownFn: func() { shutdowns.Add(1) }, + } + b.mu.Lock() + b.workspaces.Set(ws.ID, ws) + b.pathIndex[key] = ws.ID + b.mu.Unlock() + return ws, &shutdowns } -func writeSkill(t *testing.T, workingDir, name, desc string) { +func newClientID(t *testing.T) string { t.Helper() - skillDir := filepath.Join(workingDir, ".agents", "skills", name) - require.NoError(t, os.MkdirAll(skillDir, 0o755)) - content := fmt.Sprintf("---\nname: %s\ndescription: %s\n---\n%s\n", name, desc, desc) - require.NoError(t, os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte(content), 0o644)) + return uuid.New().String() +} + +func TestResolveWorkspaceKey_AbsoluteAndSymlink(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + real, err := filepath.EvalSymlinks(tmp) + require.NoError(t, err) + + got, err := resolveWorkspaceKey(tmp) + require.NoError(t, err) + require.Equal(t, real, got) +} + +func TestResolveWorkspaceKey_NonExistentFallback(t *testing.T) { + t.Parallel() + + missing := filepath.Join(t.TempDir(), "does", "not", "exist") + got, err := resolveWorkspaceKey(missing) + require.NoError(t, err) + abs, err := filepath.Abs(missing) + require.NoError(t, err) + require.Equal(t, abs, got) +} + +func TestValidateClientID(t *testing.T) { + t.Parallel() + + _, err := validateClientID("") + require.ErrorIs(t, err, ErrInvalidClientID) + _, err = validateClientID("not-a-uuid") + require.ErrorIs(t, err, ErrInvalidClientID) + + id := uuid.New().String() + got, err := validateClientID(id) + require.NoError(t, err) + require.Equal(t, id, got) +} + +func TestRegisterClient_Idempotent(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + b.registerClient(ws, cid) + + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + require.Len(t, ws.clients, 1) + require.NotNil(t, ws.clients[cid].holdTimer) + require.Equal(t, 0, ws.clients[cid].streams) +} + +func TestAttachClient_ConsumesHold(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.AttachClient(ws.ID, cid)) + + ws.clientsMu.Lock() + require.Len(t, ws.clients, 1) + require.Nil(t, ws.clients[cid].holdTimer, "attach must stop the grace timer") + require.Equal(t, 1, ws.clients[cid].streams) + ws.clientsMu.Unlock() + + // Wait past the grace window: a stopped timer must not fire. + time.Sleep(150 * time.Millisecond) + require.Equal(t, int32(0), shutdowns.Load(), "workspace must not be torn down while attached") +} + +func TestAttachClient_WithoutPriorCreate(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cid)) + + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + require.Len(t, ws.clients, 1) + require.Equal(t, 1, ws.clients[cid].streams) + require.Nil(t, ws.clients[cid].holdTimer) +} + +func TestAttachClient_DuplicateStreams(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cid)) + require.NoError(t, b.AttachClient(ws.ID, cid)) + + ws.clientsMu.Lock() + require.Equal(t, 2, ws.clients[cid].streams) + ws.clientsMu.Unlock() + + b.DetachClient(ws.ID, cid) + ws.clientsMu.Lock() + require.Equal(t, 1, ws.clients[cid].streams) + ws.clientsMu.Unlock() + require.Equal(t, int32(0), shutdowns.Load()) + + b.DetachClient(ws.ID, cid) + require.Equal(t, int32(1), shutdowns.Load(), "second detach tears down the workspace") } -func containsSkillName(states []*skills.SkillState, name string) bool { - for _, s := range states { - if s.Name == name { - return true +func TestDetachClient_LastStreamTearsDown(t *testing.T) { + t.Parallel() + + b, srvShutdowns := newTestBackend(t) + ws, wsShutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.AttachClient(ws.ID, cid)) + b.DetachClient(ws.ID, cid) + + require.Equal(t, int32(1), wsShutdowns.Load()) + require.Equal(t, int32(1), srvShutdowns.Load(), "last workspace shut down must trigger server shutdown") + _, err := b.GetWorkspace(ws.ID) + require.ErrorIs(t, err, ErrWorkspaceNotFound) +} + +func TestHoldExpiry_TearsDown(t *testing.T) { + t.Parallel() + + b, srvShutdowns := newTestBackend(t) + ws, wsShutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + + require.Eventually(t, func() bool { + return wsShutdowns.Load() == 1 && srvShutdowns.Load() == 1 + }, 1*time.Second, 5*time.Millisecond) +} + +func TestReleaseHold_NoStreams(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.releaseHold(ws.ID, cid)) + + require.Equal(t, int32(1), shutdowns.Load()) + // Idempotent. + require.NoError(t, b.releaseHold(ws.ID, cid)) + require.Equal(t, int32(1), shutdowns.Load()) +} + +func TestReleaseHold_WithActiveStream(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.AttachClient(ws.ID, cid)) + require.NoError(t, b.releaseHold(ws.ID, cid)) + + ws.clientsMu.Lock() + require.Equal(t, 1, ws.clients[cid].streams) + require.Nil(t, ws.clients[cid].holdTimer) + ws.clientsMu.Unlock() + require.Equal(t, int32(0), shutdowns.Load()) + + b.DetachClient(ws.ID, cid) + require.Equal(t, int32(1), shutdowns.Load()) +} + +func TestReleaseHoldThenAttach(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + require.NoError(t, b.releaseHold(ws.ID, cid)) // no entry yet — no-op. + require.NoError(t, b.AttachClient(ws.ID, cid)) + ws.clientsMu.Lock() + require.Equal(t, 1, ws.clients[cid].streams) + ws.clientsMu.Unlock() + require.NoError(t, b.releaseHold(ws.ID, cid)) // hold-only no-op (no hold timer). + require.Equal(t, int32(0), shutdowns.Load()) + b.DetachClient(ws.ID, cid) + require.Equal(t, int32(1), shutdowns.Load()) +} + +func TestRefcountWithSecondClient(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cidA := newClientID(t) + cidB := newClientID(t) + b.registerClient(ws, cidA) + require.NoError(t, b.AttachClient(ws.ID, cidA)) + b.registerClient(ws, cidB) + require.NoError(t, b.AttachClient(ws.ID, cidB)) + + b.DetachClient(ws.ID, cidA) + ws.clientsMu.Lock() + require.Contains(t, ws.clients, cidB) + require.NotContains(t, ws.clients, cidA) + ws.clientsMu.Unlock() + require.Equal(t, int32(0), shutdowns.Load(), "workspace survives while second client attached") + + b.DetachClient(ws.ID, cidB) + require.Equal(t, int32(1), shutdowns.Load()) +} + +func TestAttachClient_InvalidID(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + require.ErrorIs(t, b.AttachClient(ws.ID, ""), ErrInvalidClientID) + require.ErrorIs(t, b.AttachClient(ws.ID, "not-a-uuid"), ErrInvalidClientID) +} + +func TestDeleteWorkspace_RejectsBadClientID(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + require.ErrorIs(t, b.DeleteWorkspace(ws.ID, ""), ErrInvalidClientID) + require.ErrorIs(t, b.DeleteWorkspace(ws.ID, "not-a-uuid"), ErrInvalidClientID) +} + +// TestHoldExpiry_RaceWithAttach checks that, when the grace timer fires +// while a concurrent AttachClient call is in flight, the workspace ends +// up either fully attached or fully torn down — never in a half-state. +func TestHoldExpiry_RaceWithAttach(t *testing.T) { + t.Parallel() + + for i := range 50 { + b, _ := newTestBackend(t) + // Tighten the grace window further to force the race. + b.createGrace = 1 * time.Millisecond + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/race") + + cid := newClientID(t) + b.registerClient(ws, cid) + // Attach concurrently with the very short grace timer. + errCh := make(chan error, 1) + go func() { errCh <- b.AttachClient(ws.ID, cid) }() + <-errCh + + // Wait for any pending timer to settle. + time.Sleep(10 * time.Millisecond) + + ws.clientsMu.Lock() + gotShutdown := shutdowns.Load() == 1 + cs, present := ws.clients[cid] + var ( + gotStreams int + gotHoldTimer *time.Timer + ) + if present { + gotStreams = cs.streams + gotHoldTimer = cs.holdTimer + } + ws.clientsMu.Unlock() + // Either the workspace was torn down OR the client is + // attached with streams==1 and the hold timer cleared. + // The state must be consistent: if shutdown, client is + // gone; if attached, no teardown and streams==1. + if gotShutdown { + require.False(t, present, "iter %d: shutdown but client still present", i) + } else { + require.True(t, present, "iter %d: not shutdown but client missing", i) + require.Equal(t, 1, gotStreams, "iter %d: attach winner must leave streams=1", i) + require.Nil(t, gotHoldTimer, "iter %d: attach winner must clear holdTimer", i) } } - return false } -// waitForSkillsEvent drains the given event channel until either a -// pubsub.Event[skills.Event] containing a state named marker arrives or -// the timeout elapses. Non-skills events on the channel are silently -// skipped — the backend fans in many event types and we only care -// about skills here. -func waitForSkillsEvent(t *testing.T, ch <-chan pubsub.Event[tea.Msg], marker string, timeout time.Duration) bool { +// TestConcurrentAttachDetach exercises the state machine under +// parallel attach/detach pressure with the race detector. +func TestConcurrentAttachDetach(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.AttachClient(ws.ID, cid)) // ensure refcount stays > 0. + + const n = 50 + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + cid2 := newClientID(t) + _ = b.AttachClient(ws.ID, cid2) + b.DetachClient(ws.ID, cid2) + }() + } + wg.Wait() + + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + require.Len(t, ws.clients, 1) + require.Contains(t, ws.clients, cid) +} + +// TestPathDedupe_FullCreate exercises CreateWorkspace end-to-end +// (config init, real app.App). Two CreateWorkspace calls at the same +// path return the same workspace ID and share the clients map. +func TestPathDedupe_FullCreate(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwd := t.TempDir() + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + cidA := uuid.New().String() + cidB := uuid.New().String() + + wsA, protoA, err := b.CreateWorkspace(protoWS(cwd, dataDir, cidA)) + require.NoError(t, err) + require.NotEmpty(t, protoA.ID) + require.Equal(t, protoA.DataDir, wsA.Cfg.Config().Options.DataDirectory) + + wsB, protoB, err := b.CreateWorkspace(protoWS(cwd, dataDir, cidB)) + require.NoError(t, err) + require.Equal(t, wsA.ID, wsB.ID, "second create at same path must return existing workspace") + require.Equal(t, protoA.ID, protoB.ID) + + wsA.clientsMu.Lock() + require.Contains(t, wsA.clients, cidA) + require.Contains(t, wsA.clients, cidB) + wsA.clientsMu.Unlock() +} + +// TestPathDedupe_DifferentPaths_DifferentWorkspaces confirms that two +// CreateWorkspace calls at distinct paths produce distinct workspaces. +func TestPathDedupe_DifferentPaths_DifferentWorkspaces(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwdA := t.TempDir() + cwdB := t.TempDir() + dataA := t.TempDir() + dataB := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + wsA, _, err := b.CreateWorkspace(protoWS(cwdA, dataA, uuid.New().String())) + require.NoError(t, err) + wsB, _, err := b.CreateWorkspace(protoWS(cwdB, dataB, uuid.New().String())) + require.NoError(t, err) + require.NotEqual(t, wsA.ID, wsB.ID) +} + +// TestPathDedupe_FirstWinsKeepsOriginalEnv verifies that the second +// create at the same path returns the *originating* client's Env in +// its proto and does not mutate the existing workspace's YOLO/Debug +// flags. +func TestPathDedupe_FirstWinsKeepsOriginalEnv(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwd := t.TempDir() + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + originalEnv := []string{"FOO=bar"} + argsA := protoWS(cwd, dataDir, uuid.New().String()) + argsA.YOLO = true + argsA.Env = originalEnv + wsA, protoA, err := b.CreateWorkspace(argsA) + require.NoError(t, err) + require.True(t, protoA.YOLO) + require.Equal(t, originalEnv, protoA.Env) + + argsB := protoWS(cwd, dataDir, uuid.New().String()) + argsB.YOLO = false + argsB.Debug = true + argsB.Env = []string{"BAZ=qux"} + _, protoB, err := b.CreateWorkspace(argsB) + require.NoError(t, err) + require.Equal(t, protoA.ID, protoB.ID) + require.True(t, protoB.YOLO, "first wins: YOLO must remain true") + require.Equal(t, originalEnv, protoB.Env, "proto must carry the originating client's Env") + require.Equal(t, wsA.Cfg.Overrides().SkipPermissionRequests, true) +} + +// TestPathDedupe_Symlink confirms two paths that resolve to the same +// target share a workspace. +func TestPathDedupe_Symlink(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink semantics differ on windows") + } + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + real := t.TempDir() + link := filepath.Join(t.TempDir(), "link") + require.NoError(t, os.Symlink(real, link)) + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + wsA, _, err := b.CreateWorkspace(protoWS(real, dataDir, uuid.New().String())) + require.NoError(t, err) + wsB, _, err := b.CreateWorkspace(protoWS(link, dataDir, uuid.New().String())) + require.NoError(t, err) + require.Equal(t, wsA.ID, wsB.ID) +} + +// TestPathDedupe_NonExistentPath ensures CreateWorkspace tolerates a +// path that does not yet exist (EvalSymlinks falls back to Abs). +func TestPathDedupe_NonExistentPath(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + parent := t.TempDir() + missing := filepath.Join(parent, "does-not-exist") + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + _, p, err := b.CreateWorkspace(protoWS(missing, dataDir, uuid.New().String())) + require.NoError(t, err) + require.NotEmpty(t, p.ID) +} + +// TestCreateWorkspace_IdempotentSameClient checks that a duplicate +// create from the same client at the same path does not produce a +// second claim. +func TestCreateWorkspace_IdempotentSameClient(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwd := t.TempDir() + dataDir := t.TempDir() + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + cid := uuid.New().String() + ws1, _, err := b.CreateWorkspace(protoWS(cwd, dataDir, cid)) + require.NoError(t, err) + ws2, _, err := b.CreateWorkspace(protoWS(cwd, dataDir, cid)) + require.NoError(t, err) + require.Equal(t, ws1.ID, ws2.ID) + + ws1.clientsMu.Lock() + require.Len(t, ws1.clients, 1, "duplicate create from same client must not double the claim") + ws1.clientsMu.Unlock() +} + +// TestPathDedupe_ParallelCreates ensures two simultaneous CreateWorkspace +// calls at the same path produce the same workspace and the clients map +// contains both client IDs. +func TestPathDedupe_ParallelCreates(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwd := t.TempDir() + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + cidA := uuid.New().String() + cidB := uuid.New().String() + + type result struct { + ws *Workspace + proto proto.Workspace + err error + } + ch := make(chan result, 2) + start := make(chan struct{}) + go func() { + <-start + ws, p, err := b.CreateWorkspace(protoWS(cwd, dataDir, cidA)) + ch <- result{ws, p, err} + }() + go func() { + <-start + ws, p, err := b.CreateWorkspace(protoWS(cwd, dataDir, cidB)) + ch <- result{ws, p, err} + }() + close(start) + r1 := <-ch + r2 := <-ch + require.NoError(t, r1.err) + require.NoError(t, r2.err) + require.Equal(t, r1.ws.ID, r2.ws.ID, "both creates must converge on one workspace ID") + + ws := r1.ws + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + require.Contains(t, ws.clients, cidA) + require.Contains(t, ws.clients, cidB) +} + +// TestCreateWorkspace_RejectsBadClientID covers the 400 path from the +// backend side. +func TestCreateWorkspace_RejectsBadClientID(t *testing.T) { + t.Parallel() + + b := New(context.Background(), nil, func() {}) + + _, _, err := b.CreateWorkspace(protoWS("/tmp/x", t.TempDir(), "")) + require.ErrorIs(t, err, ErrInvalidClientID) + _, _, err = b.CreateWorkspace(protoWS("/tmp/x", t.TempDir(), "not-a-uuid")) + require.ErrorIs(t, err, ErrInvalidClientID) +} + +// drainBackend tears the backend down at the end of a test by deleting +// every remaining workspace. Necessary so the test process doesn't +// leak goroutines or DB handles from the embedded [app.App] instances. +func drainBackend(t *testing.T, b *Backend) { t.Helper() - deadline := time.After(timeout) - for { - select { - case ev, ok := <-ch: - if !ok { - return false - } - se, ok := ev.Payload.(pubsub.Event[skills.Event]) - if !ok { - continue - } - if containsSkillName(se.Payload.States, marker) { - return true - } - case <-deadline: - return false + for _, ws := range b.workspaces.Seq2() { + ws.clientsMu.Lock() + ids := make([]string, 0, len(ws.clients)) + for id := range ws.clients { + ids = append(ids, id) + } + ws.clientsMu.Unlock() + for _, cid := range ids { + _ = b.releaseHold(ws.ID, cid) + } + } +} + +func protoWS(path, dataDir, clientID string) proto.Workspace { + return proto.Workspace{Path: path, DataDir: dataDir, ClientID: clientID} +} + +// syncBuffer is a thread-safe buffer that can be safely read and written +// from multiple goroutines. +type syncBuffer struct { + mu sync.Mutex + buf bytes.Buffer +} + +func (sb *syncBuffer) Write(p []byte) (n int, err error) { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.Write(p) +} + +func (sb *syncBuffer) String() string { + sb.mu.Lock() + defer sb.mu.Unlock() + return sb.buf.String() +} + +// captureDebugLogs installs a buffer-backed slog handler at Debug +// level for the duration of the test, returning the buffer. The +// previous default handler is restored via t.Cleanup. +func captureDebugLogs(t *testing.T) *syncBuffer { + t.Helper() + var sb syncBuffer + prev := slog.Default() + handler := slog.NewTextHandler(&sb, &slog.HandlerOptions{Level: slog.LevelDebug}) + slog.SetDefault(slog.New(handler)) + t.Cleanup(func() { slog.SetDefault(prev) }) + return &sb +} + +// xdgIsolated points HOME and XDG_* variables at fresh tempdirs so +// CreateWorkspace's config loading does not interfere with the host +// machine's real config. +func xdgIsolated(t *testing.T) { + t.Helper() + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) +} + +// TestFirstWinsMismatch_LogsOnFlagDifferences verifies that the +// debug mismatch line is emitted when any of YOLO, Debug, DataDir, +// or Env differs between the first and second CreateWorkspace at +// the same path, and that the existing workspace's Debug flag is +// not overwritten. +func TestFirstWinsMismatch_LogsOnFlagDifferences(t *testing.T) { + tests := []struct { + name string + mutate func(*proto.Workspace) + }{ + { + name: "yolo", + mutate: func(p *proto.Workspace) { p.YOLO = true }, + }, + { + name: "debug", + mutate: func(p *proto.Workspace) { p.Debug = true }, + }, + { + name: "datadir", + mutate: func(p *proto.Workspace) { p.DataDir = "" }, + }, + { + name: "env", + mutate: func(p *proto.Workspace) { p.Env = []string{"NEW=val"} }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + xdgIsolated(t) + cwd := t.TempDir() + dataDir := t.TempDir() + + buf := captureDebugLogs(t) + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + argsA := protoWS(cwd, dataDir, uuid.New().String()) + argsA.Env = []string{"FOO=bar"} + wsA, _, err := b.CreateWorkspace(argsA) + require.NoError(t, err) + originalDebug := wsA.Cfg.Config().Options.Debug + originalYOLO := wsA.Cfg.Overrides().SkipPermissionRequests + + argsB := protoWS(cwd, dataDir, uuid.New().String()) + argsB.Env = []string{"FOO=bar"} // identical by default + tc.mutate(&argsB) + _, _, err = b.CreateWorkspace(argsB) + require.NoError(t, err) + + require.Contains( + t, buf.String(), + "Workspace flag mismatch on duplicate create", + "expected debug log for mismatching %s", tc.name, + ) + // Existing workspace's YOLO and Debug must not change. + require.Equal(t, originalYOLO, wsA.Cfg.Overrides().SkipPermissionRequests, "YOLO must be immutable on first-wins") + require.Equal(t, originalDebug, wsA.Cfg.Config().Options.Debug, "Debug must be immutable on first-wins") + }) + } +} + +// TestFirstWinsMismatch_NoLogWhenIdentical confirms identical args +// do not emit the mismatch log line. +func TestFirstWinsMismatch_NoLogWhenIdentical(t *testing.T) { + xdgIsolated(t) + cwd := t.TempDir() + dataDir := t.TempDir() + + buf := captureDebugLogs(t) + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + argsA := protoWS(cwd, dataDir, uuid.New().String()) + argsA.Env = []string{"FOO=bar"} + _, _, err := b.CreateWorkspace(argsA) + require.NoError(t, err) + + argsB := protoWS(cwd, dataDir, uuid.New().String()) + argsB.Env = []string{"FOO=bar"} + _, _, err = b.CreateWorkspace(argsB) + require.NoError(t, err) + + require.False(t, + strings.Contains(buf.String(), "Workspace flag mismatch on duplicate create"), + "identical args must not log a mismatch: %s", buf.String()) +} + +// TestRaceTwoClientsAttachOneDetaches exercises the PLAN-required +// race scenario: two clients attach concurrently, then one detaches. +// The workspace must remain alive with refcount==1 and the clients +// map must reflect the remaining client only. +func TestRaceTwoClientsAttachOneDetaches(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/race-two") + + cidA := newClientID(t) + cidB := newClientID(t) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + require.NoError(t, b.AttachClient(ws.ID, cidA)) + }() + go func() { + defer wg.Done() + require.NoError(t, b.AttachClient(ws.ID, cidB)) + }() + wg.Wait() + + ws.clientsMu.Lock() + require.Len(t, ws.clients, 2, "both clients must be attached") + ws.clientsMu.Unlock() + + b.DetachClient(ws.ID, cidA) + + ws.clientsMu.Lock() + require.Len(t, ws.clients, 1, "refcount must be 1 after one detach") + require.Contains(t, ws.clients, cidB, "remaining client must be cidB") + require.NotContains(t, ws.clients, cidA, "detached client must be removed") + ws.clientsMu.Unlock() + require.Equal(t, int32(0), shutdowns.Load(), "workspace must remain alive") + + // Drain. + b.DetachClient(ws.ID, cidB) + require.Equal(t, int32(1), shutdowns.Load()) +} + +// TestExplicitDeleteThenAttach reproduces the PLAN scenario: start +// with a real hold, releaseHold consumes it, AttachClient from the +// same clientID creates a fresh entry with streams==1, and calling +// releaseHold again is a no-op. A second client keeps the workspace +// alive so AttachClient can still resolve the workspace ID after the +// first client's hold is released. +func TestExplicitDeleteThenAttach(t *testing.T) { + t.Parallel() + + // Large grace window so timers cannot fire during the test + // — we want to exercise the explicit releaseHold path. + b, _ := newTestBackend(t) + b.createGrace = time.Hour + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/delete-then-attach") + + // Anchor client keeps the workspace registered in + // b.workspaces across the cid's releaseHold below. + anchor := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, anchor)) + + cid := newClientID(t) + // Real hold via registerClient (mirrors CreateWorkspace). + b.registerClient(ws, cid) + ws.clientsMu.Lock() + require.Contains(t, ws.clients, cid) + require.NotNil(t, ws.clients[cid].holdTimer, "hold must be live") + require.Equal(t, 0, ws.clients[cid].streams) + ws.clientsMu.Unlock() + + // releaseHold: consumes the hold and removes the entry + // (streams == 0). The anchor client keeps the workspace + // alive. + require.NoError(t, b.releaseHold(ws.ID, cid)) + require.Equal(t, int32(0), shutdowns.Load(), "anchor must keep workspace alive") + ws.clientsMu.Lock() + require.NotContains(t, ws.clients, cid, "entry must be removed by releaseHold") + ws.clientsMu.Unlock() + + // AttachClient creates a fresh entry with streams==1 and no + // hold timer. + require.NoError(t, b.AttachClient(ws.ID, cid)) + ws.clientsMu.Lock() + require.Contains(t, ws.clients, cid, "fresh entry must be created") + require.Equal(t, 1, ws.clients[cid].streams, "fresh attach must start at streams=1") + require.Nil(t, ws.clients[cid].holdTimer, "fresh attach must have no hold timer") + ws.clientsMu.Unlock() + + // Calling releaseHold again is a no-op (no hold timer to + // stop, streams > 0 so the entry stays). + require.NoError(t, b.releaseHold(ws.ID, cid)) + ws.clientsMu.Lock() + require.Contains(t, ws.clients, cid, "releaseHold must not touch a stream-only entry") + require.Equal(t, 1, ws.clients[cid].streams) + require.Nil(t, ws.clients[cid].holdTimer) + ws.clientsMu.Unlock() + + // Drain. + b.DetachClient(ws.ID, cid) + b.DetachClient(ws.ID, anchor) + require.Equal(t, int32(1), shutdowns.Load()) +} + +// TestAttachClient_RacesWithTeardown forces AttachClient to compete +// with the teardown path triggered by DetachClient. Before the fix, +// AttachClient could observe a workspace after teardown had already +// decided to remove it (because AttachClient did not synchronize with +// Backend.mu), leaving a live stream claim attached to a workspace +// that was then removed and shut down. With the fix, the outcome must +// be deterministic: either AttachClient won and the workspace is +// alive with the client registered, or teardown won and AttachClient +// returns ErrWorkspaceNotFound — never a half-state where the +// workspace is gone but ws.clients still contains the new client. +func TestAttachClient_RacesWithTeardown(t *testing.T) { + t.Parallel() + + for i := range 200 { + b, _ := newTestBackend(t) + // Keep the grace window long so it can't fire during the + // test and confuse the bookkeeping. + b.createGrace = time.Hour + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/race-teardown") + + // Seed: cidA holds the workspace open via a stream. The + // imminent DetachClient(cidA) will be the *only* claim + // drop, so teardown will run. + cidA := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cidA)) + + // cidB attempts to attach concurrently with the detach + // that will tear the workspace down. + cidB := newClientID(t) + start := make(chan struct{}) + errCh := make(chan error, 1) + detachDone := make(chan struct{}) + go func() { + <-start + errCh <- b.AttachClient(ws.ID, cidB) + }() + go func() { + <-start + b.DetachClient(ws.ID, cidA) + close(detachDone) + }() + close(start) + + // Wait for both goroutines so teardown (including + // shutdownFn) has fully run before we read state. + attachErr := <-errCh + <-detachDone + + _, wsStillRegistered := b.workspaces.Get(ws.ID) + ws.clientsMu.Lock() + _, hasA := ws.clients[cidA] + _, hasB := ws.clients[cidB] + clientCount := len(ws.clients) + ws.clientsMu.Unlock() + shutdownCount := shutdowns.Load() + + switch { + case attachErr == nil: + // AttachClient won. The workspace must be alive + // (registered) with cidB in its clients map. cidA + // may or may not still be there depending on who + // took clientsMu first, but the workspace must + // not have been torn down. + require.True(t, wsStillRegistered, + "iter %d: attach succeeded but workspace was removed", i) + require.True(t, hasB, + "iter %d: attach succeeded but cidB missing from clients", i) + require.Equal(t, int32(0), shutdownCount, + "iter %d: attach succeeded but workspace was shut down", i) + case errors.Is(attachErr, ErrWorkspaceNotFound): + // Teardown won. The workspace must be removed, + // shut down exactly once, and ws.clients must be + // empty (no half-state with cidB inserted into a + // dead workspace's clients map). + require.False(t, wsStillRegistered, + "iter %d: ErrWorkspaceNotFound but workspace still registered", i) + require.Equal(t, int32(1), shutdownCount, + "iter %d: ErrWorkspaceNotFound but shutdown count = %d", i, shutdownCount) + require.False(t, hasA, + "iter %d: teardown won but cidA still in clients", i) + require.False(t, hasB, + "iter %d: teardown won but cidB still in clients (would be the leaked attach)", i) + require.Zero(t, clientCount, + "iter %d: teardown won but clients map is non-empty", i) + default: + t.Fatalf("iter %d: unexpected AttachClient error: %v", i, attachErr) } } } + +// TestSetCurrentSession_BasicAttachAndSwitch verifies the happy path: +// an attached client can set its current session, a second attached +// client can target the same session, and one of them can switch to a +// different session without disturbing the other's record. +func TestSetCurrentSession_BasicAttachAndSwitch(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/current-session-basic") + + cidA := newClientID(t) + cidB := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cidA)) + require.NoError(t, b.AttachClient(ws.ID, cidB)) + + require.NoError(t, b.SetCurrentSession(ws.ID, cidA, "S1")) + ws.clientsMu.Lock() + require.Equal(t, "S1", ws.clients[cidA].currentSessionID) + ws.clientsMu.Unlock() + + require.NoError(t, b.SetCurrentSession(ws.ID, cidB, "S1")) + ws.clientsMu.Lock() + require.Equal(t, "S1", ws.clients[cidA].currentSessionID) + require.Equal(t, "S1", ws.clients[cidB].currentSessionID) + ws.clientsMu.Unlock() + + // B switches to S2; counts redistribute. + require.NoError(t, b.SetCurrentSession(ws.ID, cidB, "S2")) + ws.clientsMu.Lock() + require.Equal(t, "S1", ws.clients[cidA].currentSessionID) + require.Equal(t, "S2", ws.clients[cidB].currentSessionID) + ws.clientsMu.Unlock() + + // A clears its selection. + require.NoError(t, b.SetCurrentSession(ws.ID, cidA, "")) + ws.clientsMu.Lock() + require.Empty(t, ws.clients[cidA].currentSessionID) + require.Equal(t, "S2", ws.clients[cidB].currentSessionID) + ws.clientsMu.Unlock() + + // Drain to release the workspace. + b.DetachClient(ws.ID, cidA) + b.DetachClient(ws.ID, cidB) +} + +// TestSetCurrentSession_DetachClearsEntry verifies the implicit +// cleanup: once a client's [clientState] entry is removed (last +// stream closed), its currentSessionID is gone with it. +func TestSetCurrentSession_DetachClearsEntry(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/current-session-detach") + + // Anchor client so the workspace is not torn down when cid + // detaches. + anchor := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, anchor)) + + cid := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cid)) + require.NoError(t, b.SetCurrentSession(ws.ID, cid, "S2")) + + b.DetachClient(ws.ID, cid) + + ws.clientsMu.Lock() + _, present := ws.clients[cid] + ws.clientsMu.Unlock() + require.False(t, present, "detach must remove the clientState entry along with its currentSessionID") + + // A follow-up SetCurrentSession on the gone client must be + // rejected with ErrClientNotAttached. + require.ErrorIs(t, b.SetCurrentSession(ws.ID, cid, "S3"), ErrClientNotAttached) + + b.DetachClient(ws.ID, anchor) +} + +// TestSetCurrentSession_RejectsHoldOnly verifies that a registered +// client whose only claim is a creation hold (streams == 0) cannot +// influence presence: SetCurrentSession returns ErrClientNotAttached +// and the entry's currentSessionID stays empty. +func TestSetCurrentSession_RejectsHoldOnly(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + // Keep the grace window large so the hold survives the test. + b.createGrace = time.Hour + ws, _ := insertTestWorkspace(t, b, "/tmp/current-session-hold") + + cid := newClientID(t) + b.registerClient(ws, cid) + + require.ErrorIs(t, b.SetCurrentSession(ws.ID, cid, "S1"), ErrClientNotAttached) + + ws.clientsMu.Lock() + require.Empty(t, ws.clients[cid].currentSessionID, "hold-only client must not write a session id") + ws.clientsMu.Unlock() + + // Drain. + require.NoError(t, b.releaseHold(ws.ID, cid)) +} + +// TestSetCurrentSession_UnknownClient verifies that a client with no +// entry at all is rejected with ErrClientNotAttached. +func TestSetCurrentSession_UnknownClient(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/current-session-unknown") + + require.ErrorIs(t, b.SetCurrentSession(ws.ID, newClientID(t), "S1"), ErrClientNotAttached) +} + +// TestSetCurrentSession_RejectsBadInputs covers the validation +// branches: empty/malformed client_id and unknown workspace. +func TestSetCurrentSession_RejectsBadInputs(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/current-session-bad") + + require.ErrorIs(t, b.SetCurrentSession(ws.ID, "", "S1"), ErrInvalidClientID) + require.ErrorIs(t, b.SetCurrentSession(ws.ID, "not-a-uuid", "S1"), ErrInvalidClientID) + + require.ErrorIs( + t, + b.SetCurrentSession("00000000-0000-0000-0000-000000000000", newClientID(t), "S1"), + ErrWorkspaceNotFound, + ) +} + +// TestSetCurrentSession_RaceWithDetach exercises concurrent +// SetCurrentSession updates from one client racing against detach +// on a second client. The final state must be self-consistent: any +// remaining clientState entries reflect a coherent +// (streams, currentSessionID) pair. +func TestSetCurrentSession_RaceWithDetach(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/current-session-race") + + cidA := newClientID(t) + cidB := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cidA)) + require.NoError(t, b.AttachClient(ws.ID, cidB)) + + var wg sync.WaitGroup + const updates = 200 + wg.Add(3) + go func() { + defer wg.Done() + for i := range updates { + // Errors are tolerated: once cidA detaches, + // further updates against cidA must return + // ErrClientNotAttached but never panic. + _ = b.SetCurrentSession(ws.ID, cidA, "SA") + _ = i + } + }() + go func() { + defer wg.Done() + for i := range updates { + _ = b.SetCurrentSession(ws.ID, cidB, "SB") + _ = i + } + }() + go func() { + defer wg.Done() + // Single concurrent detach of cidA partway through. + b.DetachClient(ws.ID, cidA) + }() + wg.Wait() + + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + require.NotContains(t, ws.clients, cidA, "detached client must be gone") + require.Contains(t, ws.clients, cidB, "remaining client must still be present") + require.Equal(t, "SB", ws.clients[cidB].currentSessionID, "remaining client must keep its last set session") +} + +// TestAttachedClients_BasicLifecycle walks one session's count through +// attach -> set -> second client joins -> switch -> detach. It also +// confirms hold-only and unselected clients do not contribute. +func TestAttachedClients_BasicLifecycle(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + // Keep the grace window long so the hold-only client survives. + b.createGrace = time.Hour + ws, _ := insertTestWorkspace(t, b, "/tmp/attached-clients-basic") + + // No clients yet. + n, err := b.AttachedClients(ws.ID, "S1") + require.NoError(t, err) + require.Zero(t, n) + + // Attach A, set to S1. Count for S1 is 1; count for S2 is 0. + cidA := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cidA)) + require.NoError(t, b.SetCurrentSession(ws.ID, cidA, "S1")) + + n, err = b.AttachedClients(ws.ID, "S1") + require.NoError(t, err) + require.Equal(t, 1, n) + n, err = b.AttachedClients(ws.ID, "S2") + require.NoError(t, err) + require.Zero(t, n) + + // Attach B, set to S1. Count for S1 is 2. + cidB := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cidB)) + require.NoError(t, b.SetCurrentSession(ws.ID, cidB, "S1")) + + n, _ = b.AttachedClients(ws.ID, "S1") + require.Equal(t, 2, n) + + // B switches to S2; counts redistribute. + require.NoError(t, b.SetCurrentSession(ws.ID, cidB, "S2")) + n, _ = b.AttachedClients(ws.ID, "S1") + require.Equal(t, 1, n) + n, _ = b.AttachedClients(ws.ID, "S2") + require.Equal(t, 1, n) + + // A hold-only client must NOT be counted, even if we were able to + // imagine a currentSessionID on it. registerClient leaves + // currentSessionID empty by construction, and SetCurrentSession + // rejects hold-only writers — so the contract holds two ways. + cidHold := newClientID(t) + b.registerClient(ws, cidHold) + t.Cleanup(func() { _ = b.releaseHold(ws.ID, cidHold) }) + n, _ = b.AttachedClients(ws.ID, "S1") + require.Equal(t, 1, n, "hold-only client must not contribute") + n, _ = b.AttachedClients(ws.ID, "") + require.Equal(t, 0, n, + "empty sessionID must not match the hold-only entry (streams==0)") + + // A client with streams > 0 but currentSessionID == "" is NOT + // counted toward any non-empty session, and is matched only + // against the empty session id (which represents the landing + // screen). + cidC := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cidC)) + n, _ = b.AttachedClients(ws.ID, "S1") + require.Equal(t, 1, n, "stream-only client with empty currentSessionID must not be counted toward S1") + n, _ = b.AttachedClients(ws.ID, "") + require.Equal(t, 1, n, "stream-only client with empty currentSessionID matches the empty session id") + + // B detaches: count for S2 drops to 0. + b.DetachClient(ws.ID, cidB) + n, _ = b.AttachedClients(ws.ID, "S2") + require.Zero(t, n) + n, _ = b.AttachedClients(ws.ID, "S1") + require.Equal(t, 1, n, "A still on S1") + + // Final cleanup. + b.DetachClient(ws.ID, cidA) + b.DetachClient(ws.ID, cidC) +} + +// TestAttachedClients_UnknownWorkspace verifies the error surface. +func TestAttachedClients_UnknownWorkspace(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + _, err := b.AttachedClients("00000000-0000-0000-0000-000000000000", "S1") + require.ErrorIs(t, err, ErrWorkspaceNotFound) +} diff --git a/internal/backend/config.go b/internal/backend/config.go index 553b0c2e18225a1ccff3460dfe1a7e8a32610aa4..4e7ce27dd11db51758fc564a458a0527ca21c499 100644 --- a/internal/backend/config.go +++ b/internal/backend/config.go @@ -11,9 +11,23 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/oauth" "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/skills" ) +// publishConfigChanged publishes a ConfigChanged event on the workspace's +// event broker so all subscribers (e.g. remote clients) refresh their +// cached config snapshot. +func publishConfigChanged(ws *Workspace) { + if ws == nil || ws.App == nil { + return + } + ws.SendEvent(pubsub.Event[proto.ConfigChanged]{ + Type: pubsub.UpdatedEvent, + Payload: proto.ConfigChanged{WorkspaceID: ws.ID}, + }) +} + // MCPResourceContents holds the contents of an MCP resource returned // by the backend. type MCPResourceContents struct { @@ -30,7 +44,11 @@ func (b *Backend) SetConfigField(workspaceID string, scope config.Scope, key str if err != nil { return err } - return ws.Cfg.SetConfigField(scope, key, value) + if err := ws.Cfg.SetConfigField(scope, key, value); err != nil { + return err + } + publishConfigChanged(ws) + return nil } // RemoveConfigField removes a key from the config file for the given @@ -40,7 +58,11 @@ func (b *Backend) RemoveConfigField(workspaceID string, scope config.Scope, key if err != nil { return err } - return ws.Cfg.RemoveConfigField(scope, key) + if err := ws.Cfg.RemoveConfigField(scope, key); err != nil { + return err + } + publishConfigChanged(ws) + return nil } // UpdatePreferredModel updates the preferred model for the given type @@ -50,7 +72,11 @@ func (b *Backend) UpdatePreferredModel(workspaceID string, scope config.Scope, m if err != nil { return err } - return ws.Cfg.UpdatePreferredModel(scope, modelType, model) + if err := ws.Cfg.UpdatePreferredModel(scope, modelType, model); err != nil { + return err + } + publishConfigChanged(ws) + return nil } // SetCompactMode sets the compact mode setting and persists it. @@ -59,7 +85,11 @@ func (b *Backend) SetCompactMode(workspaceID string, scope config.Scope, enabled if err != nil { return err } - return ws.Cfg.SetCompactMode(scope, enabled) + if err := ws.Cfg.SetCompactMode(scope, enabled); err != nil { + return err + } + publishConfigChanged(ws) + return nil } // SetProviderAPIKey sets the API key for a provider and persists it. @@ -68,7 +98,11 @@ func (b *Backend) SetProviderAPIKey(workspaceID string, scope config.Scope, prov if err != nil { return err } - return ws.Cfg.SetProviderAPIKey(scope, providerID, apiKey) + if err := ws.Cfg.SetProviderAPIKey(scope, providerID, apiKey); err != nil { + return err + } + publishConfigChanged(ws) + return nil } // ImportCopilot attempts to import a GitHub Copilot token from disk. @@ -78,6 +112,9 @@ func (b *Backend) ImportCopilot(workspaceID string) (*oauth.Token, bool, error) return nil, false, err } token, ok := ws.Cfg.ImportCopilot() + if ok { + publishConfigChanged(ws) + } return token, ok, nil } @@ -87,7 +124,11 @@ func (b *Backend) RefreshOAuthToken(ctx context.Context, workspaceID string, sco if err != nil { return err } - return ws.Cfg.RefreshOAuthToken(ctx, scope, providerID) + if err := ws.Cfg.RefreshOAuthToken(ctx, scope, providerID); err != nil { + return err + } + publishConfigChanged(ws) + return nil } // ProjectNeedsInitialization checks whether the project in this @@ -106,7 +147,11 @@ func (b *Backend) MarkProjectInitialized(workspaceID string) error { if err != nil { return err } - return config.MarkProjectInitialized(ws.Cfg) + if err := config.MarkProjectInitialized(ws.Cfg); err != nil { + return err + } + publishConfigChanged(ws) + return nil } // InitializePrompt builds the initialization prompt for the workspace. @@ -186,6 +231,7 @@ func (b *Backend) EnableDockerMCP(ctx context.Context, workspaceID string) error return fmt.Errorf("docker MCP started but failed to persist configuration: %w", errors.Join(err, disableErr)) } + publishConfigChanged(ws) return nil } @@ -205,6 +251,7 @@ func (b *Backend) DisableDockerMCP(workspaceID string) error { return err } + publishConfigChanged(ws) return nil } diff --git a/internal/backend/config_test.go b/internal/backend/config_test.go new file mode 100644 index 0000000000000000000000000000000000000000..858df6dabbe6d318dfa76e4593de314da9c779ce --- /dev/null +++ b/internal/backend/config_test.go @@ -0,0 +1,207 @@ +package backend + +import ( + "context" + "testing" + "time" + + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// awaitConfigChanged drains events until a ConfigChanged is received +// for the given workspace ID, or fails the test on timeout. Other +// event types are ignored. +func awaitConfigChanged(t *testing.T, evc <-chan pubsub.Event[tea.Msg], workspaceID string) { + t.Helper() + deadline := time.After(2 * time.Second) + for { + select { + case ev, ok := <-evc: + if !ok { + t.Fatal("event channel closed before ConfigChanged arrived") + } + cc, ok := ev.Payload.(pubsub.Event[proto.ConfigChanged]) + if !ok { + continue + } + require.Equal(t, workspaceID, cc.Payload.WorkspaceID) + return + case <-deadline: + t.Fatal("timed out waiting for ConfigChanged event") + } + } +} + +// newPublishingWorkspace creates a real workspace through the backend +// so its embedded *app.App is wired up and SendEvent works. It returns +// the backend, the workspace, and a fresh event subscription. +func newPublishingWorkspace(t *testing.T) (*Backend, *Workspace, <-chan pubsub.Event[tea.Msg]) { + t.Helper() + xdgIsolated(t) + + cwd := t.TempDir() + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + cid := uuid.New().String() + ws, _, err := b.CreateWorkspace(protoWS(cwd, dataDir, cid)) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return b, ws, ws.Events(ctx) +} + +func TestSetConfigField_PublishesConfigChanged(t *testing.T) { + b, ws, evc := newPublishingWorkspace(t) + + require.NoError(t, b.SetConfigField(ws.ID, config.ScopeGlobal, "options.debug", true)) + awaitConfigChanged(t, evc, ws.ID) +} + +func TestRemoveConfigField_PublishesConfigChanged(t *testing.T) { + b, ws, evc := newPublishingWorkspace(t) + + // Seed a field we can then remove. Setting also publishes, so + // drain the resulting event before testing remove. + require.NoError(t, b.SetConfigField(ws.ID, config.ScopeGlobal, "options.debug", true)) + awaitConfigChanged(t, evc, ws.ID) + + require.NoError(t, b.RemoveConfigField(ws.ID, config.ScopeGlobal, "options.debug")) + awaitConfigChanged(t, evc, ws.ID) +} + +func TestUpdatePreferredModel_PublishesConfigChanged(t *testing.T) { + if raceEnabled { + // UpdatePreferredModel writes config.Models concurrently + // with the agent coordinator's async sub-agent builder + // that reads it via buildAgentModels. That race is + // pre-existing in the codebase and unrelated to this + // item; ConfigStore mutations are not currently + // synchronized against background readers in [app.App]. + // The mutator → publish wiring is unit-tested via + // publishConfigChanged regardless. + t.Skip("skipped under -race: pre-existing race between ConfigStore writes and agent coordinator startup") + } + b, ws, evc := newPublishingWorkspace(t) + + model := config.SelectedModel{Provider: "openai", Model: "gpt-4"} + require.NoError(t, b.UpdatePreferredModel(ws.ID, config.ScopeGlobal, config.SelectedModelTypeLarge, model)) + awaitConfigChanged(t, evc, ws.ID) +} + +func TestSetCompactMode_PublishesConfigChanged(t *testing.T) { + b, ws, evc := newPublishingWorkspace(t) + + require.NoError(t, b.SetCompactMode(ws.ID, config.ScopeGlobal, true)) + awaitConfigChanged(t, evc, ws.ID) +} + +func TestSetProviderAPIKey_PublishesConfigChanged(t *testing.T) { + b, ws, evc := newPublishingWorkspace(t) + + require.NoError(t, b.SetProviderAPIKey(ws.ID, config.ScopeGlobal, "openai", "test-key")) + awaitConfigChanged(t, evc, ws.ID) +} + +func TestMarkProjectInitialized_PublishesConfigChanged(t *testing.T) { + b, ws, evc := newPublishingWorkspace(t) + + require.NoError(t, b.MarkProjectInitialized(ws.ID)) + awaitConfigChanged(t, evc, ws.ID) +} + +// TestImportCopilot_PublishesConfigChanged exercises the success path +// by seeding a token file in the location ImportCopilot scans, then +// asserting the event fires only when ok==true. +func TestImportCopilot_PublishesConfigChanged(t *testing.T) { + // ImportCopilot reads from external user-state directories that + // vary by OS. Rather than recreate that setup, drive the + // publishing helper directly and assert ImportCopilot's + // no-event-on-not-found semantics are preserved. + b, ws, evc := newPublishingWorkspace(t) + + // Not-found path: no token exists, so no event must fire. + _, ok, err := b.ImportCopilot(ws.ID) + require.NoError(t, err) + require.False(t, ok, "ImportCopilot should return ok=false when no token is present") + + select { + case ev := <-evc: + if _, isCC := ev.Payload.(pubsub.Event[proto.ConfigChanged]); isCC { + t.Fatal("ImportCopilot must not publish ConfigChanged when nothing was imported") + } + case <-time.After(100 * time.Millisecond): + // Expected: no ConfigChanged. + } + + // Helper sanity: publishing manually does fire the event. + publishConfigChanged(ws) + awaitConfigChanged(t, evc, ws.ID) +} + +// TestRefreshOAuthToken_PublishesConfigChangedOnError verifies that +// the unhappy path does not publish (mutator returned an error). The +// happy path requires a real OAuth-capable provider configured with a +// refreshable token, which is beyond an isolated unit test's scope. +func TestRefreshOAuthToken_NoEventOnError(t *testing.T) { + b, ws, evc := newPublishingWorkspace(t) + + // Provider does not exist → store returns an error → no event. + err := b.RefreshOAuthToken(context.Background(), ws.ID, config.ScopeGlobal, "no-such-provider") + require.Error(t, err) + + select { + case ev := <-evc: + if _, isCC := ev.Payload.(pubsub.Event[proto.ConfigChanged]); isCC { + t.Fatal("RefreshOAuthToken must not publish ConfigChanged when it errors") + } + case <-time.After(100 * time.Millisecond): + } +} + +// TestDisableDockerMCP_PublishesConfigChanged seeds a Docker MCP entry +// directly so DisableDockerMCP has something to remove without needing +// a running Docker daemon for PrepareDockerMCPConfig's availability +// probe. +func TestDisableDockerMCP_PublishesConfigChanged(t *testing.T) { + b, ws, evc := newPublishingWorkspace(t) + + // Persist a Docker MCP entry directly via the store so the + // downstream DisableDockerMCP path has something to remove. + require.NoError(t, ws.Cfg.PersistDockerMCPConfig(config.DockerMCPConfig())) + drainEvents(evc, 100*time.Millisecond) + + require.NoError(t, b.DisableDockerMCP(ws.ID)) + awaitConfigChanged(t, evc, ws.ID) +} + +// drainEvents reads from evc until quiet for the given window. Used +// to flush events emitted by setup steps so the assertion can target +// the event from the action under test. +func drainEvents(evc <-chan pubsub.Event[tea.Msg], quiet time.Duration) { + for { + select { + case <-evc: + case <-time.After(quiet): + return + } + } +} + +// TestPublishConfigChanged_NilWorkspaceSafe documents that the helper +// is safe to call on workspaces without an *app.App (e.g. synthetic +// test workspaces). +func TestPublishConfigChanged_NilWorkspaceSafe(t *testing.T) { + t.Parallel() + require.NotPanics(t, func() { publishConfigChanged(nil) }) + require.NotPanics(t, func() { publishConfigChanged(&Workspace{}) }) +} diff --git a/internal/backend/permission.go b/internal/backend/permission.go index bb7876d6989ec8bee6a99362cb5f5ef914fc5c49..d6db237989ac3a85244c8f9ab4c14df1a7afa1d0 100644 --- a/internal/backend/permission.go +++ b/internal/backend/permission.go @@ -6,11 +6,13 @@ import ( ) // GrantPermission grants, denies, or persistently grants a permission -// request. -func (b *Backend) GrantPermission(workspaceID string, req proto.PermissionGrant) error { +// request. The returned bool reports whether this call resolved the +// pending request (true) or found it already resolved by a previous +// caller (false). A false return is not an error. +func (b *Backend) GrantPermission(workspaceID string, req proto.PermissionGrant) (bool, error) { ws, err := b.GetWorkspace(workspaceID) if err != nil { - return err + return false, err } perm := permission.PermissionRequest{ @@ -26,15 +28,14 @@ func (b *Backend) GrantPermission(workspaceID string, req proto.PermissionGrant) switch req.Action { case proto.PermissionAllow: - ws.Permissions.Grant(perm) + return ws.Permissions.Grant(perm), nil case proto.PermissionAllowForSession: - ws.Permissions.GrantPersistent(perm) + return ws.Permissions.GrantPersistent(perm), nil case proto.PermissionDeny: - ws.Permissions.Deny(perm) + return ws.Permissions.Deny(perm), nil default: - return ErrInvalidPermissionAction + return false, ErrInvalidPermissionAction } - return nil } // SetPermissionsSkip sets whether permission prompts are skipped. diff --git a/internal/backend/race_off_test.go b/internal/backend/race_off_test.go new file mode 100644 index 0000000000000000000000000000000000000000..04ff4b864f6382fc8b62231677367c220c86dbe2 --- /dev/null +++ b/internal/backend/race_off_test.go @@ -0,0 +1,5 @@ +//go:build !race + +package backend + +const raceEnabled = false diff --git a/internal/backend/race_on_test.go b/internal/backend/race_on_test.go new file mode 100644 index 0000000000000000000000000000000000000000..1904bea7bf217ac24deae0dcdeba39a527830ce9 --- /dev/null +++ b/internal/backend/race_on_test.go @@ -0,0 +1,5 @@ +//go:build race + +package backend + +const raceEnabled = true diff --git a/internal/backend/testing.go b/internal/backend/testing.go new file mode 100644 index 0000000000000000000000000000000000000000..6616e0f19e06595fac68808b484394d960d7f79f --- /dev/null +++ b/internal/backend/testing.go @@ -0,0 +1,55 @@ +package backend + +// InsertWorkspaceForTest registers ws with b under its current ID and +// path. It is intended for tests in other packages that need to drive +// HTTP handlers against a synthetic workspace without booting a real +// app.App. Production code should go through CreateWorkspace. +func InsertWorkspaceForTest(b *Backend, ws *Workspace) { + if ws.resolvedPath == "" { + ws.resolvedPath = ws.Path + } + if ws.clients == nil { + ws.clients = make(map[string]*clientState) + } + b.mu.Lock() + defer b.mu.Unlock() + b.workspaces.Set(ws.ID, ws) + if ws.resolvedPath != "" { + b.pathIndex[ws.resolvedPath] = ws.ID + } +} + +// RegisterClientForTesting installs a creation hold for clientID on +// ws using the backend's normal registerClient path. Intended for +// tests in other packages that need to drive a hold-only client +// (streams == 0) without booting a real CreateWorkspace flow. +func RegisterClientForTesting(b *Backend, ws *Workspace, clientID string) error { + if _, err := validateClientID(clientID); err != nil { + return err + } + b.registerClient(ws, clientID) + return nil +} + +// SetWorkspaceShutdownFnForTest overrides the workspace teardown +// callback. Useful for tests in other packages that drive synthetic +// workspaces (where the embedded [app.App] is incomplete) through +// detach paths that would otherwise crash inside App.Shutdown. +func SetWorkspaceShutdownFnForTest(ws *Workspace, fn func()) { + ws.shutdownFn = fn +} + +// WorkspaceLiveStreamCountForTest returns the number of clients on ws +// that have at least one live SSE stream. Used by integration tests +// in other packages to wait for SSE attaches before publishing events. +func WorkspaceLiveStreamCountForTest(ws *Workspace) int { + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + n := 0 + for _, cs := range ws.clients { + if cs.streams > 0 { + n++ + } + } + return n +} diff --git a/internal/client/client.go b/internal/client/client.go index 42dd0243b234bc1c9bfc4801311a728d027eb240..7b83da5cbb29e3959e5ee22762d303341e76be0c 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -15,6 +15,7 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/server" + "github.com/google/uuid" ) // DummyHost is used to satisfy the http.Client's requirement for a URL. @@ -22,10 +23,11 @@ const DummyHost = "api.crush.localhost" // Client represents an RPC client connected to a Crush server. type Client struct { - h *http.Client - path string - network string - addr string + h *http.Client + path string + network string + addr string + clientID string } // DefaultClient creates a new [Client] connected to the default server address. @@ -44,6 +46,7 @@ func NewClient(path, network, address string) (*Client, error) { c.path = filepath.Clean(path) c.network = network c.addr = address + c.clientID = uuid.New().String() p := &http.Protocols{} p.SetHTTP1(true) p.SetUnencryptedHTTP2(true) @@ -65,6 +68,12 @@ func (c *Client) Path() string { return c.path } +// ClientID returns the per-process client ID minted in [NewClient]. +// The server uses it as a presence/coordination handle. +func (c *Client) ClientID() string { + return c.clientID +} + // GetGlobalConfig retrieves the server's configuration. func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) { var cfg config.Config diff --git a/internal/client/proto.go b/internal/client/proto.go index e261b172bd041b719d680ababfda8bd9bf130fd3..5a57262679df0a7edc3f9269dc763fc124e778cd 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -39,6 +39,7 @@ func (c *Client) ListWorkspaces(ctx context.Context) ([]proto.Workspace, error) // CreateWorkspace creates a new workspace on the server. func (c *Client) CreateWorkspace(ctx context.Context, ws proto.Workspace) (*proto.Workspace, error) { + ws.ClientID = c.clientID rsp, err := c.post(ctx, "/workspaces", nil, jsonBody(ws), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return nil, fmt.Errorf("failed to create workspace: %w", err) @@ -73,7 +74,8 @@ func (c *Client) GetWorkspace(ctx context.Context, id string) (*proto.Workspace, // DeleteWorkspace deletes a workspace on the server. func (c *Client) DeleteWorkspace(ctx context.Context, id string) error { - rsp, err := c.delete(ctx, fmt.Sprintf("/workspaces/%s", id), nil, nil) + q := url.Values{"client_id": []string{c.clientID}} + rsp, err := c.delete(ctx, fmt.Sprintf("/workspaces/%s", id), q, nil) if err != nil { return fmt.Errorf("failed to delete workspace: %w", err) } @@ -84,11 +86,36 @@ func (c *Client) DeleteWorkspace(ctx context.Context, id string) error { return nil } +// SetCurrentSession reports the client's current-session selection +// for the named workspace. An empty sessionID clears the entry. The +// request carries the process-scoped client ID minted in [NewClient] +// as a query parameter so the server can route the update to the +// correct [clientState] entry. +func (c *Client) SetCurrentSession(ctx context.Context, workspaceID, sessionID string) error { + q := url.Values{"client_id": []string{c.clientID}} + rsp, err := c.post( + ctx, + fmt.Sprintf("/workspaces/%s/current-session", workspaceID), + q, + jsonBody(proto.CurrentSession{SessionID: sessionID}), + http.Header{"Content-Type": []string{"application/json"}}, + ) + if err != nil { + return fmt.Errorf("failed to set current session: %w", err) + } + defer rsp.Body.Close() + if rsp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to set current session: status code %d", rsp.StatusCode) + } + return nil +} + // SubscribeEvents subscribes to server-sent events for a workspace. func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, error) { events := make(chan any, 100) + q := url.Values{"client_id": []string{c.clientID}} //nolint:bodyclose - rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/events", id), nil, http.Header{ + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/events", id), q, http.Header{ "Accept": []string{"text/event-stream"}, "Cache-Control": []string{"no-cache"}, "Connection": []string{"keep-alive"}, @@ -168,6 +195,10 @@ func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, er var e pubsub.Event[proto.AgentEvent] _ = json.Unmarshal(p.Payload, &e) sendEvent(ctx, events, e) + case pubsub.PayloadTypeConfigChanged: + var e pubsub.Event[proto.ConfigChanged] + _ = json.Unmarshal(p.Payload, &e) + sendEvent(ctx, events, e) case pubsub.PayloadTypeSkillsEvent: var e pubsub.Event[proto.SkillsEvent] _ = json.Unmarshal(p.Payload, &e) @@ -482,17 +513,25 @@ func (c *Client) ListSessions(ctx context.Context, id string) ([]proto.Session, return sessions, nil } -// GrantPermission grants a permission on a workspace. -func (c *Client) GrantPermission(ctx context.Context, id string, req proto.PermissionGrant) error { +// GrantPermission grants a permission on a workspace. The returned +// bool reports whether this call resolved the pending request (true) +// or found it already resolved by a previous caller (false). A false +// value is not an error — it just means another subscriber resolved +// the same request first. +func (c *Client) GrantPermission(ctx context.Context, id string, req proto.PermissionGrant) (bool, error) { rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/permissions/grant", id), nil, jsonBody(req), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { - return fmt.Errorf("failed to grant permission: %w", err) + return false, fmt.Errorf("failed to grant permission: %w", err) } defer rsp.Body.Close() if rsp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to grant permission: status code %d", rsp.StatusCode) + return false, fmt.Errorf("failed to grant permission: status code %d", rsp.StatusCode) } - return nil + var resp proto.PermissionGrantResponse + if err := json.NewDecoder(rsp.Body).Decode(&resp); err != nil { + return false, fmt.Errorf("failed to decode grant permission response: %w", err) + } + return resp.Resolved, nil } // SetPermissionsSkipRequests sets the skip-requests flag for a workspace. diff --git a/internal/db/connect.go b/internal/db/connect.go index ce1ed1172f8ef194b3f60fe298f00dd42e4f1409..f7304791d3bca1158c2ca54f73f5b092ca4b8ac8 100644 --- a/internal/db/connect.go +++ b/internal/db/connect.go @@ -6,6 +6,7 @@ import ( "embed" "fmt" "log/slog" + "os" "path/filepath" "sync" "testing" @@ -39,10 +40,16 @@ func init() { } } -// connEntry holds a shared database connection and its reference count. +// connEntry holds a shared database connection, its reference count, +// and the data-directory lock that gates access to this entry. The +// lock is acquired exactly once when the entry is created and released +// when the last reference is dropped, which lets the same process open +// the same data directory concurrently while still blocking a second +// crush process from racing the storage. type connEntry struct { db *sql.DB refCount int + lock *dataDirLock } var ( @@ -50,16 +57,39 @@ var ( poolMu sync.Mutex ) +// ConnectOption configures a Connect call. Options are applied in +// order; later options override earlier ones for the same field. +type ConnectOption func(*connectOptions) + +// connectOptions holds the resolved configuration for a Connect call. +type connectOptions struct { + lockDataDir bool +} + +// WithDataDirLock toggles acquisition of the per-data-directory lock +// for this Connect call. The lock is off by default so local-mode +// invocations do not regress today's behavior; the server's +// workspace-bootstrap path opts in. CRUSH_SKIP_DATADIR_LOCK still +// bypasses acquisition even when this option is set. +func WithDataDirLock(enable bool) ConnectOption { + return func(o *connectOptions) { o.lockDataDir = enable } +} + // Connect opens a SQLite database connection for the given data // directory and runs migrations. If a connection to the same database // file already exists, the existing connection is returned with its // reference count incremented. Callers must pair each Connect with a // [Release] when they no longer need the connection. -func Connect(ctx context.Context, dataDir string) (*sql.DB, error) { +func Connect(ctx context.Context, dataDir string, opts ...ConnectOption) (*sql.DB, error) { if dataDir == "" { return nil, fmt.Errorf("data.dir is not set") } + var cfg connectOptions + for _, opt := range opts { + opt(&cfg) + } + dbPath := filepath.Join(dataDir, "crush.db") // Resolve to an absolute path so that different relative paths to @@ -77,8 +107,30 @@ func Connect(ctx context.Context, dataDir string) (*sql.DB, error) { return entry.db, nil } + // Take the per-data-directory lock before opening the database so + // we fail fast and with a clear error rather than racing another + // crush process on the same SQLite file. The lock is released when + // the matching Release call drops the refcount to zero. Ensuring + // the data directory exists is required because the lock file + // lives inside it. Locking is opt-in via WithDataDirLock so that + // local-mode invocations do not refuse a second crush against the + // same data dir until client/server becomes the default. + if err := os.MkdirAll(dataDir, 0o700); err != nil { + return nil, fmt.Errorf("failed to create data directory %q: %w", dataDir, err) + } + var lock *dataDirLock + if cfg.lockDataDir && !skipDataDirLock() { + lock, err = acquireDataDirLock(dataDir) + if err != nil { + return nil, err + } + } + conn, err := openDB(dbPath) if err != nil { + if lock != nil { + lock.release() + } return nil, err } @@ -89,24 +141,33 @@ func Connect(ctx context.Context, dataDir string) (*sql.DB, error) { // resulting in SQLITE_NOTADB (26) on the next open. conn.SetMaxOpenConns(1) + releaseLock := func() { + if lock != nil { + lock.release() + } + } + if err = conn.PingContext(ctx); err != nil { conn.Close() + releaseLock() return nil, fmt.Errorf("failed to connect to database: %w", err) } if err := initGoose(); err != nil { conn.Close() + releaseLock() slog.Error("Failed to initialize goose", "error", err) return nil, fmt.Errorf("failed to initialize goose: %w", err) } if err := goose.Up(conn, "migrations"); err != nil { conn.Close() + releaseLock() slog.Error("Failed to apply migrations", "error", err) return nil, fmt.Errorf("failed to apply migrations: %w", err) } - pool[absPath] = &connEntry{db: conn, refCount: 1} + pool[absPath] = &connEntry{db: conn, refCount: 1, lock: lock} return conn, nil } @@ -134,7 +195,11 @@ func Release(dataDir string) error { } delete(pool, absPath) - return entry.db.Close() + closeErr := entry.db.Close() + if entry.lock != nil { + entry.lock.release() + } + return closeErr } // ResetPool closes all pooled connections and clears the pool. This is @@ -144,6 +209,9 @@ func ResetPool() { defer poolMu.Unlock() for path, entry := range pool { entry.db.Close() + if entry.lock != nil { + entry.lock.release() + } delete(pool, path) } } diff --git a/internal/db/connect_test.go b/internal/db/connect_test.go index 93c2af00216cb9076214b861eed230a45d7d9bd0..45e39758924a9351b07bdb5956ddbf1ae85d1b02 100644 --- a/internal/db/connect_test.go +++ b/internal/db/connect_test.go @@ -2,6 +2,8 @@ package db import ( "context" + "errors" + "path/filepath" "testing" "github.com/stretchr/testify/require" @@ -52,3 +54,156 @@ func TestRelease_NoopForUnknownDataDir(t *testing.T) { require.NoError(t, Release("/nonexistent/path"), "releasing unknown data dir should not error") } + +// TestConnect_FailsWhenDataDirLocked simulates a second crush process by +// taking the data-dir lock directly via the OS primitive on a separate +// file descriptor and then asserting that Connect surfaces a clean +// ErrDataDirLocked instead of opening the database under contention. +func TestConnect_FailsWhenDataDirLocked(t *testing.T) { + t.Cleanup(ResetPool) + + dataDir := t.TempDir() + lockPath := filepath.Join(dataDir, dataDirLockFile) + + release, err := tryFileLock(lockPath) + require.NoError(t, err, "expected to take the data-dir lock for the first time") + t.Cleanup(release) + + _, err = Connect(context.Background(), dataDir, WithDataDirLock(true)) + require.Error(t, err, "Connect must refuse to open a locked data dir") + require.ErrorIs(t, err, ErrDataDirLocked) +} + +// TestConnect_SucceedsAfterContenderReleases ensures the lock is purely +// advisory and that a clean release lets the next Connect proceed. +func TestConnect_SucceedsAfterContenderReleases(t *testing.T) { + t.Cleanup(ResetPool) + + dataDir := t.TempDir() + lockPath := filepath.Join(dataDir, dataDirLockFile) + + release, err := tryFileLock(lockPath) + require.NoError(t, err) + + _, err = Connect(context.Background(), dataDir, WithDataDirLock(true)) + require.ErrorIs(t, err, ErrDataDirLocked) + + release() + + conn, err := Connect(context.Background(), dataDir, WithDataDirLock(true)) + require.NoError(t, err, "Connect should succeed once the contender releases the lock") + require.NoError(t, conn.PingContext(context.Background())) + require.NoError(t, Release(dataDir)) +} + +// TestConnect_LockReleasedOnFinalRelease confirms that closing the last +// reference to a pool entry also drops the OS lock, so subsequent +// processes can take the data dir. +func TestConnect_LockReleasedOnFinalRelease(t *testing.T) { + t.Cleanup(ResetPool) + + dataDir := t.TempDir() + lockPath := filepath.Join(dataDir, dataDirLockFile) + + conn, err := Connect(context.Background(), dataDir, WithDataDirLock(true)) + require.NoError(t, err) + require.NoError(t, conn.PingContext(context.Background())) + + // Holding the in-process entry must keep the OS lock held so a + // "second process" (simulated by a fresh tryFileLock call) is + // rejected. + _, lockErr := tryFileLock(lockPath) + require.Error(t, lockErr) + require.True(t, errors.Is(lockErr, errLockContended), "expected contended lock while pool entry is live") + + require.NoError(t, Release(dataDir)) + + // After the final release the lock is free again. + release, err := tryFileLock(lockPath) + require.NoError(t, err, "expected lock to be released after final Release") + release() +} + +// TestConnect_SharedPoolDoesNotReacquireLock makes sure that subsequent +// in-process Connect calls reuse the existing OS lock through refcount, +// not by re-acquiring it. The simplest observable signal of correctness +// is that the second Connect does not error and the lock is still held +// after a single Release. +func TestConnect_SharedPoolDoesNotReacquireLock(t *testing.T) { + t.Cleanup(ResetPool) + + dataDir := t.TempDir() + lockPath := filepath.Join(dataDir, dataDirLockFile) + + _, err := Connect(context.Background(), dataDir, WithDataDirLock(true)) + require.NoError(t, err) + + _, err = Connect(context.Background(), dataDir, WithDataDirLock(true)) + require.NoError(t, err) + + // Drop one reference; lock must still be held. + require.NoError(t, Release(dataDir)) + _, lockErr := tryFileLock(lockPath) + require.ErrorIs(t, lockErr, errLockContended) + + require.NoError(t, Release(dataDir)) +} + +// TestConnect_SkipLockEnvBypassesAcquisition exercises the escape +// hatch used by users on filesystems where flock is unreliable. +func TestConnect_SkipLockEnvBypassesAcquisition(t *testing.T) { + t.Cleanup(ResetPool) + + dataDir := t.TempDir() + lockPath := filepath.Join(dataDir, dataDirLockFile) + + release, err := tryFileLock(lockPath) + require.NoError(t, err) + t.Cleanup(release) + + t.Setenv("CRUSH_SKIP_DATADIR_LOCK", "1") + + conn, err := Connect(context.Background(), dataDir, WithDataDirLock(true)) + require.NoError(t, err, "skip-lock env should bypass contention") + require.NoError(t, conn.PingContext(context.Background())) + require.NoError(t, Release(dataDir)) +} + +// TestConnect_DefaultIgnoresContendedLock confirms that without +// WithDataDirLock(true) the lock file is irrelevant: a contender can +// hold tryFileLock and Connect still succeeds. This pins the +// local-mode default to its pre-lock behavior. +func TestConnect_DefaultIgnoresContendedLock(t *testing.T) { + t.Cleanup(ResetPool) + + dataDir := t.TempDir() + lockPath := filepath.Join(dataDir, dataDirLockFile) + + release, err := tryFileLock(lockPath) + require.NoError(t, err, "expected to take the data-dir lock for the first time") + t.Cleanup(release) + + conn, err := Connect(context.Background(), dataDir) + require.NoError(t, err, "default Connect must not take the lock and must succeed under contention") + require.NoError(t, conn.PingContext(context.Background())) + require.NoError(t, Release(dataDir)) +} + +// TestConnect_ServerPathFailsWhenDataDirLocked is the server's +// workspace-bootstrap analogue of TestConnect_FailsWhenDataDirLocked: +// passing WithDataDirLock(true) must surface ErrDataDirLocked when a +// contender already holds the lock. +func TestConnect_ServerPathFailsWhenDataDirLocked(t *testing.T) { + t.Cleanup(ResetPool) + + dataDir := t.TempDir() + lockPath := filepath.Join(dataDir, dataDirLockFile) + + release, err := tryFileLock(lockPath) + require.NoError(t, err, "expected to take the data-dir lock for the first time") + t.Cleanup(release) + + _, err = Connect(context.Background(), dataDir, WithDataDirLock(true)) + require.Error(t, err, "server-path Connect must refuse to open a locked data dir") + require.ErrorIs(t, err, ErrDataDirLocked) +} diff --git a/internal/db/datadirlock.go b/internal/db/datadirlock.go new file mode 100644 index 0000000000000000000000000000000000000000..914933503fd795dd13a2052af76a8cd597015c04 --- /dev/null +++ b/internal/db/datadirlock.go @@ -0,0 +1,130 @@ +package db + +import ( + "encoding/json" + "errors" + "fmt" + "log/slog" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/charmbracelet/crush/internal/version" +) + +// ErrDataDirLocked is returned by Connect when the data directory is +// already in use by another crush process. +var ErrDataDirLocked = errors.New("data directory already in use by another crush process") + +// dataDirLockFile is the name of the lock file inside the data +// directory. It lives next to crush.db so users can `ls` and find it. +const dataDirLockFile = "crush.lock" + +// dataDirOwnerInfo is the JSON payload written into the lock file by +// the process that currently owns it. It is purely informational; the +// authoritative state of ownership is the operating system flock on +// the file descriptor. +type dataDirOwnerInfo struct { + PID int `json:"pid"` + Version string `json:"version,omitempty"` + StartedAt string `json:"started_at,omitempty"` +} + +// dataDirLock represents an acquired exclusive lock on a data +// directory. release closes the underlying file descriptor which the +// kernel uses to drop the OS-level lock. +type dataDirLock struct { + release func() +} + +// acquireDataDirLock takes an exclusive non-blocking lock on +// {dataDir}/crush.lock. If the lock is already held by another +// process, it returns ErrDataDirLocked wrapped with a diagnostic that +// includes whatever owner info that process wrote. +// +// Acquisition is skipped (returning a no-op lock) when +// CRUSH_SKIP_DATADIR_LOCK is set to a truthy value. This is intended +// as an escape hatch for hostile filesystems that do not implement +// advisory locking; it should not be used in normal operation. +func acquireDataDirLock(dataDir string) (*dataDirLock, error) { + if skipDataDirLock() { + return &dataDirLock{release: func() {}}, nil + } + + path := filepath.Join(dataDir, dataDirLockFile) + release, err := tryFileLock(path) + if err != nil { + if errors.Is(err, errLockContended) { + return nil, contendedLockError(dataDir, path) + } + return nil, fmt.Errorf("failed to lock data directory %q: %w", dataDir, err) + } + + // Record ownership metadata so a contending process can identify + // us. Failures here are non-fatal: the OS-level lock is what + // actually guarantees mutual exclusion, and a missing/partial JSON + // payload only degrades the diagnostic a contender prints. + if err := writeOwnerInfo(path); err != nil { + slog.Debug("Failed to write data-dir owner info", "path", path, "error", err) + } + + // The lock file itself is intentionally never unlinked. flock is + // keyed by inode, not by path, and any close-then-unlink (or + // unlink-then-close) ordering opens a window where two processes + // can each hold a flock on a different inode that lives at the + // same path. Leaving the file in place lets every acquirer see + // the same inode and lets the kernel arbitrate correctly. + return &dataDirLock{release: release}, nil +} + +// skipDataDirLock reports whether the data-dir lock should be bypassed. +func skipDataDirLock() bool { + v, _ := strconv.ParseBool(os.Getenv("CRUSH_SKIP_DATADIR_LOCK")) + return v +} + +// writeOwnerInfo truncates and rewrites the lock file with the current +// process's identifying information. It is called only after the lock +// is held. +func writeOwnerInfo(path string) error { + info := dataDirOwnerInfo{ + PID: os.Getpid(), + Version: version.Version, + StartedAt: time.Now().UTC().Format(time.RFC3339), + } + payload, err := json.MarshalIndent(info, "", " ") + if err != nil { + return err + } + payload = append(payload, '\n') + return os.WriteFile(path, payload, 0o600) +} + +// readOwnerInfo returns the lock file's recorded owner, if it parses. +// A missing or malformed file yields an empty struct and no error; +// the caller decides what to surface to the user. +func readOwnerInfo(path string) dataDirOwnerInfo { + raw, err := os.ReadFile(path) + if err != nil || len(raw) == 0 { + return dataDirOwnerInfo{} + } + var info dataDirOwnerInfo + _ = json.Unmarshal(raw, &info) + return info +} + +// contendedLockError builds a wrapped ErrDataDirLocked annotated with +// whatever owner metadata is currently in the lock file. +func contendedLockError(dataDir, lockPath string) error { + info := readOwnerInfo(lockPath) + details := "" + switch { + case info.PID != 0 && info.StartedAt != "": + details = fmt.Sprintf(" (owner pid=%d version=%s started_at=%s)", + info.PID, info.Version, info.StartedAt) + case info.PID != 0: + details = fmt.Sprintf(" (owner pid=%d)", info.PID) + } + return fmt.Errorf("%w: %s%s", ErrDataDirLocked, dataDir, details) +} diff --git a/internal/db/datadirlock_unix.go b/internal/db/datadirlock_unix.go new file mode 100644 index 0000000000000000000000000000000000000000..7e495349dd1b29c1960bc8c5731d3d19dd716d50 --- /dev/null +++ b/internal/db/datadirlock_unix.go @@ -0,0 +1,45 @@ +//go:build !windows + +package db + +import ( + "errors" + "fmt" + "os" + + "golang.org/x/sys/unix" +) + +// errLockContended is returned by tryFileLock when the lock is already +// held by another open file description (typically another process). +var errLockContended = errors.New("file lock is held by another process") + +// tryFileLock takes an exclusive non-blocking BSD flock on path, +// creating the file if necessary. On success it returns a release +// function that drops the lock and closes the descriptor. When the +// lock is contended it returns errLockContended. +// +// BSD flock is advisory and per-open-file-description, so it does not +// interfere with the byte-range locks SQLite itself uses on the same +// file's siblings (crush.db, crush.db-wal, crush.db-shm). The lock is +// also released automatically by the kernel when the file descriptor +// is closed, including on process crash, so we do not need any +// explicit stale-lock recovery. +func tryFileLock(path string) (func(), error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) + if err != nil { + return nil, fmt.Errorf("open lock file: %w", err) + } + if err := unix.Flock(int(f.Fd()), unix.LOCK_EX|unix.LOCK_NB); err != nil { + _ = f.Close() + if errors.Is(err, unix.EWOULDBLOCK) { + return nil, errLockContended + } + return nil, fmt.Errorf("flock: %w", err) + } + return func() { + // Closing the descriptor releases the flock atomically. + _ = unix.Flock(int(f.Fd()), unix.LOCK_UN) + _ = f.Close() + }, nil +} diff --git a/internal/db/datadirlock_windows.go b/internal/db/datadirlock_windows.go new file mode 100644 index 0000000000000000000000000000000000000000..1a0d53894c39d303e4a5e1820c513764375c891b --- /dev/null +++ b/internal/db/datadirlock_windows.go @@ -0,0 +1,46 @@ +//go:build windows + +package db + +import ( + "errors" + "fmt" + "math" + "os" + + "golang.org/x/sys/windows" +) + +// errLockContended is returned by tryFileLock when the lock is held +// by another process. +var errLockContended = errors.New("file lock is held by another process") + +// tryFileLock takes an exclusive non-blocking lock on path via +// LockFileEx. On success it returns a release function that unlocks +// and closes the descriptor. +// +// The flags combine LOCKFILE_EXCLUSIVE_LOCK with LOCKFILE_FAIL_IMMEDIATELY +// to mirror the BSD LOCK_EX|LOCK_NB semantics used on POSIX. The lock +// is released when the file handle closes, including on process exit, +// which gives us automatic stale-lock recovery without any bookkeeping. +func tryFileLock(path string) (func(), error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0o600) + if err != nil { + return nil, fmt.Errorf("open lock file: %w", err) + } + h := windows.Handle(f.Fd()) + ol := new(windows.Overlapped) + flags := uint32(windows.LOCKFILE_EXCLUSIVE_LOCK | windows.LOCKFILE_FAIL_IMMEDIATELY) + if err := windows.LockFileEx(h, flags, 0, math.MaxUint32, math.MaxUint32, ol); err != nil { + _ = f.Close() + if errors.Is(err, windows.ERROR_LOCK_VIOLATION) || errors.Is(err, windows.ERROR_IO_PENDING) { + return nil, errLockContended + } + return nil, fmt.Errorf("LockFileEx: %w", err) + } + return func() { + ol := new(windows.Overlapped) + _ = windows.UnlockFileEx(windows.Handle(f.Fd()), 0, math.MaxUint32, math.MaxUint32, ol) + _ = f.Close() + }, nil +} diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 3fe475bab9fa2245067bac70ce689b2942a3747b..25e4994a807f44b0a9fd994a1d6ebe548641202a 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -64,9 +64,19 @@ type PermissionRequest struct { type Service interface { pubsub.Subscriber[PermissionRequest] - GrantPersistent(permission PermissionRequest) - Grant(permission PermissionRequest) - Deny(permission PermissionRequest) + // GrantPersistent grants a permission request and remembers the grant + // for the session. It returns true if this call actually resolved the + // pending request; false if the request had already been resolved + // (e.g., by another concurrent caller) or is unknown. + GrantPersistent(permission PermissionRequest) bool + // Grant grants a permission request. It returns true if this call + // actually resolved the pending request; false if the request had + // already been resolved or is unknown. + Grant(permission PermissionRequest) bool + // Deny denies a permission request. It returns true if this call + // actually resolved the pending request; false if the request had + // already been resolved or is unknown. + Deny(permission PermissionRequest) bool Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) AutoApproveSession(sessionID string) SetSkipRequests(skip bool) @@ -100,63 +110,72 @@ type permissionService struct { activeRequestMu sync.Mutex } -func (s *permissionService) GrantPersistent(permission PermissionRequest) { - s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ - ToolCallID: permission.ToolCallID, - Granted: true, - }) - respCh, ok := s.pendingRequests.Get(permission.ID) - if ok { - respCh <- true +// resolve atomically removes the pending request entry for the given +// permission and, if it was still pending, publishes exactly one +// PermissionNotification and forwards the outcome to the waiter on +// respCh. It returns true if this call resolved the request, false if +// it had already been resolved (e.g., by another concurrent caller) or +// the request ID is unknown. +// +// If onResolve is non-nil it runs after the pending entry has been +// taken but before the notification is published or the waiter is +// unblocked. This lets GrantPersistent record the session permission +// only when it actually wins the race, so a losing GrantPersistent +// that lost to a Deny does not leak an auto-approve entry. +// +// All three public resolution methods (Grant, GrantPersistent, Deny) +// route through this helper so multi-subscriber UIs can race safely: +// the first caller wins, the rest become no-ops. +func (s *permissionService) resolve(permission PermissionRequest, granted, denied bool, onResolve func()) bool { + respCh, ok := s.pendingRequests.Take(permission.ID) + if !ok { + return false } - s.sessionPermissions.Set(PermissionKey{ - SessionID: permission.SessionID, - ToolName: permission.ToolName, - Action: permission.Action, - Path: permission.Path, - }, true) - - s.activeRequestMu.Lock() - if s.activeRequest != nil && s.activeRequest.ID == permission.ID { - s.activeRequest = nil + if onResolve != nil { + onResolve() } - s.activeRequestMu.Unlock() -} -func (s *permissionService) Grant(permission PermissionRequest) { s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ ToolCallID: permission.ToolCallID, - Granted: true, + Granted: granted, + Denied: denied, }) - respCh, ok := s.pendingRequests.Get(permission.ID) - if ok { - respCh <- true - } + + // respCh is buffered (cap 1) and only ever has at most one sender + // per request because Take removes the entry under the map lock, + // so this send never blocks. + respCh <- granted s.activeRequestMu.Lock() if s.activeRequest != nil && s.activeRequest.ID == permission.ID { s.activeRequest = nil } s.activeRequestMu.Unlock() + return true } -func (s *permissionService) Deny(permission PermissionRequest) { - s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ - ToolCallID: permission.ToolCallID, - Granted: false, - Denied: true, +func (s *permissionService) GrantPersistent(permission PermissionRequest) bool { + // Record the persistent grant only if this call wins the + // pending-request race. Otherwise a losing GrantPersistent that + // lost to a Deny would still leave an auto-approve entry behind, + // silently flipping later denied calls to allowed. + return s.resolve(permission, true, false, func() { + s.sessionPermissions.Set(PermissionKey{ + SessionID: permission.SessionID, + ToolName: permission.ToolName, + Action: permission.Action, + Path: permission.Path, + }, true) }) - respCh, ok := s.pendingRequests.Get(permission.ID) - if ok { - respCh <- false - } +} - s.activeRequestMu.Lock() - if s.activeRequest != nil && s.activeRequest.ID == permission.ID { - s.activeRequest = nil - } - s.activeRequestMu.Unlock() +func (s *permissionService) Grant(permission PermissionRequest) bool { + return s.resolve(permission, true, false, nil) +} + +func (s *permissionService) Deny(permission PermissionRequest) bool { + return s.resolve(permission, false, true, nil) } func (s *permissionService) Request(ctx context.Context, opts CreatePermissionRequest) (bool, error) { diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go index de08f6beae901172dd3c821a9ff7e544cbc7c6c5..9d464f7a0c16b04491b5a6e0e621a0b64fa94ff4 100644 --- a/internal/permission/permission_test.go +++ b/internal/permission/permission_test.go @@ -2,7 +2,9 @@ package permission import ( "sync" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -342,3 +344,272 @@ func TestPermissionService_SequentialProperties(t *testing.T) { assert.True(t, result, "Repeated request should be auto-approved due to persistent permission") }) } + +// TestPermissionService_ResolveIdempotency covers the multi-subscriber +// resolve guarantees added for client/server mode: exactly one +// notification per resolution, racing callers see "already resolved", +// and stray Grant/Deny calls for unknown IDs are safe no-ops. +func TestPermissionService_ResolveIdempotency(t *testing.T) { + t.Parallel() + + t.Run("concurrent grants resolve exactly once", func(t *testing.T) { + t.Parallel() + service := NewPermissionService("/tmp", false, nil) + + events := service.Subscribe(t.Context()) + notifications := service.SubscribeNotifications(t.Context()) + + req := CreatePermissionRequest{ + SessionID: "race-session", + ToolCallID: "race-call", + ToolName: "tool", + Action: "act", + Path: "/tmp/race", + } + + var ( + wg sync.WaitGroup + granted bool + requestErr error + ) + wg.Go(func() { + granted, requestErr = service.Request(t.Context(), req) + }) + + // Wait for the request to be published so we have a real + // PermissionRequest (with its server-side ID) to race on. + var pending PermissionRequest + select { + case ev := <-events: + pending = ev.Payload + case <-time.After(2 * time.Second): + t.Fatal("permission request was never published") + } + + // Drain the initial "request opened" notification (Granted == + // false && Denied == false) so the next read is the resolution + // itself. + select { + case ev := <-notifications: + require.False(t, ev.Payload.Granted, "initial notification must not be granted") + require.False(t, ev.Payload.Denied, "initial notification must not be denied") + case <-time.After(2 * time.Second): + t.Fatal("initial notification was never published") + } + + // Race two grants from two goroutines. + var ( + resolvedCount atomic.Int32 + start = make(chan struct{}) + racers sync.WaitGroup + ) + for range 2 { + racers.Go(func() { + <-start + if service.Grant(pending) { + resolvedCount.Add(1) + } + }) + } + close(start) + racers.Wait() + + // Original Request must return granted exactly once. + wg.Wait() + require.NoError(t, requestErr) + assert.True(t, granted, "request should observe its grant") + + // Exactly one of the two grants resolved the request. + assert.Equal(t, int32(1), resolvedCount.Load(), + "exactly one Grant should report it resolved the request") + + // Exactly one resolution notification, and no further ones. + select { + case ev := <-notifications: + assert.True(t, ev.Payload.Granted, "resolution notification should be granted") + assert.Equal(t, "race-call", ev.Payload.ToolCallID) + case <-time.After(2 * time.Second): + t.Fatal("resolution notification was never published") + } + select { + case ev := <-notifications: + t.Fatalf("unexpected duplicate notification: %+v", ev.Payload) + case <-time.After(50 * time.Millisecond): + // good: no duplicate. + } + + // pendingRequests must be empty: no goroutine is left blocked + // on a send, and a future Grant for the same ID is a no-op. + ps := service.(*permissionService) + assert.Equal(t, 0, ps.pendingRequests.Len(), + "pendingRequests must be empty after resolution") + + assert.False(t, service.Grant(pending), + "a third Grant should report already-resolved") + }) + + t.Run("grant after deny is a no-op", func(t *testing.T) { + t.Parallel() + service := NewPermissionService("/tmp", false, nil) + + events := service.Subscribe(t.Context()) + notifications := service.SubscribeNotifications(t.Context()) + + req := CreatePermissionRequest{ + SessionID: "deny-first", + ToolCallID: "df-call", + ToolName: "tool", + Action: "act", + Path: "/tmp/df", + } + + var ( + wg sync.WaitGroup + granted bool + requestErr error + ) + wg.Go(func() { + granted, requestErr = service.Request(t.Context(), req) + }) + + var pending PermissionRequest + select { + case ev := <-events: + pending = ev.Payload + case <-time.After(2 * time.Second): + t.Fatal("permission request was never published") + } + + // Drain the initial neither-granted-nor-denied notification. + <-notifications + + assert.True(t, service.Deny(pending), "Deny should resolve the request") + wg.Wait() + require.NoError(t, requestErr) + assert.False(t, granted, "request should observe denial") + + // A follow-up Grant must be a no-op and must not flip the + // outcome or publish anything new. + assert.False(t, service.Grant(pending), + "Grant after Deny should report already-resolved") + + select { + case ev := <-notifications: + // The first resolution notification (denial) is expected; + // anything after that is a bug. + require.True(t, ev.Payload.Denied, + "the only post-initial notification must be the denial") + case <-time.After(2 * time.Second): + t.Fatal("denial notification was never published") + } + select { + case ev := <-notifications: + t.Fatalf("Grant after Deny must not publish: %+v", ev.Payload) + case <-time.After(50 * time.Millisecond): + // good. + } + }) + + t.Run("losing GrantPersistent does not record session permission", func(t *testing.T) { + t.Parallel() + service := NewPermissionService("/tmp", false, nil) + + events := service.Subscribe(t.Context()) + notifications := service.SubscribeNotifications(t.Context()) + + req := CreatePermissionRequest{ + SessionID: "race-persist", + ToolCallID: "rp-call", + ToolName: "tool", + Action: "act", + Path: "/tmp/rp", + } + + var ( + wg sync.WaitGroup + granted bool + requestErr error + ) + wg.Go(func() { + granted, requestErr = service.Request(t.Context(), req) + }) + + // Wait for the request to be published so we have the real + // pending PermissionRequest to race on. + var pending PermissionRequest + select { + case ev := <-events: + pending = ev.Payload + case <-time.After(2 * time.Second): + t.Fatal("permission request was never published") + } + + // Drain the initial neither-granted-nor-denied notification. + <-notifications + + // Deny wins, then a competing GrantPersistent loses. + assert.True(t, service.Deny(pending), "Deny should resolve the request") + assert.False(t, service.GrantPersistent(pending), + "GrantPersistent after Deny should report already-resolved") + + wg.Wait() + require.NoError(t, requestErr) + assert.False(t, granted, "request should observe denial") + + // The losing GrantPersistent must not have inserted an + // auto-approve entry. Issue a matching follow-up request and + // confirm the service still publishes a pending request (i.e. + // not auto-approved). We then Deny it to drain the goroutine. + var ( + wg2 sync.WaitGroup + granted2 bool + requestErr2 error + ) + wg2.Go(func() { + granted2, requestErr2 = service.Request(t.Context(), req) + }) + + select { + case ev := <-events: + assert.Equal(t, pending.SessionID, ev.Payload.SessionID) + service.Deny(ev.Payload) + case <-time.After(2 * time.Second): + t.Fatal("follow-up request was auto-approved; persistent grant leaked") + } + + wg2.Wait() + require.NoError(t, requestErr2) + assert.False(t, granted2, "follow-up request should be denied, not auto-approved") + }) + + t.Run("grant for unknown id is a safe no-op", func(t *testing.T) { + t.Parallel() + service := NewPermissionService("/tmp", false, nil) + + notifications := service.SubscribeNotifications(t.Context()) + + bogus := PermissionRequest{ + ID: "does-not-exist", + ToolCallID: "ghost", + ToolName: "tool", + Action: "act", + Path: "/tmp/ghost", + } + + assert.NotPanics(t, func() { + assert.False(t, service.Grant(bogus), + "Grant for unknown ID should report already-resolved") + assert.False(t, service.GrantPersistent(bogus), + "GrantPersistent for unknown ID should report already-resolved") + assert.False(t, service.Deny(bogus), + "Deny for unknown ID should report already-resolved") + }) + + select { + case ev := <-notifications: + t.Fatalf("unknown-ID resolution must not publish: %+v", ev.Payload) + case <-time.After(50 * time.Millisecond): + // good: no notification. + } + }) +} diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 03afa6b1c7083ea7f55f92faaa6d4f4709311ef0..739d8ddd9ef34c40f4b1d8ca25ddc20cd8a9f581 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -13,14 +13,15 @@ import ( // Workspace represents a running app.App workspace with its associated // resources and state. type Workspace struct { - ID string `json:"id"` - Path string `json:"path"` - YOLO bool `json:"yolo,omitempty"` - Debug bool `json:"debug,omitempty"` - DataDir string `json:"data_dir,omitempty"` - Version string `json:"version,omitempty"` - Config *config.Config `json:"config,omitempty"` - Env []string `json:"env,omitempty"` + ID string `json:"id"` + Path string `json:"path"` + YOLO bool `json:"yolo,omitempty"` + Debug bool `json:"debug,omitempty"` + DataDir string `json:"data_dir,omitempty"` + Version string `json:"version,omitempty"` + ClientID string `json:"client_id,omitempty"` + Config *config.Config `json:"config,omitempty"` + Env []string `json:"env,omitempty"` // Skills carries the snapshot of skill discovery state at workspace // creation time. Subsequent updates flow through the SSE event // stream. @@ -32,6 +33,19 @@ type Error struct { Message string `json:"message"` } +// ConfigChanged is published whenever the workspace's configuration is +// mutated by a backend operation. Clients react by re-fetching the +// workspace snapshot so cached config stays in sync across subscribers. +type ConfigChanged struct { + WorkspaceID string `json:"workspace_id"` +} + +// CurrentSession is the request body for the per-client +// current-session endpoint. An empty SessionID clears the entry. +type CurrentSession struct { + SessionID string `json:"session_id"` +} + // SkillInfo describes a visible skill exposed to a frontend. type SkillInfo struct { ID string `json:"id"` @@ -118,6 +132,15 @@ type PermissionGrant struct { Action PermissionAction `json:"action"` } +// PermissionGrantResponse is the server's response to a permission +// grant call. Resolved is true when this call resolved the pending +// request, and false when the request had already been resolved by a +// previous caller (e.g., another client in a multi-subscriber UI). A +// false value is not an error. +type PermissionGrantResponse struct { + Resolved bool `json:"resolved"` +} + // PermissionSkipRequest represents a request to skip permission prompts. type PermissionSkipRequest struct { Skip bool `json:"skip"` diff --git a/internal/proto/session.go b/internal/proto/session.go index 6c7aca7bd8b010d44e39ee582e03edaa7cea5a66..9c49e439ccdfda35144740835bd7e3a25741ecb7 100644 --- a/internal/proto/session.go +++ b/internal/proto/session.go @@ -1,6 +1,18 @@ package proto // Session represents a session in the proto layer. +// +// IsBusy is computed on read (it is not persisted with the session) and +// reflects whether an agent run is currently in flight for this session. +// It is populated by REST handlers in internal/server/proto.go from the +// workspace's AgentCoordinator. The Session SSE event path does not set +// it, since SSE consumers can compute presence from other agent signals. +// +// AttachedClients counts the number of clients currently viewing this +// session — i.e. entries in the workspace's clients map whose +// currentSessionID equals this session's ID and which have at least one +// live SSE stream. Hold-only clients (streams == 0) do not contribute. +// Like IsBusy, it is computed on read by REST handlers. type Session struct { ID string `json:"id"` ParentSessionID string `json:"parent_session_id"` @@ -13,6 +25,8 @@ type Session struct { Todos []Todo `json:"todos,omitempty"` CreatedAt int64 `json:"created_at"` UpdatedAt int64 `json:"updated_at"` + IsBusy bool `json:"is_busy"` + AttachedClients int `json:"attached_clients"` } // Todo represents a single todo entry on a session in the proto layer. diff --git a/internal/pubsub/events.go b/internal/pubsub/events.go index 689d7242970953f048f539252210b5f33ab1a49a..7f75d7d19e39f2a714fccc5be0232a19fadab7b9 100644 --- a/internal/pubsub/events.go +++ b/internal/pubsub/events.go @@ -24,6 +24,7 @@ const ( PayloadTypeSession PayloadType = "session" PayloadTypeFile PayloadType = "file" PayloadTypeAgentEvent PayloadType = "agent_event" + PayloadTypeConfigChanged PayloadType = "config_changed" PayloadTypeSkillsEvent PayloadType = "skills_event" ) diff --git a/internal/server/e2e_test.go b/internal/server/e2e_test.go new file mode 100644 index 0000000000000000000000000000000000000000..08aaedf66c95edd704f18b62d83d64e79966564e --- /dev/null +++ b/internal/server/e2e_test.go @@ -0,0 +1,600 @@ +package server + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/backend" + "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// e2eHarness wires a Server, its Backend (with a custom shutdownFn we +// can observe), an httptest.NewServer, and a synthetic Workspace whose +// embedded App has a live event broker. It is the minimum scaffolding +// the multi-client end-to-end scenarios in PLAN item 6 need. +type e2eHarness struct { + httpSrv *httptest.Server + srv *Server + backend *backend.Backend + workspace *backend.Workspace + app *app.App + shutdownHit atomic.Bool + + // sseWG tracks every SSE reader goroutine spawned by + // [e2eHarness.subscribeSSE]. The harness's cleanup hook waits on + // it after the httptest server has been closed so that the test + // cannot leave behind background readers (and therefore unclosed + // response bodies) after returning. + sseWG sync.WaitGroup +} + +// installServer attaches a fresh Server (with a custom shutdown +// callback that flips [e2eHarness.shutdownHit]) wrapped in an +// [httptest.Server] onto h. It registers the cleanup hooks for the +// httptest server and the SSE reader WaitGroup in the order required +// by the LIFO contract documented on [newE2EHarness]. +// +// Callers that want a fully synthetic workspace use [newE2EHarness]; +// callers that want to drive the real CreateWorkspace HTTP path use +// [newRealCreateHarness] and then [e2eHarness.postWorkspace]. +func (h *e2eHarness) installServer(t *testing.T) { + t.Helper() + srv := &Server{} + srv.backend = backend.New(context.Background(), nil, func() { + h.shutdownHit.Store(true) + }) + srv.installHandler() + + hs := httptest.NewServer(srv.Handler()) + // Order matters: t.Cleanup is LIFO and the test's own per- + // stream cancels (cancelA/cancelB) run first. After those, we + // want hs.Close to fire first (so any handler still parked in + // its `select` returns), THEN sseWG.Wait so every reader + // goroutine exits and closes its response body. Any caller- + // owned cleanups registered *before* installServer (e.g. App + // teardown for the synthetic harness) therefore run LAST, + // after the readers have drained. + t.Cleanup(h.sseWG.Wait) + t.Cleanup(hs.Close) + + h.httpSrv = hs + h.srv = srv + h.backend = srv.backend +} + +// newE2EHarness builds an in-process server + a synthetic Workspace +// whose embedded App is a real [app.App] constructed via +// [app.NewForTest], so its event broker delivers everything the SSE +// pipeline expects. Used by the scenarios that do not need to +// exercise the path-dedupe behavior of [backend.CreateWorkspace]. +// +// Cleanup tears down the App's broker only after sseWG.Wait and +// hs.Close have run, so SSE readers cannot observe a dead broker. +func newE2EHarness(t *testing.T) *e2eHarness { + t.Helper() + + h := &e2eHarness{} + + // Register the App teardown FIRST so LIFO order puts it AFTER + // the cleanups that installServer registers below (hs.Close + + // sseWG.Wait). + appCtx, cancel := context.WithCancel(context.Background()) + a := app.NewForTest(appCtx) + t.Cleanup(func() { + cancel() + a.ShutdownForTest() + }) + + h.installServer(t) + + ws := &backend.Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + App: a, + } + // Synthetic workspaces have an incomplete App; bypass the + // default teardown so the "last workspace removed" path can run + // without panicking inside [app.App.Shutdown]. + backend.SetWorkspaceShutdownFnForTest(ws, func() {}) + backend.InsertWorkspaceForTest(h.backend, ws) + + h.workspace = ws + h.app = a + return h +} + +// newRealCreateHarness builds an in-process server WITHOUT any +// pre-inserted workspace, intended for tests that drive the real +// [backend.CreateWorkspace] HTTP path (path-dedupe scenario). It +// isolates HOME/XDG_* via [t.Setenv] so [config.Init] doesn't read +// the host machine's config, which means callers MUST NOT mark the +// test as parallel. +func newRealCreateHarness(t *testing.T) *e2eHarness { + t.Helper() + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + h := &e2eHarness{} + h.installServer(t) + return h +} + +// postWorkspace drives the real POST /v1/workspaces handler and +// returns the resolved workspace proto. This is how scenario 1 +// exercises the path-dedupe behavior from PLAN item 1: two calls +// with the same Path and distinct ClientIDs must return the same +// workspace ID. +func (h *e2eHarness) postWorkspace(t *testing.T, args proto.Workspace) proto.Workspace { + t.Helper() + body, err := json.Marshal(args) + require.NoError(t, err) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, + h.httpSrv.URL+"/v1/workspaces", bytes.NewReader(body)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + resp, err := h.httpSrv.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode, "POST /v1/workspaces must succeed") + var out proto.Workspace + require.NoError(t, json.NewDecoder(resp.Body).Decode(&out)) + require.NotEmpty(t, out.ID, "server must return a workspace id") + return out +} + +// subscribeSSE opens an SSE stream against the test server for the +// given workspace and client ID. It returns a channel of decoded +// envelopes plus a cancel function that closes the stream. The +// returned channel is closed when the stream ends. +func (h *e2eHarness) subscribeSSE(t *testing.T, ctx context.Context, workspaceID, clientID string) (<-chan any, context.CancelFunc) { + t.Helper() + streamCtx, cancel := context.WithCancel(ctx) + + q := url.Values{"client_id": []string{clientID}} + reqURL := h.httpSrv.URL + "/v1/workspaces/" + workspaceID + "/events?" + q.Encode() + req, err := http.NewRequestWithContext(streamCtx, http.MethodGet, reqURL, nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + resp, err := h.httpSrv.Client().Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode, "SSE subscribe should return 200") + + out := make(chan any, 64) + h.sseWG.Go(func() { + defer resp.Body.Close() + defer close(out) + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadBytes('\n') + if err != nil { + return + } + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + data, ok := bytes.CutPrefix(line, []byte("data:")) + if !ok { + continue + } + data = bytes.TrimSpace(data) + var p pubsub.Payload + if err := json.Unmarshal(data, &p); err != nil { + continue + } + ev, decoded := decodeSSEEnvelope(p) + if !decoded { + continue + } + select { + case out <- ev: + case <-streamCtx.Done(): + return + } + } + }) + return out, cancel +} + +// decodeSSEEnvelope decodes the discriminated SSE envelope into the +// concrete pubsub.Event[proto.X] payload the e2e tests care about. +// Unknown payload types are skipped so tests can match on type +// assertions without worrying about envelope noise. +func decodeSSEEnvelope(p pubsub.Payload) (any, bool) { + switch p.Type { + case pubsub.PayloadTypePermissionRequest: + var e pubsub.Event[proto.PermissionRequest] + if err := json.Unmarshal(p.Payload, &e); err != nil { + return nil, false + } + return e, true + case pubsub.PayloadTypePermissionNotification: + var e pubsub.Event[proto.PermissionNotification] + if err := json.Unmarshal(p.Payload, &e); err != nil { + return nil, false + } + return e, true + case pubsub.PayloadTypeMessage: + var e pubsub.Event[proto.Message] + if err := json.Unmarshal(p.Payload, &e); err != nil { + return nil, false + } + return e, true + } + return nil, false +} + +// grantPermission posts a permission grant via the HTTP surface and +// returns the server's "resolved" verdict. Mirrors the client-side +// GrantPermission flow without importing internal/client (which +// would create an import cycle from this in-package test). +func (h *e2eHarness) grantPermission(t *testing.T, ctx context.Context, workspaceID string, req proto.PermissionGrant) bool { + t.Helper() + body, err := json.Marshal(req) + require.NoError(t, err) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, + h.httpSrv.URL+"/v1/workspaces/"+workspaceID+"/permissions/grant", + bytes.NewReader(body)) + require.NoError(t, err) + httpReq.Header.Set("Content-Type", "application/json") + resp, err := h.httpSrv.Client().Do(httpReq) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + var out proto.PermissionGrantResponse + require.NoError(t, json.NewDecoder(resp.Body).Decode(&out)) + return out.Resolved +} + +// waitForAttached spins until the workspace's clients map reports at +// least n entries with streams > 0. Catches the race where a test +// publishes events before the server-side AttachClient has completed. +func (h *e2eHarness) waitForAttached(t *testing.T, n int) { + t.Helper() + h.waitForAttachedOn(t, h.workspace, n) +} + +// waitForAttachedOn is the workspace-explicit form of waitForAttached. +// Tests that drive a workspace whose pointer is not stored on the +// harness (e.g. the real CreateWorkspace path) pass the workspace in. +func (h *e2eHarness) waitForAttachedOn(t *testing.T, ws *backend.Workspace, n int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if backend.WorkspaceLiveStreamCountForTest(ws) >= n { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("expected %d attached streams, have %d", n, + backend.WorkspaceLiveStreamCountForTest(ws)) +} + +// drainUntil reads from evc until it sees an event of type T that +// satisfies match, or ctx expires. Returns the matching event and +// ok=true, or the zero value and ok=false on timeout. +func drainUntil[T any](ctx context.Context, evc <-chan any, match func(T) bool) (T, bool) { + var zero T + for { + select { + case <-ctx.Done(): + return zero, false + case ev, ok := <-evc: + if !ok { + return zero, false + } + typed, isT := ev.(T) + if !isT { + continue + } + if match == nil || match(typed) { + return typed, true + } + } + } +} + +// TestE2E_TwoClientsReceiveSameMessage covers PLAN item 6 scenario 1: +// two clients POST /v1/workspaces with the same Path and observe +// that the server returns a single workspace (path-dedupe from PLAN +// item 1) and that an event published on that workspace fans out to +// both SSE streams. +// +// Cannot run in parallel: it isolates HOME/XDG_* via t.Setenv so +// config.Init does not read the host machine's real config. +func TestE2E_TwoClientsReceiveSameMessage(t *testing.T) { + h := newRealCreateHarness(t) + // Shorten the create-grace window so the workspace's pending + // creation holds release quickly during test cleanup once both + // SSE streams have been detached. + h.backend.SetCreateGrace(200 * time.Millisecond) + + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cidA := uuid.New().String() + cidB := uuid.New().String() + + // Shared workspace path. Two POSTs with this path must + // deduplicate at the backend's pathIndex and return the same + // workspace id. + wsPath := t.TempDir() + dataDir := t.TempDir() + args := proto.Workspace{Path: wsPath, DataDir: dataDir} + + argsA := args + argsA.ClientID = cidA + wsRespA := h.postWorkspace(t, argsA) + + argsB := args + argsB.ClientID = cidB + wsRespB := h.postWorkspace(t, argsB) + + require.Equal(t, wsRespA.ID, wsRespB.ID, + "POST /v1/workspaces with the same Path must return the same workspace id") + + // Look up the resulting workspace on the backend so the test + // can publish events through its real [app.App] event broker. + ws, err := h.backend.GetWorkspace(wsRespA.ID) + require.NoError(t, err) + // Override the shutdown callback so test cleanup doesn't run + // the full app.Shutdown path (which would tear down LSP/MCP + // resources the test doesn't need to exercise), but still + // release the pooled DB connection so Windows can clean up + // the temp data directory. + wsDataDir := ws.Cfg.Config().Options.DataDirectory + backend.SetWorkspaceShutdownFnForTest(ws, func() { + _ = db.Release(wsDataDir) + }) + + evcA, cancelA := h.subscribeSSE(t, ctx, ws.ID, cidA) + t.Cleanup(cancelA) + evcB, cancelB := h.subscribeSSE(t, ctx, ws.ID, cidB) + t.Cleanup(cancelB) + + h.waitForAttachedOn(t, ws, 2) + + const sessionID = "s-e2e-1" + msg := message.Message{ + ID: "m-1", + SessionID: sessionID, + Role: message.Assistant, + Parts: []message.ContentPart{message.TextContent{Text: "hello multi-client"}}, + } + ws.SendEvent(pubsub.Event[message.Message]{ + Type: pubsub.CreatedEvent, + Payload: msg, + }) + + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + gotA, okA := drainUntil(pickCtx, evcA, func(e pubsub.Event[proto.Message]) bool { + return e.Payload.ID == "m-1" + }) + require.True(t, okA, "client A must receive the MessageEvent") + require.Equal(t, sessionID, gotA.Payload.SessionID) + + gotB, okB := drainUntil(pickCtx, evcB, func(e pubsub.Event[proto.Message]) bool { + return e.Payload.ID == "m-1" + }) + require.True(t, okB, "client B must receive the same MessageEvent") + require.Equal(t, sessionID, gotB.Payload.SessionID) +} + +// TestE2E_PermissionFlowCrossClient covers PLAN item 6 scenario 2: +// a tool-driven permission request is granted by client A; client B +// observes a PermissionNotification; a redundant grant from B +// returns the "already resolved" indicator (resolved=false from the +// bool plumbing landed in item 3). +func TestE2E_PermissionFlowCrossClient(t *testing.T) { + t.Parallel() + h := newE2EHarness(t) + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + + cidA := uuid.New().String() + cidB := uuid.New().String() + + evcA, cancelA := h.subscribeSSE(t, ctx, h.workspace.ID, cidA) + t.Cleanup(cancelA) + evcB, cancelB := h.subscribeSSE(t, ctx, h.workspace.ID, cidB) + t.Cleanup(cancelB) + + h.waitForAttached(t, 2) + + // Drive the permission request from a goroutine simulating the + // tool path. Request blocks until resolved; capture the outcome. + const sessionID = "s-perm" + const toolCallID = "tc-1" + type result struct { + granted bool + err error + } + done := make(chan result, 1) + go func() { + granted, err := h.app.Permissions.Request(ctx, permission.CreatePermissionRequest{ + SessionID: sessionID, + ToolCallID: toolCallID, + ToolName: "view", + Description: "read a file", + Action: "read", + Path: h.workspace.Path, + }) + done <- result{granted: granted, err: err} + }() + + // Wait for the PermissionRequest to arrive on client A's SSE + // stream. We need its ID to drive the grant. + pickCtx, pickCancel := context.WithTimeout(ctx, 3*time.Second) + defer pickCancel() + reqEv, ok := drainUntil(pickCtx, evcA, func(e pubsub.Event[proto.PermissionRequest]) bool { + return e.Payload.ToolCallID == toolCallID + }) + require.True(t, ok, "client A must receive the PermissionRequest") + + // Client A grants — first grant must report resolved=true. + resolvedA := h.grantPermission(t, ctx, h.workspace.ID, proto.PermissionGrant{ + Permission: reqEv.Payload, + Action: proto.PermissionAllow, + }) + require.True(t, resolvedA, "client A's grant must resolve the pending request") + + // The blocked Request call must now return granted=true. + select { + case r := <-done: + require.NoError(t, r.err) + require.True(t, r.granted) + case <-pickCtx.Done(): + t.Fatal("permission Request did not return after grant") + } + + // Client B must receive a PermissionNotification with + // Granted=true for the same ToolCallID. The initial neither- + // granted-nor-denied notification published at the start of + // Request also lands on B's stream — match on the granted one. + notif, ok := drainUntil(pickCtx, evcB, func(e pubsub.Event[proto.PermissionNotification]) bool { + return e.Payload.ToolCallID == toolCallID && e.Payload.Granted + }) + require.True(t, ok, "client B must receive a granting PermissionNotification") + require.True(t, notif.Payload.Granted) + require.False(t, notif.Payload.Denied) + + // A follow-up grant from client B must report resolved=false + // (the request was already resolved by A). + resolvedB := h.grantPermission(t, ctx, h.workspace.ID, proto.PermissionGrant{ + Permission: reqEv.Payload, + Action: proto.PermissionAllow, + }) + require.False(t, resolvedB, "client B's follow-up grant must report already resolved") +} + +// TestE2E_KillingClientASSEDoesNotBreakClientB covers PLAN item 6 +// scenario 3: terminating client A's SSE stream does not affect +// client B's stream; client B continues to receive events. +func TestE2E_KillingClientASSEDoesNotBreakClientB(t *testing.T) { + t.Parallel() + h := newE2EHarness(t) + ctxB, cancelB := context.WithCancel(t.Context()) + t.Cleanup(cancelB) + ctxA, cancelA := context.WithCancel(t.Context()) + + cidA := uuid.New().String() + cidB := uuid.New().String() + + _, killA := h.subscribeSSE(t, ctxA, h.workspace.ID, cidA) + t.Cleanup(killA) + evcB, killB := h.subscribeSSE(t, ctxB, h.workspace.ID, cidB) + t.Cleanup(killB) + + h.waitForAttached(t, 2) + + // Kill A's stream. The server's deferred DetachClient should + // drop A's claim, leaving B as the sole attached client. + cancelA() + killA() + + require.Eventually(t, func() bool { + return backend.WorkspaceLiveStreamCountForTest(h.workspace) == 1 + }, 3*time.Second, 10*time.Millisecond, + "expected client A's stream to drop the attached count to 1") + + // Workspace must still exist (B is holding it open) and + // shutdown callback must not have fired yet. + _, err := h.backend.GetWorkspace(h.workspace.ID) + require.NoError(t, err, "workspace must still exist while B is attached") + require.False(t, h.shutdownHit.Load(), + "shutdown callback must not fire while B is still attached") + + // Publish a fresh event; B must still receive it. + const sessionID = "s-after-a-died" + msg := message.Message{ + ID: "m-after", + SessionID: sessionID, + Role: message.Assistant, + Parts: []message.ContentPart{message.TextContent{Text: "still alive"}}, + } + h.app.SendEvent(pubsub.Event[message.Message]{ + Type: pubsub.CreatedEvent, + Payload: msg, + }) + + pickCtx, pickCancel := context.WithTimeout(ctxB, 3*time.Second) + defer pickCancel() + got, ok := drainUntil(pickCtx, evcB, func(e pubsub.Event[proto.Message]) bool { + return e.Payload.ID == "m-after" + }) + require.True(t, ok, "client B must still receive events after A's stream is killed") + require.Equal(t, sessionID, got.Payload.SessionID) +} + +// TestE2E_ShutdownCallbackFiresWhenLastClientLeaves covers PLAN +// item 6 scenario 4: once both clients disconnect, the backend +// runs its "last workspace removed -> server shutdown" path. +func TestE2E_ShutdownCallbackFiresWhenLastClientLeaves(t *testing.T) { + t.Parallel() + h := newE2EHarness(t) + + ctxA, cancelA := context.WithCancel(t.Context()) + ctxB, cancelB := context.WithCancel(t.Context()) + t.Cleanup(cancelA) + t.Cleanup(cancelB) + + cidA := uuid.New().String() + cidB := uuid.New().String() + _, killA := h.subscribeSSE(t, ctxA, h.workspace.ID, cidA) + t.Cleanup(killA) + _, killB := h.subscribeSSE(t, ctxB, h.workspace.ID, cidB) + t.Cleanup(killB) + + h.waitForAttached(t, 2) + require.False(t, h.shutdownHit.Load(), "shutdown must not fire while clients are attached") + + cancelA() + killA() + require.Eventually(t, func() bool { + return backend.WorkspaceLiveStreamCountForTest(h.workspace) == 1 + }, 3*time.Second, 10*time.Millisecond) + require.False(t, h.shutdownHit.Load(), + "shutdown must not fire after only one client disconnects") + + cancelB() + killB() + require.Eventually(t, h.shutdownHit.Load, + 3*time.Second, 10*time.Millisecond, + "shutdown callback must fire once the last client disconnects") + + // Workspace must be gone from the index. + _, err := h.backend.GetWorkspace(h.workspace.ID) + require.ErrorIs(t, err, backend.ErrWorkspaceNotFound) + + // Subsequent GETs against the now-defunct workspace return + // 404, confirming the http surface still reflects the teardown. + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, + h.httpSrv.URL+"/v1/workspaces/"+h.workspace.ID, nil) + require.NoError(t, err) + r, err := h.httpSrv.Client().Do(req) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, r.Body) + r.Body.Close() + require.Equal(t, http.StatusNotFound, r.StatusCode) +} diff --git a/internal/server/events.go b/internal/server/events.go index e596e2d5866268fba4a4e42a98efbb4971e40f8b..f38619c52528679bf75675780eb4bb47961bd640 100644 --- a/internal/server/events.go +++ b/internal/server/events.go @@ -8,6 +8,7 @@ import ( "github.com/charmbracelet/crush/internal/agent/notify" "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/backend" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" @@ -92,6 +93,8 @@ func wrapEvent(ev any) *pubsub.Payload { Type: proto.AgentEventType(e.Payload.Type), }, }) + case pubsub.Event[proto.ConfigChanged]: + return envelope(pubsub.PayloadTypeConfigChanged, e) case pubsub.Event[skills.Event]: return envelope(pubsub.PayloadTypeSkillsEvent, pubsub.Event[proto.SkillsEvent]{ Type: e.Type, @@ -147,6 +150,29 @@ func sessionToProto(s session.Session) proto.Session { } } +// isSessionBusy reports whether the given workspace has an in-flight +// agent run for sessionID. It tolerates a nil workspace (treating it as +// "not busy") so REST handlers can pass GetWorkspace's result through +// unconditionally — the workspace lookup error is already surfaced by +// the prior ListSessions/GetSession call when relevant. +func isSessionBusy(ws *backend.Workspace, sessionID string) bool { + if ws == nil || ws.App == nil || ws.AgentCoordinator == nil { + return false + } + return ws.AgentCoordinator.IsSessionBusy(sessionID) +} + +// attachedClients returns the number of clients currently viewing +// sessionID in ws. Hold-only clients (streams == 0) do not contribute. +// A nil workspace is treated as zero so handlers can pass GetWorkspace's +// result through without an extra guard. +func attachedClients(ws *backend.Workspace, sessionID string) int { + if ws == nil { + return 0 + } + return ws.AttachedClientsForSession(sessionID) +} + func todosToProto(todos []session.Todo) []proto.Todo { if len(todos) == 0 { return nil diff --git a/internal/server/multiclient_test.go b/internal/server/multiclient_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3e11bd206764741b78054cbc070d3cbbfc2c3d74 --- /dev/null +++ b/internal/server/multiclient_test.go @@ -0,0 +1,230 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/charmbracelet/crush/internal/backend" + "github.com/charmbracelet/crush/internal/proto" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// installSyntheticWorkspace creates a synthetic [backend.Workspace] +// registered with the controller's backend, suitable for handler-level +// tests that do not need a real [app.App]. The workspace's ID is a +// fresh UUID and its path is a tempdir; teardown is the caller's +// responsibility (handlers should not rely on synthetic workspaces +// disappearing automatically). +func installSyntheticWorkspace(t *testing.T, c *controllerV1) *backend.Workspace { + t.Helper() + ws := &backend.Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + } + backend.InsertWorkspaceForTest(c.backend, ws) + return ws +} + +// newTestController builds a controllerV1 around a backend without a +// real config store, suitable for handler-level 400 tests. +func newTestController() *controllerV1 { + s := &Server{} + s.backend = backend.New(context.Background(), nil, nil) + return &controllerV1{backend: s.backend, server: s} +} + +func TestPostWorkspaces_RejectsMissingClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + body, err := json.Marshal(proto.Workspace{Path: t.TempDir()}) + require.NoError(t, err) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/workspaces", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + c.handlePostWorkspaces(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + var perr proto.Error + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &perr)) + require.Contains(t, perr.Message, "client_id") +} + +func TestPostWorkspaces_RejectsMalformedClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + body, err := json.Marshal(proto.Workspace{Path: t.TempDir(), ClientID: "not-a-uuid"}) + require.NoError(t, err) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/workspaces", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + c.handlePostWorkspaces(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDeleteWorkspace_RejectsMissingClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodDelete, "/v1/workspaces/abc", nil) + req.SetPathValue("id", "abc") + rec := httptest.NewRecorder() + + c.handleDeleteWorkspaces(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDeleteWorkspace_RejectsMalformedClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodDelete, "/v1/workspaces/abc?client_id=nope", nil) + req.SetPathValue("id", "abc") + rec := httptest.NewRecorder() + + c.handleDeleteWorkspaces(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSubscribeEvents_RejectsMissingClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/workspaces/abc/events", nil) + req.SetPathValue("id", "abc") + rec := httptest.NewRecorder() + + c.handleGetWorkspaceEvents(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSubscribeEvents_RejectsMalformedClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/workspaces/abc/events?client_id=nope", nil) + req.SetPathValue("id", "abc") + rec := httptest.NewRecorder() + + c.handleGetWorkspaceEvents(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +// postCurrentSession is a small helper that POSTs the JSON body to +// /v1/workspaces/{id}/current-session?client_id=cid and returns the +// recorder. It does not require a real listener. +func postCurrentSession(t *testing.T, c *controllerV1, wsID, clientID, sessionID string) *httptest.ResponseRecorder { + t.Helper() + body, err := json.Marshal(proto.CurrentSession{SessionID: sessionID}) + require.NoError(t, err) + url := "/v1/workspaces/" + wsID + "/current-session" + if clientID != "" { + url += "?client_id=" + clientID + } + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, url, bytes.NewReader(body)) + req.SetPathValue("id", wsID) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c.handlePostWorkspaceCurrentSession(rec, req) + return rec +} + +func TestPostCurrentSession_RejectsMissingClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + body, err := json.Marshal(proto.CurrentSession{SessionID: "S1"}) + require.NoError(t, err) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/workspaces/abc/current-session", bytes.NewReader(body)) + req.SetPathValue("id", "abc") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + c.handlePostWorkspaceCurrentSession(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestPostCurrentSession_RejectsMalformedClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + rec := postCurrentSession(t, c, "abc", "not-a-uuid", "S1") + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestPostCurrentSession_RejectsBadBody(t *testing.T) { + t.Parallel() + c := newTestController() + + cid := uuid.New().String() + url := "/v1/workspaces/abc/current-session?client_id=" + cid + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, url, bytes.NewReader([]byte("not-json"))) + req.SetPathValue("id", "abc") + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + c.handlePostWorkspaceCurrentSession(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestPostCurrentSession_UnknownWorkspace(t *testing.T) { + t.Parallel() + c := newTestController() + + rec := postCurrentSession(t, c, uuid.New().String(), uuid.New().String(), "S1") + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestPostCurrentSession_UnknownClient(t *testing.T) { + t.Parallel() + c := newTestController() + ws := installSyntheticWorkspace(t, c) + + rec := postCurrentSession(t, c, ws.ID, uuid.New().String(), "S1") + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestPostCurrentSession_HoldOnly(t *testing.T) { + t.Parallel() + c := newTestController() + ws := installSyntheticWorkspace(t, c) + + cid := uuid.New().String() + require.NoError(t, backend.RegisterClientForTesting(c.backend, ws, cid)) + t.Cleanup(func() { _ = c.backend.DeleteWorkspace(ws.ID, cid) }) + + rec := postCurrentSession(t, c, ws.ID, cid, "S1") + require.Equal(t, http.StatusNotFound, rec.Code, "hold-only client must be rejected") +} + +func TestPostCurrentSession_AttachedClientSucceeds(t *testing.T) { + t.Parallel() + c := newTestController() + ws := installSyntheticWorkspace(t, c) + + cid := uuid.New().String() + require.NoError(t, c.backend.AttachClient(ws.ID, cid)) + t.Cleanup(func() { c.backend.DetachClient(ws.ID, cid) }) + + rec := postCurrentSession(t, c, ws.ID, cid, "S1") + require.Equal(t, http.StatusOK, rec.Code) + + // Clearing also returns 200. + rec = postCurrentSession(t, c, ws.ID, cid, "") + require.Equal(t, http.StatusOK, rec.Code) +} diff --git a/internal/server/proto.go b/internal/server/proto.go index f30dade2c66fdd62a5caa4b80d29235ef2930c4a..6d3eebb562784adb377eebfb480852dcda53642a 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -9,6 +9,7 @@ import ( "github.com/charmbracelet/crush/internal/backend" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/session" + "github.com/google/uuid" ) type controllerV1 struct { @@ -133,6 +134,56 @@ func (c *controllerV1) handlePostWorkspaces(w http.ResponseWriter, r *http.Reque jsonEncode(w, result) } +// requireClientID reads the client_id query parameter and validates it +// as a UUID. On failure it writes a 400 and returns false. +func (c *controllerV1) requireClientID(w http.ResponseWriter, r *http.Request) (string, bool) { + cid := r.URL.Query().Get("client_id") + if cid == "" { + c.server.logError(r, "Missing client_id query parameter") + jsonError(w, http.StatusBadRequest, "client_id is required") + return "", false + } + if _, err := uuid.Parse(cid); err != nil { + c.server.logError(r, "Invalid client_id", "error", err) + jsonError(w, http.StatusBadRequest, "client_id is not a valid UUID") + return "", false + } + return cid, true +} + +// handlePostWorkspaceCurrentSession records the calling client's +// current session selection for the workspace. An empty session_id +// clears the entry (e.g. the client is on the landing screen). +// +// @Summary Set current session for a client +// @Tags workspaces +// @Accept json +// @Produce json +// @Param id path string true "Workspace ID" +// @Param client_id query string true "Client ID (UUID)" +// @Param request body proto.CurrentSession true "Current session selection" +// @Success 200 +// @Failure 400 {object} proto.Error +// @Failure 404 {object} proto.Error +// @Router /workspaces/{id}/current-session [post] +func (c *controllerV1) handlePostWorkspaceCurrentSession(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + clientID, ok := c.requireClientID(w, r) + if !ok { + return + } + var req proto.CurrentSession + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.server.logError(r, "Failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + if err := c.backend.SetCurrentSession(id, clientID, req.SessionID); err != nil { + c.handleError(w, r, err) + return + } +} + // handleDeleteWorkspaces deletes a workspace. // // @Summary Delete workspace @@ -143,7 +194,14 @@ func (c *controllerV1) handlePostWorkspaces(w http.ResponseWriter, r *http.Reque // @Router /workspaces/{id} [delete] func (c *controllerV1) handleDeleteWorkspaces(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - c.backend.DeleteWorkspace(id) + clientID, ok := c.requireClientID(w, r) + if !ok { + return + } + if err := c.backend.DeleteWorkspace(id, clientID); err != nil { + c.handleError(w, r, err) + return + } } // handleGetWorkspaceConfig returns workspace configuration. @@ -199,6 +257,15 @@ func (c *controllerV1) handleGetWorkspaceProviders(w http.ResponseWriter, r *htt func (c *controllerV1) handleGetWorkspaceEvents(w http.ResponseWriter, r *http.Request) { flusher := http.NewResponseController(w) id := r.PathValue("id") + clientID, ok := c.requireClientID(w, r) + if !ok { + return + } + if err := c.backend.AttachClient(id, clientID); err != nil { + c.handleError(w, r, err) + return + } + defer c.backend.DetachClient(id, clientID) events, err := c.backend.SubscribeEvents(r.Context(), id) if err != nil { c.handleError(w, r, err) @@ -208,6 +275,12 @@ func (c *controllerV1) handleGetWorkspaceEvents(w http.ResponseWriter, r *http.R w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") + // Flush headers immediately so clients see the 200 response + // before any events arrive. Without this, a quiet workspace + // keeps the client's SubscribeEvents call blocked on the + // initial RoundTrip. + w.WriteHeader(http.StatusOK) + flusher.Flush() for { select { @@ -303,9 +376,12 @@ func (c *controllerV1) handleGetWorkspaceSessions(w http.ResponseWriter, r *http c.handleError(w, r, err) return } + ws, _ := c.backend.GetWorkspace(id) result := make([]proto.Session, len(sessions)) for i, s := range sessions { result[i] = sessionToProto(s) + result[i].IsBusy = isSessionBusy(ws, s.ID) + result[i].AttachedClients = attachedClients(ws, s.ID) } jsonEncode(w, result) } @@ -338,7 +414,11 @@ func (c *controllerV1) handlePostWorkspaceSessions(w http.ResponseWriter, r *htt c.handleError(w, r, err) return } - jsonEncode(w, sessionToProto(sess)) + ws, _ := c.backend.GetWorkspace(id) + out := sessionToProto(sess) + out.IsBusy = isSessionBusy(ws, sess.ID) + out.AttachedClients = attachedClients(ws, sess.ID) + jsonEncode(w, out) } // handleGetWorkspaceSession returns a single session. @@ -360,7 +440,11 @@ func (c *controllerV1) handleGetWorkspaceSession(w http.ResponseWriter, r *http. c.handleError(w, r, err) return } - jsonEncode(w, sessionToProto(sess)) + ws, _ := c.backend.GetWorkspace(id) + out := sessionToProto(sess) + out.IsBusy = isSessionBusy(ws, sess.ID) + out.AttachedClients = attachedClients(ws, sess.ID) + jsonEncode(w, out) } // handleGetWorkspaceSessionHistory returns the history for a session. @@ -436,7 +520,11 @@ func (c *controllerV1) handlePutWorkspaceSession(w http.ResponseWriter, r *http. c.handleError(w, r, err) return } - jsonEncode(w, sessionToProto(saved)) + ws, _ := c.backend.GetWorkspace(id) + out := sessionToProto(saved) + out.IsBusy = isSessionBusy(ws, saved.ID) + out.AttachedClients = attachedClients(ws, saved.ID) + jsonEncode(w, out) } // handleDeleteWorkspaceSession deletes a session. @@ -864,7 +952,7 @@ func (c *controllerV1) handleGetWorkspaceAgentDefaultSmallModel(w http.ResponseW // @Accept json // @Param id path string true "Workspace ID" // @Param request body proto.PermissionGrant true "Permission grant" -// @Success 200 +// @Success 200 {object} proto.PermissionGrantResponse // @Failure 400 {object} proto.Error // @Failure 404 {object} proto.Error // @Failure 500 {object} proto.Error @@ -879,11 +967,12 @@ func (c *controllerV1) handlePostWorkspacePermissionsGrant(w http.ResponseWriter return } - if err := c.backend.GrantPermission(id, req); err != nil { + resolved, err := c.backend.GrantPermission(id, req) + if err != nil { c.handleError(w, r, err) return } - w.WriteHeader(http.StatusOK) + jsonEncode(w, proto.PermissionGrantResponse{Resolved: resolved}) } // handlePostWorkspacePermissionsSkip sets whether to skip permission prompts. @@ -951,6 +1040,10 @@ func (c *controllerV1) handleError(w http.ResponseWriter, r *http.Request, err e status = http.StatusBadRequest case errors.Is(err, backend.ErrUnknownCommand): status = http.StatusBadRequest + case errors.Is(err, backend.ErrInvalidClientID): + status = http.StatusBadRequest + case errors.Is(err, backend.ErrClientNotAttached): + status = http.StatusNotFound } c.server.logError(r, err.Error()) jsonError(w, status, err.Error()) diff --git a/internal/server/server.go b/internal/server/server.go index 87b7009f4a80894e18a849a215072cea464592c5..7c51dab2adf7d8c715743c214d3489900c1636e3 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -100,7 +100,18 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server { } }() }) + s.installHandler() + if network == "tcp" { + s.h.Addr = address + } + return s +} +// installHandler builds the protocol/router around s.backend and +// assigns the resulting http.Server to s.h. Extracted from +// [NewServer] so test harnesses can wire a Server around a +// pre-constructed backend. +func (s *Server) installHandler() { var p http.Protocols p.SetHTTP1(true) p.SetUnencryptedHTTP2(true) @@ -113,6 +124,7 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server { mux.HandleFunc("GET /v1/workspaces", c.handleGetWorkspaces) mux.HandleFunc("POST /v1/workspaces", c.handlePostWorkspaces) mux.HandleFunc("DELETE /v1/workspaces/{id}", c.handleDeleteWorkspaces) + mux.HandleFunc("POST /v1/workspaces/{id}/current-session", c.handlePostWorkspaceCurrentSession) mux.HandleFunc("GET /v1/workspaces/{id}", c.handleGetWorkspace) mux.HandleFunc("GET /v1/workspaces/{id}/config", c.handleGetWorkspaceConfig) mux.HandleFunc("GET /v1/workspaces/{id}/events", c.handleGetWorkspaceEvents) @@ -172,10 +184,13 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server { Protocols: &p, Handler: s.recoverHandler(s.loggingHandler(mux)), } - if network == "tcp" { - s.h.Addr = address - } - return s +} + +// Handler returns the server's HTTP handler. Exposed so test harnesses +// can wrap it in an httptest.Server without going through the +// production listener setup. +func (s *Server) Handler() http.Handler { + return s.h.Handler } // Serve accepts incoming connections on the listener. diff --git a/internal/server/sessions_isbusy_test.go b/internal/server/sessions_isbusy_test.go new file mode 100644 index 0000000000000000000000000000000000000000..060c00abe9367dc7162bdb50dd77fe951041aa51 --- /dev/null +++ b/internal/server/sessions_isbusy_test.go @@ -0,0 +1,325 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/app" + "github.com/charmbracelet/crush/internal/backend" + "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/session" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// stubCoordinator is a minimal agent.Coordinator that only reports +// per-session busy state. Every other method returns a zero value so +// the type satisfies the interface without dragging in the full +// coordinator dependency graph. +type stubCoordinator struct { + busy map[string]bool +} + +func (s *stubCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) { + return nil, nil +} +func (s *stubCoordinator) Cancel(string) {} +func (s *stubCoordinator) CancelAll() {} +func (s *stubCoordinator) IsBusy() bool { return false } +func (s *stubCoordinator) IsSessionBusy(id string) bool { + return s.busy[id] +} +func (s *stubCoordinator) QueuedPrompts(string) int { return 0 } +func (s *stubCoordinator) QueuedPromptsList(string) []string { return nil } +func (s *stubCoordinator) ClearQueue(string) {} +func (s *stubCoordinator) Summarize(context.Context, string) error { + return nil +} +func (s *stubCoordinator) Model() agent.Model { return agent.Model{} } +func (s *stubCoordinator) UpdateModels(context.Context) error { return nil } + +// stubSessions is a minimal session.Service that returns a fixed list +// (and supports Get by ID). All other methods return zero values; the +// IsBusy tests do not exercise them. +type stubSessions struct { + session.Service // embed nil to inherit the unexported broker methods + all []session.Session +} + +func (s *stubSessions) List(context.Context) ([]session.Session, error) { + return s.all, nil +} + +func (s *stubSessions) Get(_ context.Context, id string) (session.Session, error) { + for _, sess := range s.all { + if sess.ID == id { + return sess, nil + } + } + return session.Session{}, errors.New("not found") +} + +// buildBusyWorkspace returns a controller wired to a backend that owns +// a single workspace whose AgentCoordinator reports the named session +// as busy. +func buildBusyWorkspace(t *testing.T, sessionID string, busy bool) (*controllerV1, string) { + t.Helper() + + b := backend.New(context.Background(), nil, nil) + wsID := uuid.New().String() + coord := &stubCoordinator{busy: map[string]bool{sessionID: busy}} + a := &app.App{AgentCoordinator: coord} + a.Sessions = &stubSessions{all: []session.Session{{ID: sessionID, Title: "t"}}} + + ws := &backend.Workspace{ + ID: wsID, + Path: t.TempDir(), + App: a, + } + backend.InsertWorkspaceForTest(b, ws) + + s := &Server{backend: b} + return &controllerV1{backend: b, server: s}, wsID +} + +func TestSessionListIncludesIsBusy(t *testing.T) { + t.Parallel() + const sid = "s-busy" + c, wsID := buildBusyWorkspace(t, sid, true) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/workspaces/"+wsID+"/sessions", nil) + req.SetPathValue("id", wsID) + rec := httptest.NewRecorder() + c.handleGetWorkspaceSessions(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var got []proto.Session + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Len(t, got, 1) + require.Equal(t, sid, got[0].ID) + require.True(t, got[0].IsBusy, "expected IsBusy=true for the busy session") +} + +func TestSessionListIdleSessionIsNotBusy(t *testing.T) { + t.Parallel() + const sid = "s-idle" + c, wsID := buildBusyWorkspace(t, sid, false) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/workspaces/"+wsID+"/sessions", nil) + req.SetPathValue("id", wsID) + rec := httptest.NewRecorder() + c.handleGetWorkspaceSessions(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var got []proto.Session + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Len(t, got, 1) + require.False(t, got[0].IsBusy, "expected IsBusy=false for idle session") +} + +func TestSessionGetIncludesIsBusy(t *testing.T) { + t.Parallel() + const sid = "s-busy" + c, wsID := buildBusyWorkspace(t, sid, true) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/workspaces/"+wsID+"/sessions/"+sid, nil) + req.SetPathValue("id", wsID) + req.SetPathValue("sid", sid) + rec := httptest.NewRecorder() + c.handleGetWorkspaceSession(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var got proto.Session + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, sid, got.ID) + require.True(t, got.IsBusy) +} + +// TestIsSessionBusyNilSafe verifies the helper tolerates a missing +// workspace, app, or coordinator — phase A handlers rely on this so +// they can pass GetWorkspace's result through without an extra guard. +func TestIsSessionBusyNilSafe(t *testing.T) { + t.Parallel() + + require.False(t, isSessionBusy(nil, "x")) + require.False(t, isSessionBusy(&backend.Workspace{}, "x")) + require.False(t, isSessionBusy(&backend.Workspace{App: &app.App{}}, "x")) +} + +// TestProtoSessionIsBusyBackwardCompat verifies older consumers that +// unmarshal proto.Session without knowing about IsBusy still succeed +// and ignore the new field harmlessly. +func TestProtoSessionIsBusyBackwardCompat(t *testing.T) { + t.Parallel() + + wire := proto.Session{ID: "s1", Title: "t", IsBusy: true} + raw, err := json.Marshal(wire) + require.NoError(t, err) + + // Old client shape: same struct minus IsBusy. We model this by + // unmarshaling into a struct that doesn't declare the field. + type oldSession struct { + ID string `json:"id"` + Title string `json:"title"` + } + var old oldSession + require.NoError(t, json.Unmarshal(raw, &old)) + require.Equal(t, "s1", old.ID) + require.Equal(t, "t", old.Title) +} + +// buildMultiSessionWorkspace returns a controller wired to a backend +// that owns a workspace with the given session IDs. Used to exercise +// AttachedClients counts across more than one session. +func buildMultiSessionWorkspace(t *testing.T, sessionIDs ...string) (*controllerV1, *backend.Workspace) { + t.Helper() + + b := backend.New(context.Background(), nil, nil) + a := &app.App{AgentCoordinator: &stubCoordinator{}} + sessions := make([]session.Session, len(sessionIDs)) + for i, sid := range sessionIDs { + sessions[i] = session.Session{ID: sid, Title: sid} + } + a.Sessions = &stubSessions{all: sessions} + + ws := &backend.Workspace{ + ID: uuid.New().String(), + Path: t.TempDir(), + App: a, + } + backend.InsertWorkspaceForTest(b, ws) + // Synthetic workspaces have an incomplete App; bypass the + // default teardown to avoid panics when the last client detaches. + backend.SetWorkspaceShutdownFnForTest(ws, func() {}) + + s := &Server{backend: b} + return &controllerV1{backend: b, server: s}, ws +} + +// listSessions invokes handleGetWorkspaceSessions and returns the +// decoded response so tests can assert per-session counts. +func listSessions(t *testing.T, c *controllerV1, wsID string) []proto.Session { + t.Helper() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/workspaces/"+wsID+"/sessions", nil) + req.SetPathValue("id", wsID) + rec := httptest.NewRecorder() + c.handleGetWorkspaceSessions(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + var got []proto.Session + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + return got +} + +func countsBySessionID(sessions []proto.Session) map[string]int { + out := make(map[string]int, len(sessions)) + for _, s := range sessions { + out[s.ID] = s.AttachedClients + } + return out +} + +// TestSessionListIncludesAttachedClients walks two sessions through +// the same lifecycle covered by TestAttachedClients_BasicLifecycle in +// the backend package, but observed at the handler boundary. +func TestSessionListIncludesAttachedClients(t *testing.T) { + t.Parallel() + c, ws := buildMultiSessionWorkspace(t, "S1", "S2") + + // No attached clients yet. + counts := countsBySessionID(listSessions(t, c, ws.ID)) + require.Equal(t, 0, counts["S1"]) + require.Equal(t, 0, counts["S2"]) + + // Attach A, set to S1: S1=1. + cidA := uuid.New().String() + require.NoError(t, c.backend.AttachClient(ws.ID, cidA)) + t.Cleanup(func() { c.backend.DetachClient(ws.ID, cidA) }) + require.NoError(t, c.backend.SetCurrentSession(ws.ID, cidA, "S1")) + counts = countsBySessionID(listSessions(t, c, ws.ID)) + require.Equal(t, 1, counts["S1"]) + require.Equal(t, 0, counts["S2"]) + + // Attach B, set to S1: S1=2. + cidB := uuid.New().String() + require.NoError(t, c.backend.AttachClient(ws.ID, cidB)) + require.NoError(t, c.backend.SetCurrentSession(ws.ID, cidB, "S1")) + counts = countsBySessionID(listSessions(t, c, ws.ID)) + require.Equal(t, 2, counts["S1"]) + require.Equal(t, 0, counts["S2"]) + + // B switches to S2: counts redistribute. + require.NoError(t, c.backend.SetCurrentSession(ws.ID, cidB, "S2")) + counts = countsBySessionID(listSessions(t, c, ws.ID)) + require.Equal(t, 1, counts["S1"]) + require.Equal(t, 1, counts["S2"]) + + // B detaches: S2 drops to 0. + c.backend.DetachClient(ws.ID, cidB) + counts = countsBySessionID(listSessions(t, c, ws.ID)) + require.Equal(t, 1, counts["S1"]) + require.Equal(t, 0, counts["S2"]) +} + +// TestSessionListExcludesHoldOnlyClient verifies that a registered +// client without an SSE stream (streams == 0) does not contribute to +// AttachedClients, even though it has an entry in the workspace's +// clients map. +func TestSessionListExcludesHoldOnlyClient(t *testing.T) { + t.Parallel() + c, ws := buildMultiSessionWorkspace(t, "S1") + + cid := uuid.New().String() + require.NoError(t, backend.RegisterClientForTesting(c.backend, ws, cid)) + t.Cleanup(func() { _ = c.backend.DeleteWorkspace(ws.ID, cid) }) + + counts := countsBySessionID(listSessions(t, c, ws.ID)) + require.Equal(t, 0, counts["S1"], "hold-only client must not be counted") +} + +// TestSessionListExcludesUnselectedAttachedClient verifies that a +// client with a live SSE stream but no current session +// (currentSessionID == "") does not show up under any session's count. +func TestSessionListExcludesUnselectedAttachedClient(t *testing.T) { + t.Parallel() + c, ws := buildMultiSessionWorkspace(t, "S1") + + cid := uuid.New().String() + require.NoError(t, c.backend.AttachClient(ws.ID, cid)) + t.Cleanup(func() { c.backend.DetachClient(ws.ID, cid) }) + // Intentionally do NOT call SetCurrentSession. + + counts := countsBySessionID(listSessions(t, c, ws.ID)) + require.Equal(t, 0, counts["S1"], + "attached client with no current session must not contribute to S1") +} + +// TestSessionGetIncludesAttachedClients verifies the single-session +// handler also populates AttachedClients. +func TestSessionGetIncludesAttachedClients(t *testing.T) { + t.Parallel() + c, ws := buildMultiSessionWorkspace(t, "S1") + + cid := uuid.New().String() + require.NoError(t, c.backend.AttachClient(ws.ID, cid)) + t.Cleanup(func() { c.backend.DetachClient(ws.ID, cid) }) + require.NoError(t, c.backend.SetCurrentSession(ws.ID, cid, "S1")) + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, + "/v1/workspaces/"+ws.ID+"/sessions/S1", nil) + req.SetPathValue("id", ws.ID) + req.SetPathValue("sid", "S1") + rec := httptest.NewRecorder() + c.handleGetWorkspaceSession(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + var got proto.Session + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, 1, got.AttachedClients) +} diff --git a/internal/ui/dialog/permissions.go b/internal/ui/dialog/permissions.go index 4f158211514bfab5ac0ee4e857686b8105c359f3..4dfb9ba6655a863468a87de5cd1ce90faada6c70 100644 --- a/internal/ui/dialog/permissions.go +++ b/internal/ui/dialog/permissions.go @@ -224,6 +224,12 @@ func (*Permissions) ID() string { return PermissionsID } +// ToolCallID returns the tool call ID associated with this dialog's +// permission request. +func (p *Permissions) ToolCallID() string { + return p.permission.ToolCallID +} + // HandleMsg implements [Dialog]. func (p *Permissions) HandleMsg(msg tea.Msg) Action { switch msg := msg.(type) { diff --git a/internal/ui/model/permission_test.go b/internal/ui/model/permission_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c209e211223bf9e09f26eaeff9098b54206456a2 --- /dev/null +++ b/internal/ui/model/permission_test.go @@ -0,0 +1,96 @@ +package model + +import ( + "testing" + + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/ui/dialog" + "github.com/stretchr/testify/require" +) + +// newTestUIForPermissions builds a UI with a chat, dialog overlay, and +// common context sufficient to exercise handlePermissionNotification. +func newTestUIForPermissions() *UI { + u := newTestUI() + u.dialog = dialog.NewOverlay() + return u +} + +func TestHandlePermissionNotification_RemoteGrantClosesDialog(t *testing.T) { + t.Parallel() + + u := newTestUIForPermissions() + perm := permission.PermissionRequest{ + ID: "perm-1", + ToolCallID: "tool-call-X", + ToolName: "bash", + } + u.dialog.OpenDialog(dialog.NewPermissions(u.com, perm)) + require.True(t, u.dialog.ContainsDialog(dialog.PermissionsID)) + + u.handlePermissionNotification(permission.PermissionNotification{ + ToolCallID: "tool-call-X", + Granted: true, + }) + + require.False(t, u.dialog.ContainsDialog(dialog.PermissionsID), + "granted notification should close matching permissions dialog") +} + +func TestHandlePermissionNotification_RemoteDenyClosesDialog(t *testing.T) { + t.Parallel() + + u := newTestUIForPermissions() + perm := permission.PermissionRequest{ + ID: "perm-2", + ToolCallID: "tool-call-Y", + } + u.dialog.OpenDialog(dialog.NewPermissions(u.com, perm)) + + u.handlePermissionNotification(permission.PermissionNotification{ + ToolCallID: "tool-call-Y", + Denied: true, + }) + + require.False(t, u.dialog.ContainsDialog(dialog.PermissionsID), + "denied notification should close matching permissions dialog") +} + +func TestHandlePermissionNotification_InitialPendingDoesNotClose(t *testing.T) { + t.Parallel() + + u := newTestUIForPermissions() + perm := permission.PermissionRequest{ + ID: "perm-3", + ToolCallID: "tool-call-Z", + } + u.dialog.OpenDialog(dialog.NewPermissions(u.com, perm)) + + // The initial notification published by permission.Request is + // neither granted nor denied; it must not dismiss the dialog. + u.handlePermissionNotification(permission.PermissionNotification{ + ToolCallID: "tool-call-Z", + }) + + require.True(t, u.dialog.ContainsDialog(dialog.PermissionsID), + "initial pending notification must not close the dialog") +} + +func TestHandlePermissionNotification_DifferentToolCallIDDoesNotClose(t *testing.T) { + t.Parallel() + + u := newTestUIForPermissions() + perm := permission.PermissionRequest{ + ID: "perm-4", + ToolCallID: "tool-call-A", + } + u.dialog.OpenDialog(dialog.NewPermissions(u.com, perm)) + + u.handlePermissionNotification(permission.PermissionNotification{ + ToolCallID: "tool-call-B", + Granted: true, + }) + + require.True(t, u.dialog.ContainsDialog(dialog.PermissionsID), + "notification for unrelated tool call must not close the dialog") +} diff --git a/internal/ui/model/session.go b/internal/ui/model/session.go index 17172d87f9f7f46d63768512055604bce8adf262..aa31009b89ac6f7d14480fb1e607560021f88a3c 100644 --- a/internal/ui/model/session.go +++ b/internal/ui/model/session.go @@ -64,8 +64,13 @@ type SessionFile struct { // the diff statistics (additions and deletions) for each file in the session. // It returns a tea.Cmd that, when executed, fetches the session data and // returns a sessionFilesLoadedMsg containing the processed session files. +// +// The returned batch also reports the new current-session selection to +// the workspace so the server can update its per-client presence map. +// That report is fire-and-forget: errors are logged at debug and the +// UI never blocks on the call. func (m *UI) loadSession(sessionID string) tea.Cmd { - return func() tea.Msg { + load := func() tea.Msg { session, err := m.com.Workspace.GetSession(context.Background(), sessionID) if err != nil { return util.ReportError(err) @@ -87,6 +92,21 @@ func (m *UI) loadSession(sessionID string) tea.Cmd { readFiles: readFiles, } } + return tea.Batch(load, m.reportCurrentSession(sessionID)) +} + +// reportCurrentSession returns a fire-and-forget tea.Cmd that +// informs the workspace which session this client is currently +// viewing. Errors are logged at debug only; the call is a hint +// for server-side presence tracking, not correctness-critical +// state. +func (m *UI) reportCurrentSession(sessionID string) tea.Cmd { + return func() tea.Msg { + if err := m.com.Workspace.SetCurrentSession(context.Background(), sessionID); err != nil { + slog.Debug("Failed to report current session", "session_id", sessionID, "error", err) + } + return nil + } } func (m *UI) loadSessionFiles(sessionID string) ([]SessionFile, error) { diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index a0883082efc653e5472c5cfaf79a8c387ef9c29e..890dfc7de8a97eae13c4ecbd56ca07b566061408 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -3523,16 +3523,25 @@ func (m *UI) openPermissionsDialog(perm permission.PermissionRequest) tea.Cmd { // handlePermissionNotification updates tool items when permission state changes. func (m *UI) handlePermissionNotification(notification permission.PermissionNotification) { - toolItem := m.chat.MessageItem(notification.ToolCallID) - if toolItem == nil { - return + if toolItem := m.chat.MessageItem(notification.ToolCallID); toolItem != nil { + if permItem, ok := toolItem.(chat.ToolMessageItem); ok { + if notification.Granted { + permItem.SetStatus(chat.ToolStatusRunning) + } else { + permItem.SetStatus(chat.ToolStatusAwaitingPermission) + } + } } - if permItem, ok := toolItem.(chat.ToolMessageItem); ok { - if notification.Granted { - permItem.SetStatus(chat.ToolStatusRunning) - } else { - permItem.SetStatus(chat.ToolStatusAwaitingPermission) + // If this notification reflects a final resolution (granted or denied), + // dismiss any open permissions dialog whose tool call ID matches. This + // covers the case where another client resolved the request remotely. + if !notification.Granted && !notification.Denied { + return + } + if d := m.dialog.Dialog(dialog.PermissionsID); d != nil { + if perm, ok := d.(*dialog.Permissions); ok && perm.ToolCallID() == notification.ToolCallID { + m.dialog.CloseDialog(dialog.PermissionsID) } } } @@ -3601,6 +3610,7 @@ func (m *UI) newSession() tea.Cmd { return nil }, m.loadPromptHistory(), + m.reportCurrentSession(""), ) } diff --git a/internal/workspace/app_workspace.go b/internal/workspace/app_workspace.go index 17d5903793d3a42b38014d122be0c8d11216d803..c35a9f59fe2cb6b20ab74f21d45649046130b8b1 100644 --- a/internal/workspace/app_workspace.go +++ b/internal/workspace/app_workspace.go @@ -68,6 +68,13 @@ func (w *AppWorkspace) ParseAgentToolSessionID(sessionID string) (string, string return w.app.Sessions.ParseAgentToolSessionID(sessionID) } +// SetCurrentSession is a no-op in single-client local mode. The +// presence concept only matters when multiple clients can share a +// workspace via the HTTP server. +func (w *AppWorkspace) SetCurrentSession(ctx context.Context, sessionID string) error { + return nil +} + // -- Messages -- func (w *AppWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { @@ -174,16 +181,16 @@ func (w *AppWorkspace) GetDefaultSmallModel(providerID string) config.SelectedMo // -- Permissions -- -func (w *AppWorkspace) PermissionGrant(perm permission.PermissionRequest) { - w.app.Permissions.Grant(perm) +func (w *AppWorkspace) PermissionGrant(perm permission.PermissionRequest) bool { + return w.app.Permissions.Grant(perm) } -func (w *AppWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) { - w.app.Permissions.GrantPersistent(perm) +func (w *AppWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) bool { + return w.app.Permissions.GrantPersistent(perm) } -func (w *AppWorkspace) PermissionDeny(perm permission.PermissionRequest) { - w.app.Permissions.Deny(perm) +func (w *AppWorkspace) PermissionDeny(perm permission.PermissionRequest) bool { + return w.app.Permissions.Deny(perm) } func (w *AppWorkspace) PermissionSkipRequests() bool { diff --git a/internal/workspace/client_workspace.go b/internal/workspace/client_workspace.go index a959cb3842b891000195f7b51dcd0e2a7e0b240e..f4bd4ba35a6eb26db98420215f2a6282ebac0f9f 100644 --- a/internal/workspace/client_workspace.go +++ b/internal/workspace/client_workspace.go @@ -63,7 +63,7 @@ func NewClientWorkspace(c *client.Client, ws proto.Workspace) *ClientWorkspace { // refreshWorkspace re-fetches the workspace from the server, updating // the cached snapshot. Called after config-mutating operations. func (w *ClientWorkspace) refreshWorkspace() { - updated, err := w.client.GetWorkspace(context.Background(), w.ws.ID) + updated, err := w.client.GetWorkspace(context.Background(), w.workspaceID()) if err != nil { slog.Error("Failed to refresh workspace", "error", err) return @@ -142,6 +142,14 @@ func (w *ClientWorkspace) ParseAgentToolSessionID(sessionID string) (string, str return parts[0], parts[1], true } +// SetCurrentSession reports the session this client is currently +// viewing to the server. Empty sessionID clears the entry. Errors +// are propagated to the caller; the TUI logs and ignores them since +// the presence record is a hint, not correctness-critical state. +func (w *ClientWorkspace) SetCurrentSession(ctx context.Context, sessionID string) error { + return w.client.SetCurrentSession(ctx, w.workspaceID(), sessionID) +} + // -- Messages -- func (w *ClientWorkspace) ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) { @@ -255,8 +263,8 @@ func (w *ClientWorkspace) GetDefaultSmallModel(providerID string) config.Selecte // -- Permissions -- -func (w *ClientWorkspace) PermissionGrant(perm permission.PermissionRequest) { - _ = w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{ +func (w *ClientWorkspace) PermissionGrant(perm permission.PermissionRequest) bool { + resolved, _ := w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{ Permission: proto.PermissionRequest{ ID: perm.ID, SessionID: perm.SessionID, @@ -267,12 +275,13 @@ func (w *ClientWorkspace) PermissionGrant(perm permission.PermissionRequest) { Path: perm.Path, Params: perm.Params, }, - Action: proto.PermissionAllowForSession, + Action: proto.PermissionAllow, }) + return resolved } -func (w *ClientWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) { - _ = w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{ +func (w *ClientWorkspace) PermissionGrantPersistent(perm permission.PermissionRequest) bool { + resolved, _ := w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{ Permission: proto.PermissionRequest{ ID: perm.ID, SessionID: perm.SessionID, @@ -283,12 +292,13 @@ func (w *ClientWorkspace) PermissionGrantPersistent(perm permission.PermissionRe Path: perm.Path, Params: perm.Params, }, - Action: proto.PermissionAllow, + Action: proto.PermissionAllowForSession, }) + return resolved } -func (w *ClientWorkspace) PermissionDeny(perm permission.PermissionRequest) { - _ = w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{ +func (w *ClientWorkspace) PermissionDeny(perm permission.PermissionRequest) bool { + resolved, _ := w.client.GrantPermission(context.Background(), w.workspaceID(), proto.PermissionGrant{ Permission: proto.PermissionRequest{ ID: perm.ID, SessionID: perm.SessionID, @@ -301,6 +311,7 @@ func (w *ClientWorkspace) PermissionDeny(perm permission.PermissionRequest) { }, Action: proto.PermissionDeny, }) + return resolved } func (w *ClientWorkspace) PermissionSkipRequests() bool { @@ -593,10 +604,22 @@ func (w *ClientWorkspace) Subscribe(program *tea.Program) { return } + w.consumeEvents(evc, program.Send) +} + +// consumeEvents drives the workspace event loop. It is split out from +// Subscribe so tests can drive it without a real *tea.Program. +// ConfigChanged events trigger a workspace refresh; all other events +// are translated into domain types and forwarded to send. +func (w *ClientWorkspace) consumeEvents(evc <-chan any, send func(tea.Msg)) { for ev := range evc { + if _, ok := ev.(pubsub.Event[proto.ConfigChanged]); ok { + w.refreshWorkspace() + continue + } translated := w.translateEvent(ev) - if translated != nil { - program.Send(translated) + if translated != nil && send != nil { + send(translated) } } } @@ -714,6 +737,13 @@ func protoToMCPEventType(t proto.MCPEventType) mcp.EventType { } } +// protoToSession converts a wire-level proto.Session into the domain +// session.Session. Fields that exist only on the wire (computed-on-read +// signals like IsBusy, and any future presence counters) are +// intentionally dropped here: session.Session models persisted state, +// not transient runtime signals. UI features that need those signals +// should either extend session.Session or read them from the proto +// payload directly before this conversion runs. func protoToSession(s proto.Session) session.Session { return session.Session{ ID: s.ID, diff --git a/internal/workspace/client_workspace_test.go b/internal/workspace/client_workspace_test.go index 0e51d21b97df4502147b0aca0e1ca0477b196640..d88100457ed006fa537e66592fb04a6c190ad214 100644 --- a/internal/workspace/client_workspace_test.go +++ b/internal/workspace/client_workspace_test.go @@ -1,9 +1,16 @@ package workspace import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" "testing" + "github.com/charmbracelet/crush/internal/client" "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/skills" @@ -47,6 +54,87 @@ func TestProtoToMessageToolResult(t *testing.T) { require.False(t, tr.IsError) } +// TestClientWorkspace_PermissionGrantMapping verifies that +// PermissionGrant on the ClientWorkspace serializes a one-time grant +// (proto.PermissionAllow) and PermissionGrantPersistent serializes a +// persistent grant (proto.PermissionAllowForSession). A swap between +// these two would silently flip "allow once" into "remember for the +// session", and vice versa, so we pin the wire mapping here. +func TestClientWorkspace_PermissionGrantMapping(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + call func(*ClientWorkspace, permission.PermissionRequest) + want proto.PermissionAction + }{ + { + name: "Grant -> PermissionAllow", + call: func(w *ClientWorkspace, p permission.PermissionRequest) { + w.PermissionGrant(p) + }, + want: proto.PermissionAllow, + }, + { + name: "GrantPersistent -> PermissionAllowForSession", + call: func(w *ClientWorkspace, p permission.PermissionRequest) { + w.PermissionGrantPersistent(p) + }, + want: proto.PermissionAllowForSession, + }, + { + name: "Deny -> PermissionDeny", + call: func(w *ClientWorkspace, p permission.PermissionRequest) { + w.PermissionDeny(p) + }, + want: proto.PermissionDeny, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var got proto.PermissionGrant + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/v1/workspaces/ws-1/permissions/grant", r.URL.Path) + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(body, &got)) + require.NoError(t, json.NewEncoder(w).Encode(proto.PermissionGrantResponse{Resolved: true})) + })) + defer srv.Close() + + u, err := url.Parse(srv.URL) + require.NoError(t, err) + c, err := client.NewClient(t.TempDir(), "tcp", u.Host) + require.NoError(t, err) + + ws := NewClientWorkspace(c, proto.Workspace{ID: "ws-1"}) + + perm := permission.PermissionRequest{ + ID: "req-1", + SessionID: "sess-1", + ToolCallID: "tc-1", + ToolName: "tool", + Description: "do thing", + Action: "act", + Path: "/tmp/p", + } + tc.call(ws, perm) + + require.Equal(t, tc.want, got.Action) + require.Equal(t, "req-1", got.Permission.ID) + require.Equal(t, "sess-1", got.Permission.SessionID) + require.Equal(t, "tc-1", got.Permission.ToolCallID) + require.Equal(t, "tool", got.Permission.ToolName) + require.Equal(t, "act", got.Permission.Action) + require.Equal(t, "/tmp/p", got.Permission.Path) + }) + } +} + // TestProtoToSkillStates verifies that the wire representation of skill // discovery states reconstructs identical values on the client, // including synthetic errors derived from Error strings. diff --git a/internal/workspace/export_test.go b/internal/workspace/export_test.go new file mode 100644 index 0000000000000000000000000000000000000000..0395020899f91043d94454b42c3c92587fb4e506 --- /dev/null +++ b/internal/workspace/export_test.go @@ -0,0 +1,14 @@ +package workspace + +import ( + tea "charm.land/bubbletea/v2" +) + +// ConsumeEventsForTest runs the event-handling loop on the given +// channel, invoking send for translated domain messages and refreshing +// the cached workspace snapshot on ConfigChanged. Exposed for +// cross-package integration tests that cannot rely on a real +// *tea.Program. It returns when evc is closed. +func (w *ClientWorkspace) ConsumeEventsForTest(evc <-chan any, send func(tea.Msg)) { + w.consumeEvents(evc, send) +} diff --git a/internal/workspace/multiclient_integration_test.go b/internal/workspace/multiclient_integration_test.go new file mode 100644 index 0000000000000000000000000000000000000000..98f1603f519a5295f061a09023031848c73eb13b --- /dev/null +++ b/internal/workspace/multiclient_integration_test.go @@ -0,0 +1,176 @@ +package workspace_test + +import ( + "context" + "net/http/httptest" + "net/url" + "testing" + "time" + + tea "charm.land/bubbletea/v2" + "github.com/charmbracelet/crush/internal/client" + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/server" + "github.com/charmbracelet/crush/internal/workspace" + "github.com/stretchr/testify/require" +) + +// xdgIsolate redirects HOME and XDG_* to fresh temp dirs so config +// loading does not touch the host's real config. +func xdgIsolate(t *testing.T) { + t.Helper() + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) +} + +// runtimeServer wires the production server handler around an +// httptest.NewServer for integration testing. +type runtimeServer struct { + httpSrv *httptest.Server + host string +} + +func newRuntimeServer(t *testing.T) *runtimeServer { + t.Helper() + s := server.NewServer(nil, "tcp", "127.0.0.1:0") + hs := httptest.NewServer(s.Handler()) + t.Cleanup(hs.Close) + + u, err := url.Parse(hs.URL) + require.NoError(t, err) + return &runtimeServer{httpSrv: hs, host: u.Host} +} + +func (r *runtimeServer) newClient(t *testing.T, path string) *client.Client { + t.Helper() + c, err := client.NewClient(path, "tcp", r.host) + require.NoError(t, err) + return c +} + +// TestClientWorkspace_ConfigChangedRefreshesSiblingCache is the +// cross-client refresh end-to-end test required by PLAN item 4. Two +// ClientWorkspace instances pointed at the same backend workspace +// subscribe to events; when one mutates configuration via the server, +// the other's cached Config snapshot reflects the new value without +// a manual refresh. +func TestClientWorkspace_ConfigChangedRefreshesSiblingCache(t *testing.T) { + xdgIsolate(t) + rt := newRuntimeServer(t) + + cwd := t.TempDir() + dataDir := t.TempDir() + + cA := rt.newClient(t, cwd) + cB := rt.newClient(t, cwd) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + wsProto, err := cA.CreateWorkspace(ctx, proto.Workspace{Path: cwd, DataDir: dataDir}) + require.NoError(t, err) + // Client B joins the same workspace by path; the server + // deduplicates and returns the existing workspace. + wsProtoB, err := cB.CreateWorkspace(ctx, proto.Workspace{Path: cwd, DataDir: dataDir}) + require.NoError(t, err) + require.Equal(t, wsProto.ID, wsProtoB.ID) + + wsA := workspace.NewClientWorkspace(cA, *wsProto) + wsB := workspace.NewClientWorkspace(cB, *wsProtoB) + + // Both clients attach event streams. They run for the + // lifetime of the test; cancelling via context tears them + // down. consumeEvents is exercised by Subscribe in production; + // here we run it inline so we don't need a real *tea.Program. + evcA, err := cA.SubscribeEvents(ctx, wsProto.ID) + require.NoError(t, err) + evcB, err := cB.SubscribeEvents(ctx, wsProto.ID) + require.NoError(t, err) + + go wsA.ConsumeEventsForTest(evcA, func(tea.Msg) {}) + go wsB.ConsumeEventsForTest(evcB, func(tea.Msg) {}) + + // Pre-condition: neither cache has compact mode enabled yet. + require.NotNil(t, wsA.Config()) + require.NotNil(t, wsB.Config()) + require.False(t, compactMode(wsA.Config()), "compact mode must start disabled on client A") + require.False(t, compactMode(wsB.Config()), "compact mode must start disabled on client B") + + // Client A flips a real config-mutating workspace operation + // (SetCompactMode) via the server. PLAN item 4 acceptance: + // B's cached ws.Config must reflect this change without restart. + // SetCompactMode is used over UpdatePreferredModel because the + // latter's autoReload reverts unknown-provider models back to + // defaults during configureSelectedModels, which would make the + // assertion test infrastructure rather than the cache wiring. + require.NoError(t, wsA.SetCompactMode(config.ScopeGlobal, true)) + + // Client A writes and refreshes synchronously inside + // SetCompactMode, so its cache must already reflect the change. + // Eventually here absorbs any background work but should pass + // immediately. + require.Eventually(t, func() bool { return compactMode(wsA.Config()) }, + 3*time.Second, 25*time.Millisecond, + "client A cache must reflect its own compact-mode mutation") + + // Client B must see the same change via the ConfigChanged SSE + // event triggering its own cached refresh. + require.Eventually(t, func() bool { return compactMode(wsB.Config()) }, + 3*time.Second, 25*time.Millisecond, + "client B cache must reflect A's compact-mode mutation via SSE") +} + +// compactMode is a tiny accessor that survives nil intermediates so +// the Eventually polling loop can call it on a transient cache state. +func compactMode(cfg *config.Config) bool { + if cfg == nil || cfg.Options == nil { + return false + } + return cfg.Options.TUI.CompactMode +} + +// TestClientWorkspace_ConfigChangedSignalArrives is a smaller test +// that asserts the SSE wiring delivers a ConfigChanged event to the +// raw client subscription. It catches breakage in the +// wrapEvent/decoder bridge independent of the workspace cache. +func TestClientWorkspace_ConfigChangedSignalArrives(t *testing.T) { + xdgIsolate(t) + rt := newRuntimeServer(t) + + cwd := t.TempDir() + dataDir := t.TempDir() + + c := rt.newClient(t, cwd) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + wsProto, err := c.CreateWorkspace(ctx, proto.Workspace{Path: cwd, DataDir: dataDir}) + require.NoError(t, err) + + evc, err := c.SubscribeEvents(ctx, wsProto.ID) + require.NoError(t, err) + + require.NoError(t, c.SetConfigField(ctx, wsProto.ID, config.ScopeGlobal, "options.debug", true)) + + gotConfigChanged := false + deadline := time.After(3 * time.Second) +loop: + for !gotConfigChanged { + select { + case ev, ok := <-evc: + if !ok { + break loop + } + if cc, isCC := ev.(pubsub.Event[proto.ConfigChanged]); isCC { + require.Equal(t, wsProto.ID, cc.Payload.WorkspaceID) + gotConfigChanged = true + } + case <-deadline: + break loop + } + } + require.True(t, gotConfigChanged, "expected ConfigChanged event over SSE") +} diff --git a/internal/workspace/workspace.go b/internal/workspace/workspace.go index 0434ba21512ea9b732d6b94edb3015a1cd26a1e6..9049b7bc682836345fc37c8e8efc3588cf4d0e06 100644 --- a/internal/workspace/workspace.go +++ b/internal/workspace/workspace.go @@ -68,6 +68,12 @@ type Workspace interface { DeleteSession(ctx context.Context, sessionID string) error CreateAgentToolSessionID(messageID, toolCallID string) string ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) + // SetCurrentSession reports the session this client is currently + // viewing. Empty sessionID clears the entry (e.g. landing screen). + // In single-client local mode this is a no-op. In client/server + // mode it informs the server's per-client presence map so other + // observers can compute attached-client counts per session. + SetCurrentSession(ctx context.Context, sessionID string) error // Messages ListMessages(ctx context.Context, sessionID string) ([]message.Message, error) @@ -90,9 +96,17 @@ type Workspace interface { GetDefaultSmallModel(providerID string) config.SelectedModel // Permissions - PermissionGrant(perm permission.PermissionRequest) - PermissionGrantPersistent(perm permission.PermissionRequest) - PermissionDeny(perm permission.PermissionRequest) + // + // PermissionGrant, PermissionGrantPersistent, and PermissionDeny + // return true if the call resolved the pending request and false if + // it had already been resolved by another subscriber (or is no + // longer pending). A false return is not an error; the modal can + // still close locally because the resolution will arrive via the + // PermissionNotification event stream regardless of which client + // won the race. + PermissionGrant(perm permission.PermissionRequest) bool + PermissionGrantPersistent(perm permission.PermissionRequest) bool + PermissionDeny(perm permission.PermissionRequest) bool PermissionSkipRequests() bool PermissionSetSkipRequests(skip bool)