From 1ce40dfc00b9394a8212ef008f05e7fd7efc8fba Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Tue, 19 May 2026 22:23:33 -0400 Subject: [PATCH] feat(server): share one workspace per directory across clients Multiple Crush clients connecting to the same server with the same working directory now share a single underlying workspace. Conflicting startup flags follow a first wins rule. Workspace lifetime is tied to live event streams plus a short grace window after creation, so a workspace stays alive as long as any client is attached and is torn down only after the last one disconnects. Co-Authored-By: Charm Crush --- internal/backend/backend.go | 445 ++++++++++++- internal/backend/backend_test.go | 953 ++++++++++++++++++++++++++++ internal/client/client.go | 17 +- internal/client/proto.go | 7 +- internal/proto/proto.go | 17 +- internal/server/multiclient_test.go | 107 ++++ internal/server/proto.go | 38 +- 7 files changed, 1538 insertions(+), 46 deletions(-) create mode 100644 internal/backend/backend_test.go create mode 100644 internal/server/multiclient_test.go diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 642b48a9222de132ffd24f9c356d4b7152a38591..4d377ac4d983c076ff86a970250c20b1d7adbe4b 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -8,7 +8,10 @@ import ( "errors" "fmt" "log/slog" + "path/filepath" "runtime" + "sync" + "time" "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" @@ -28,19 +31,59 @@ var ( ErrPathRequired = errors.New("path is required") ErrInvalidPermissionAction = errors.New("invalid permission action") ErrUnknownCommand = errors.New("unknown command") + ErrInvalidClientID = errors.New("invalid client_id") ) +// DefaultCreateGrace is the window in which a client must open an SSE +// stream after creating a workspace before its creation hold is +// released. Exposed as a package variable so tests can shorten it. +var DefaultCreateGrace = 30 * time.Second + // ShutdownFunc is called when the backend needs to trigger a server // shutdown (e.g. when the last workspace is removed). type ShutdownFunc func() // Backend provides transport-agnostic business logic for the Crush // server. It manages workspaces and delegates to [app.App] services. +// +// Locking order: when both [Backend.mu] and [Workspace.clientsMu] are +// held at once, [Backend.mu] is acquired first. Detach paths +// ([detachStream], [releaseHoldLocked], [expireHold]) only hold +// [Workspace.clientsMu] briefly, drop it, then call [teardown] which +// takes [Backend.mu] (and then re-takes [Workspace.clientsMu] to +// re-check that the workspace has not been re-claimed). This avoids +// the AB/BA hazard with [CreateWorkspace], which holds [Backend.mu] +// while calling [registerClient] so that a workspace cannot be torn +// down beneath it. type Backend struct { workspaces *csync.Map[string, *Workspace] - cfg *config.ConfigStore - ctx context.Context - shutdownFn ShutdownFunc + // pathIndex maps a resolved absolute workspace path to its + // workspace ID. Reads and writes are serialised via mu so + // concurrent CreateWorkspace calls at the same path deduplicate + // deterministically. + pathIndex map[string]string + mu sync.Mutex + + cfg *config.ConfigStore + ctx context.Context + shutdownFn ShutdownFunc + createGrace time.Duration +} + +// clientState tracks one client's claim on a workspace. +// +// - streams counts the number of live SSE event streams the client +// currently has open against the workspace. +// - holdTimer is non-nil iff the client created the workspace but has +// not yet attached an SSE stream; it fires after createGrace and +// releases the hold. +// +// The two are mutually exclusive in practice (the hold timer is stopped +// the moment an SSE stream attaches), but both being zero/nil means the +// entry has been released and should be removed. +type clientState struct { + streams int + holdTimer *time.Timer } // Workspace represents a running [app.App] workspace with its @@ -51,18 +94,57 @@ type Workspace struct { Path string Cfg *config.ConfigStore Env []string + + // resolvedPath is the path used as the dedup key in + // Backend.pathIndex. It is filepath.EvalSymlinks(filepath.Abs(Path)) + // with fallback to the cleaned absolute path. + resolvedPath string + + // clientsMu guards clients. It is held only briefly (no IO). + clientsMu sync.Mutex + // clients tracks each client's claim on this workspace. Refcount + // is a derived value: len(clients). + clients map[string]*clientState + + // shutdownFn is the function invoked by [Backend.teardown] to + // release the workspace's underlying resources. It defaults to the + // embedded [app.App.Shutdown]; tests may override it to avoid + // driving a full [app.App] through shutdown. + shutdownFn func() +} + +// invokeShutdown calls the workspace shutdown hook if set, falling +// back to the embedded [app.App.Shutdown] when not. +func (w *Workspace) invokeShutdown() { + if w.shutdownFn != nil { + w.shutdownFn() + return + } + if w.App != nil { + w.Shutdown() + } } // New creates a new [Backend]. func New(ctx context.Context, cfg *config.ConfigStore, shutdownFn ShutdownFunc) *Backend { return &Backend{ - workspaces: csync.NewMap[string, *Workspace](), - cfg: cfg, - ctx: ctx, - shutdownFn: shutdownFn, + workspaces: csync.NewMap[string, *Workspace](), + pathIndex: make(map[string]string), + cfg: cfg, + ctx: ctx, + shutdownFn: shutdownFn, + createGrace: DefaultCreateGrace, } } +// SetCreateGrace overrides the create-grace window. Intended for tests +// that need short timeouts. +func (b *Backend) SetCreateGrace(d time.Duration) { + b.mu.Lock() + defer b.mu.Unlock() + b.createGrace = d +} + // GetWorkspace retrieves a workspace by ID. func (b *Backend) GetWorkspace(id string) (*Workspace, error) { ws, ok := b.workspaces.Get(id) @@ -82,12 +164,46 @@ func (b *Backend) ListWorkspaces() []proto.Workspace { } // CreateWorkspace initializes a new workspace from the given -// parameters. It creates the config, database connection, and -// [app.App] instance. +// parameters, or returns an existing workspace if one already exists at +// the same resolved path (first-wins semantics). +// +// args.ClientID must be a valid UUID identifying the calling client; +// the resulting workspace registers a creation hold on behalf of that +// client which is released either by the first SSE attach (which +// converts it into a stream claim) or by the grace window expiring. func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Workspace, error) { if args.Path == "" { return nil, proto.Workspace{}, ErrPathRequired } + clientID, err := validateClientID(args.ClientID) + if err != nil { + return nil, proto.Workspace{}, err + } + + key, err := resolveWorkspaceKey(args.Path) + if err != nil { + return nil, proto.Workspace{}, fmt.Errorf("failed to resolve workspace path: %w", err) + } + + b.mu.Lock() + if existingID, ok := b.pathIndex[key]; ok { + if ws, found := b.workspaces.Get(existingID); found { + // Hold b.mu while registering: teardown also + // acquires b.mu before tearing the workspace + // down, so this guarantees the workspace we + // return cannot be torn out from under us + // between lookup and registerClient. Lock order + // here is b.mu -> ws.clientsMu. + logFirstWinsMismatch(ws, args) + b.registerClient(ws, clientID) + b.mu.Unlock() + return ws, workspaceToProto(ws), nil + } + // pathIndex referenced a workspace that has since been + // removed; clean the stale entry and fall through. + delete(b.pathIndex, key) + } + b.mu.Unlock() id := uuid.New().String() cfg, err := config.Init(args.Path, args.DataDir, args.Debug) @@ -112,14 +228,38 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works } ws := &Workspace{ - App: appWorkspace, - ID: id, - Path: args.Path, - Cfg: cfg, - Env: args.Env, + App: appWorkspace, + ID: id, + Path: args.Path, + Cfg: cfg, + Env: args.Env, + resolvedPath: key, + clients: make(map[string]*clientState), } + b.mu.Lock() + // Re-check the index under the lock: a concurrent caller may have + // won the race between the initial unlock and here. + if existingID, ok := b.pathIndex[key]; ok { + if existing, found := b.workspaces.Get(existingID); found { + // Register under b.mu so teardown cannot run + // between lookup and registerClient. Lock order + // is b.mu -> ws.clientsMu. + logFirstWinsMismatch(existing, args) + b.registerClient(existing, clientID) + b.mu.Unlock() + ws.invokeShutdown() + return existing, workspaceToProto(existing), nil + } + delete(b.pathIndex, key) + } b.workspaces.Set(id, ws) + b.pathIndex[key] = id + // Register the originating client's hold while still holding + // b.mu so the workspace is observable with its claim from the + // moment it appears in the index. + b.registerClient(ws, clientID) + b.mu.Unlock() if args.Version != "" && args.Version != version.Version { slog.Warn( @@ -133,34 +273,201 @@ func (b *Backend) CreateWorkspace(args proto.Workspace) (*Workspace, proto.Works ))) } - result := proto.Workspace{ - ID: id, - Path: args.Path, - DataDir: cfg.Config().Options.DataDirectory, - Debug: cfg.Config().Options.Debug, - YOLO: cfg.Overrides().SkipPermissionRequests, - Config: cfg.Config(), - Env: args.Env, + return ws, workspaceToProto(ws), nil +} + +// AttachClient registers a new SSE stream for the given client on the +// workspace. The stream's deferred cleanup must call DetachClient with +// the same arguments to release the claim. +// +// The lookup and the clients-map mutation are performed under +// [Backend.mu] so that AttachClient cannot race with [Backend.teardown]: +// teardown also holds [Backend.mu] while removing the workspace from +// b.workspaces, so once AttachClient observes the workspace and takes +// ws.clientsMu (under b.mu), no concurrent teardown can succeed without +// re-checking the (now non-empty) clients map. Lock order is the +// canonical b.mu -> ws.clientsMu. +func (b *Backend) AttachClient(workspaceID, clientID string) error { + if _, err := validateClientID(clientID); err != nil { + return err + } + + b.mu.Lock() + defer b.mu.Unlock() + ws, ok := b.workspaces.Get(workspaceID) + if !ok { + return ErrWorkspaceNotFound } - return ws, result, nil + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + cs, ok := ws.clients[clientID] + if !ok { + // Defensive: SSE attach without a prior CreateWorkspace by + // this client still installs a stream claim so the stream + // stays alive for its duration. + ws.clients[clientID] = &clientState{streams: 1} + return nil + } + if cs.holdTimer != nil { + cs.holdTimer.Stop() + cs.holdTimer = nil + } + cs.streams++ + return nil } -// DeleteWorkspace shuts down and removes a workspace. If it was the -// last workspace, the shutdown callback is invoked. -func (b *Backend) DeleteWorkspace(id string) { - ws, ok := b.workspaces.Get(id) - if ok { - ws.Shutdown() +// DetachClient releases one SSE stream's hold on the workspace. If the +// client has no other streams and no pending creation hold, its claim +// is removed and the workspace is torn down once refcount hits zero. +func (b *Backend) DetachClient(workspaceID, clientID string) { + ws, ok := b.workspaces.Get(workspaceID) + if !ok { + return + } + b.detachStream(ws, clientID) +} + +// releaseHold releases the creation hold for a client, if any. Active +// stream claims are unaffected. Idempotent: returns nil if the +// workspace or the client's hold no longer exist. +func (b *Backend) releaseHold(workspaceID, clientID string) error { + if _, err := validateClientID(clientID); err != nil { + return err + } + ws, ok := b.workspaces.Get(workspaceID) + if !ok { + return nil + } + b.releaseHoldLocked(ws, clientID) + return nil +} + +// registerClient installs (idempotently) the given client's claim on +// the workspace and starts a grace timer if the entry is fresh. +func (b *Backend) registerClient(ws *Workspace, clientID string) { + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + if _, ok := ws.clients[clientID]; ok { + // Idempotent: a duplicate CreateWorkspace from the same + // client does not add a second claim. + return + } + cs := &clientState{} + cs.holdTimer = time.AfterFunc(b.createGrace, func() { + b.expireHold(ws, clientID, cs) + }) + ws.clients[clientID] = cs +} + +// expireHold is the body of the grace timer. It runs in its own +// goroutine and races against AttachClient/releaseHold; the timer +// stays valid only while the entry's holdTimer still points at it. +func (b *Backend) expireHold(ws *Workspace, clientID string, timer *clientState) { + ws.clientsMu.Lock() + cs, ok := ws.clients[clientID] + if !ok || cs != timer || cs.holdTimer == nil || cs.streams > 0 { + ws.clientsMu.Unlock() + return + } + cs.holdTimer = nil + delete(ws.clients, clientID) + teardown := len(ws.clients) == 0 + ws.clientsMu.Unlock() + if teardown { + b.teardown(ws) + } +} + +func (b *Backend) releaseHoldLocked(ws *Workspace, clientID string) { + ws.clientsMu.Lock() + cs, ok := ws.clients[clientID] + if !ok { + ws.clientsMu.Unlock() + return + } + if cs.holdTimer != nil { + cs.holdTimer.Stop() + cs.holdTimer = nil + } + teardown := false + if cs.streams == 0 { + delete(ws.clients, clientID) + teardown = len(ws.clients) == 0 + } + ws.clientsMu.Unlock() + if teardown { + b.teardown(ws) + } +} + +func (b *Backend) detachStream(ws *Workspace, clientID string) { + ws.clientsMu.Lock() + cs, ok := ws.clients[clientID] + if !ok { + ws.clientsMu.Unlock() + return + } + if cs.streams > 0 { + cs.streams-- + } + teardown := false + if cs.streams == 0 && cs.holdTimer == nil { + delete(ws.clients, clientID) + teardown = len(ws.clients) == 0 + } + ws.clientsMu.Unlock() + if teardown { + b.teardown(ws) + } +} + +// teardown removes the workspace from the index, shuts down its +// underlying [app.App], and triggers a server shutdown if it was the +// last workspace alive. +// +// Callers reach teardown after observing len(ws.clients) == 0 while +// holding ws.clientsMu and then releasing it. Between that release +// and the b.mu.Lock below, a concurrent CreateWorkspace may have +// re-registered a client (CreateWorkspace holds b.mu while doing so, +// so it is mutually exclusive with this critical section). teardown +// re-checks under both locks (in the canonical b.mu -> ws.clientsMu +// order) and aborts if the workspace has been re-claimed. +func (b *Backend) teardown(ws *Workspace) { + b.mu.Lock() + ws.clientsMu.Lock() + if len(ws.clients) > 0 { + // Race: a CreateWorkspace re-registered a client + // between the detach path dropping ws.clientsMu and us + // taking b.mu. Abort: the workspace is still alive. + ws.clientsMu.Unlock() + b.mu.Unlock() + return } - b.workspaces.Del(id) + ws.clientsMu.Unlock() + if existing, ok := b.pathIndex[ws.resolvedPath]; ok && existing == ws.ID { + delete(b.pathIndex, ws.resolvedPath) + } + b.workspaces.Del(ws.ID) + remaining := b.workspaces.Len() + b.mu.Unlock() + + ws.invokeShutdown() - if b.workspaces.Len() == 0 && b.shutdownFn != nil { + if remaining == 0 && b.shutdownFn != nil { slog.Info("Last workspace removed, shutting down server...") b.shutdownFn() } } +// DeleteWorkspace is the public entry point used by the HTTP DELETE +// handler. It releases the named client's creation hold; live streams +// from the same client remain attached and continue holding the +// workspace open until their own deferred DetachClient runs. +func (b *Backend) DeleteWorkspace(id, clientID string) error { + return b.releaseHold(id, clientID) +} + // GetWorkspaceProto returns the proto representation of a workspace. func (b *Backend) GetWorkspaceProto(id string) (proto.Workspace, error) { ws, err := b.GetWorkspace(id) @@ -193,6 +500,33 @@ func (b *Backend) Shutdown() { } } +// resolveWorkspaceKey returns a stable canonical form of path suitable +// for use as a dedup key. It applies filepath.Abs, then attempts +// filepath.EvalSymlinks; because EvalSymlinks errors on non-existent +// paths, it falls back to the cleaned absolute path in that case. +func resolveWorkspaceKey(path string) (string, error) { + abs, err := filepath.Abs(path) + if err != nil { + return "", err + } + if resolved, err := filepath.EvalSymlinks(abs); err == nil { + return resolved, nil + } + return abs, nil +} + +// validateClientID returns the trimmed UUID string or an error if the +// input is empty or not a valid UUID. +func validateClientID(id string) (string, error) { + if id == "" { + return "", ErrInvalidClientID + } + if _, err := uuid.Parse(id); err != nil { + return "", fmt.Errorf("%w: %v", ErrInvalidClientID, err) + } + return id, nil +} + func workspaceToProto(ws *Workspace) proto.Workspace { cfg := ws.Cfg.Config() return proto.Workspace{ @@ -202,5 +536,54 @@ func workspaceToProto(ws *Workspace) proto.Workspace { DataDir: cfg.Options.DataDirectory, Debug: cfg.Options.Debug, Config: cfg, + Env: ws.Env, + Version: version.Version, + } +} + +// logFirstWinsMismatch emits a debug line whenever a second +// CreateWorkspace at the same resolved path arrives with flags that +// differ from the originating workspace. The existing workspace wins; +// the incoming flags are silently ignored. +// +// The comparison is done against the incoming args as the caller sent +// them — including empty/zero values — rather than after defaulting. +// This means that, for example, a second caller who omits DataDir +// while the first set one will still log the mismatch. +func logFirstWinsMismatch(existing *Workspace, args proto.Workspace) { + existingCfg := existing.Cfg.Config() + existingYOLO := existing.Cfg.Overrides().SkipPermissionRequests + if existingYOLO == args.YOLO && + existingCfg.Options.Debug == args.Debug && + existingCfg.Options.DataDirectory == args.DataDir && + stringSlicesEqual(existing.Env, args.Env) { + return + } + slog.Debug( + "Workspace flag mismatch on duplicate create; first wins", + "workspace_id", existing.ID, + "path", existing.Path, + "existing_yolo", existingYOLO, + "requested_yolo", args.YOLO, + "existing_debug", existingCfg.Options.Debug, + "requested_debug", args.Debug, + "existing_data_dir", existingCfg.Options.DataDirectory, + "requested_data_dir", args.DataDir, + "existing_env", existing.Env, + "requested_env", args.Env, + ) +} + +// stringSlicesEqual reports whether a and b contain the same strings +// in the same order. nil and empty are treated as equal. +func stringSlicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } } + return true } diff --git a/internal/backend/backend_test.go b/internal/backend/backend_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d1a6f78abb9cc38c1f8a463ed7be548c22e332a1 --- /dev/null +++ b/internal/backend/backend_test.go @@ -0,0 +1,953 @@ +package backend + +import ( + "bytes" + "context" + "errors" + "log/slog" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/proto" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// newTestBackend returns a Backend whose teardown path skips any +// real [app.App] shutdown work. Useful for state-machine tests that +// install synthetic workspaces directly via insertTestWorkspace. +func newTestBackend(t *testing.T) (*Backend, *atomic.Int32) { + t.Helper() + var shutdownCount atomic.Int32 + b := &Backend{ + workspaces: csync.NewMap[string, *Workspace](), + pathIndex: make(map[string]string), + ctx: context.Background(), + createGrace: 50 * time.Millisecond, + shutdownFn: func() { shutdownCount.Add(1) }, + } + return b, &shutdownCount +} + +// insertTestWorkspace installs a synthetic workspace into b at the +// given resolved path. Its shutdownFn is recorded in the returned +// counter so tests can assert it ran exactly once. +func insertTestWorkspace(t *testing.T, b *Backend, key string) (*Workspace, *atomic.Int32) { + t.Helper() + var shutdowns atomic.Int32 + ws := &Workspace{ + ID: uuid.New().String(), + Path: key, + resolvedPath: key, + clients: make(map[string]*clientState), + shutdownFn: func() { shutdowns.Add(1) }, + } + b.mu.Lock() + b.workspaces.Set(ws.ID, ws) + b.pathIndex[key] = ws.ID + b.mu.Unlock() + return ws, &shutdowns +} + +func newClientID(t *testing.T) string { + t.Helper() + return uuid.New().String() +} + +func TestResolveWorkspaceKey_AbsoluteAndSymlink(t *testing.T) { + t.Parallel() + + tmp := t.TempDir() + real, err := filepath.EvalSymlinks(tmp) + require.NoError(t, err) + + got, err := resolveWorkspaceKey(tmp) + require.NoError(t, err) + require.Equal(t, real, got) +} + +func TestResolveWorkspaceKey_NonExistentFallback(t *testing.T) { + t.Parallel() + + missing := filepath.Join(t.TempDir(), "does", "not", "exist") + got, err := resolveWorkspaceKey(missing) + require.NoError(t, err) + abs, err := filepath.Abs(missing) + require.NoError(t, err) + require.Equal(t, abs, got) +} + +func TestValidateClientID(t *testing.T) { + t.Parallel() + + _, err := validateClientID("") + require.ErrorIs(t, err, ErrInvalidClientID) + _, err = validateClientID("not-a-uuid") + require.ErrorIs(t, err, ErrInvalidClientID) + + id := uuid.New().String() + got, err := validateClientID(id) + require.NoError(t, err) + require.Equal(t, id, got) +} + +func TestRegisterClient_Idempotent(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + b.registerClient(ws, cid) + + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + require.Len(t, ws.clients, 1) + require.NotNil(t, ws.clients[cid].holdTimer) + require.Equal(t, 0, ws.clients[cid].streams) +} + +func TestAttachClient_ConsumesHold(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.AttachClient(ws.ID, cid)) + + ws.clientsMu.Lock() + require.Len(t, ws.clients, 1) + require.Nil(t, ws.clients[cid].holdTimer, "attach must stop the grace timer") + require.Equal(t, 1, ws.clients[cid].streams) + ws.clientsMu.Unlock() + + // Wait past the grace window: a stopped timer must not fire. + time.Sleep(150 * time.Millisecond) + require.Equal(t, int32(0), shutdowns.Load(), "workspace must not be torn down while attached") +} + +func TestAttachClient_WithoutPriorCreate(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cid)) + + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + require.Len(t, ws.clients, 1) + require.Equal(t, 1, ws.clients[cid].streams) + require.Nil(t, ws.clients[cid].holdTimer) +} + +func TestAttachClient_DuplicateStreams(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cid)) + require.NoError(t, b.AttachClient(ws.ID, cid)) + + ws.clientsMu.Lock() + require.Equal(t, 2, ws.clients[cid].streams) + ws.clientsMu.Unlock() + + b.DetachClient(ws.ID, cid) + ws.clientsMu.Lock() + require.Equal(t, 1, ws.clients[cid].streams) + ws.clientsMu.Unlock() + require.Equal(t, int32(0), shutdowns.Load()) + + b.DetachClient(ws.ID, cid) + require.Equal(t, int32(1), shutdowns.Load(), "second detach tears down the workspace") +} + +func TestDetachClient_LastStreamTearsDown(t *testing.T) { + t.Parallel() + + b, srvShutdowns := newTestBackend(t) + ws, wsShutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.AttachClient(ws.ID, cid)) + b.DetachClient(ws.ID, cid) + + require.Equal(t, int32(1), wsShutdowns.Load()) + require.Equal(t, int32(1), srvShutdowns.Load(), "last workspace shut down must trigger server shutdown") + _, err := b.GetWorkspace(ws.ID) + require.ErrorIs(t, err, ErrWorkspaceNotFound) +} + +func TestHoldExpiry_TearsDown(t *testing.T) { + t.Parallel() + + b, srvShutdowns := newTestBackend(t) + ws, wsShutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + + require.Eventually(t, func() bool { + return wsShutdowns.Load() == 1 && srvShutdowns.Load() == 1 + }, 1*time.Second, 5*time.Millisecond) +} + +func TestReleaseHold_NoStreams(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.releaseHold(ws.ID, cid)) + + require.Equal(t, int32(1), shutdowns.Load()) + // Idempotent. + require.NoError(t, b.releaseHold(ws.ID, cid)) + require.Equal(t, int32(1), shutdowns.Load()) +} + +func TestReleaseHold_WithActiveStream(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.AttachClient(ws.ID, cid)) + require.NoError(t, b.releaseHold(ws.ID, cid)) + + ws.clientsMu.Lock() + require.Equal(t, 1, ws.clients[cid].streams) + require.Nil(t, ws.clients[cid].holdTimer) + ws.clientsMu.Unlock() + require.Equal(t, int32(0), shutdowns.Load()) + + b.DetachClient(ws.ID, cid) + require.Equal(t, int32(1), shutdowns.Load()) +} + +func TestReleaseHoldThenAttach(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + require.NoError(t, b.releaseHold(ws.ID, cid)) // no entry yet — no-op. + require.NoError(t, b.AttachClient(ws.ID, cid)) + ws.clientsMu.Lock() + require.Equal(t, 1, ws.clients[cid].streams) + ws.clientsMu.Unlock() + require.NoError(t, b.releaseHold(ws.ID, cid)) // hold-only no-op (no hold timer). + require.Equal(t, int32(0), shutdowns.Load()) + b.DetachClient(ws.ID, cid) + require.Equal(t, int32(1), shutdowns.Load()) +} + +func TestRefcountWithSecondClient(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/a") + + cidA := newClientID(t) + cidB := newClientID(t) + b.registerClient(ws, cidA) + require.NoError(t, b.AttachClient(ws.ID, cidA)) + b.registerClient(ws, cidB) + require.NoError(t, b.AttachClient(ws.ID, cidB)) + + b.DetachClient(ws.ID, cidA) + ws.clientsMu.Lock() + require.Contains(t, ws.clients, cidB) + require.NotContains(t, ws.clients, cidA) + ws.clientsMu.Unlock() + require.Equal(t, int32(0), shutdowns.Load(), "workspace survives while second client attached") + + b.DetachClient(ws.ID, cidB) + require.Equal(t, int32(1), shutdowns.Load()) +} + +func TestAttachClient_InvalidID(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + require.ErrorIs(t, b.AttachClient(ws.ID, ""), ErrInvalidClientID) + require.ErrorIs(t, b.AttachClient(ws.ID, "not-a-uuid"), ErrInvalidClientID) +} + +func TestDeleteWorkspace_RejectsBadClientID(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + require.ErrorIs(t, b.DeleteWorkspace(ws.ID, ""), ErrInvalidClientID) + require.ErrorIs(t, b.DeleteWorkspace(ws.ID, "not-a-uuid"), ErrInvalidClientID) +} + +// TestHoldExpiry_RaceWithAttach checks that, when the grace timer fires +// while a concurrent AttachClient call is in flight, the workspace ends +// up either fully attached or fully torn down — never in a half-state. +func TestHoldExpiry_RaceWithAttach(t *testing.T) { + t.Parallel() + + for i := range 50 { + b, _ := newTestBackend(t) + // Tighten the grace window further to force the race. + b.createGrace = 1 * time.Millisecond + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/race") + + cid := newClientID(t) + b.registerClient(ws, cid) + // Attach concurrently with the very short grace timer. + errCh := make(chan error, 1) + go func() { errCh <- b.AttachClient(ws.ID, cid) }() + <-errCh + + // Wait for any pending timer to settle. + time.Sleep(10 * time.Millisecond) + + ws.clientsMu.Lock() + gotShutdown := shutdowns.Load() == 1 + cs, present := ws.clients[cid] + var ( + gotStreams int + gotHoldTimer *time.Timer + ) + if present { + gotStreams = cs.streams + gotHoldTimer = cs.holdTimer + } + ws.clientsMu.Unlock() + // Either the workspace was torn down OR the client is + // attached with streams==1 and the hold timer cleared. + // The state must be consistent: if shutdown, client is + // gone; if attached, no teardown and streams==1. + if gotShutdown { + require.False(t, present, "iter %d: shutdown but client still present", i) + } else { + require.True(t, present, "iter %d: not shutdown but client missing", i) + require.Equal(t, 1, gotStreams, "iter %d: attach winner must leave streams=1", i) + require.Nil(t, gotHoldTimer, "iter %d: attach winner must clear holdTimer", i) + } + } +} + +// TestConcurrentAttachDetach exercises the state machine under +// parallel attach/detach pressure with the race detector. +func TestConcurrentAttachDetach(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, _ := insertTestWorkspace(t, b, "/tmp/a") + + cid := newClientID(t) + b.registerClient(ws, cid) + require.NoError(t, b.AttachClient(ws.ID, cid)) // ensure refcount stays > 0. + + const n = 50 + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + cid2 := newClientID(t) + _ = b.AttachClient(ws.ID, cid2) + b.DetachClient(ws.ID, cid2) + }() + } + wg.Wait() + + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + require.Len(t, ws.clients, 1) + require.Contains(t, ws.clients, cid) +} + +// TestPathDedupe_FullCreate exercises CreateWorkspace end-to-end +// (config init, real app.App). Two CreateWorkspace calls at the same +// path return the same workspace ID and share the clients map. +func TestPathDedupe_FullCreate(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwd := t.TempDir() + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + cidA := uuid.New().String() + cidB := uuid.New().String() + + wsA, protoA, err := b.CreateWorkspace(protoWS(cwd, dataDir, cidA)) + require.NoError(t, err) + require.NotEmpty(t, protoA.ID) + require.Equal(t, protoA.DataDir, wsA.Cfg.Config().Options.DataDirectory) + + wsB, protoB, err := b.CreateWorkspace(protoWS(cwd, dataDir, cidB)) + require.NoError(t, err) + require.Equal(t, wsA.ID, wsB.ID, "second create at same path must return existing workspace") + require.Equal(t, protoA.ID, protoB.ID) + + wsA.clientsMu.Lock() + require.Contains(t, wsA.clients, cidA) + require.Contains(t, wsA.clients, cidB) + wsA.clientsMu.Unlock() +} + +// TestPathDedupe_DifferentPaths_DifferentWorkspaces confirms that two +// CreateWorkspace calls at distinct paths produce distinct workspaces. +func TestPathDedupe_DifferentPaths_DifferentWorkspaces(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwdA := t.TempDir() + cwdB := t.TempDir() + dataA := t.TempDir() + dataB := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + wsA, _, err := b.CreateWorkspace(protoWS(cwdA, dataA, uuid.New().String())) + require.NoError(t, err) + wsB, _, err := b.CreateWorkspace(protoWS(cwdB, dataB, uuid.New().String())) + require.NoError(t, err) + require.NotEqual(t, wsA.ID, wsB.ID) +} + +// TestPathDedupe_FirstWinsKeepsOriginalEnv verifies that the second +// create at the same path returns the *originating* client's Env in +// its proto and does not mutate the existing workspace's YOLO/Debug +// flags. +func TestPathDedupe_FirstWinsKeepsOriginalEnv(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwd := t.TempDir() + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + originalEnv := []string{"FOO=bar"} + argsA := protoWS(cwd, dataDir, uuid.New().String()) + argsA.YOLO = true + argsA.Env = originalEnv + wsA, protoA, err := b.CreateWorkspace(argsA) + require.NoError(t, err) + require.True(t, protoA.YOLO) + require.Equal(t, originalEnv, protoA.Env) + + argsB := protoWS(cwd, dataDir, uuid.New().String()) + argsB.YOLO = false + argsB.Debug = true + argsB.Env = []string{"BAZ=qux"} + _, protoB, err := b.CreateWorkspace(argsB) + require.NoError(t, err) + require.Equal(t, protoA.ID, protoB.ID) + require.True(t, protoB.YOLO, "first wins: YOLO must remain true") + require.Equal(t, originalEnv, protoB.Env, "proto must carry the originating client's Env") + require.Equal(t, wsA.Cfg.Overrides().SkipPermissionRequests, true) +} + +// TestPathDedupe_Symlink confirms two paths that resolve to the same +// target share a workspace. +func TestPathDedupe_Symlink(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink semantics differ on windows") + } + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + real := t.TempDir() + link := filepath.Join(t.TempDir(), "link") + require.NoError(t, os.Symlink(real, link)) + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + wsA, _, err := b.CreateWorkspace(protoWS(real, dataDir, uuid.New().String())) + require.NoError(t, err) + wsB, _, err := b.CreateWorkspace(protoWS(link, dataDir, uuid.New().String())) + require.NoError(t, err) + require.Equal(t, wsA.ID, wsB.ID) +} + +// TestPathDedupe_NonExistentPath ensures CreateWorkspace tolerates a +// path that does not yet exist (EvalSymlinks falls back to Abs). +func TestPathDedupe_NonExistentPath(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + parent := t.TempDir() + missing := filepath.Join(parent, "does-not-exist") + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + _, p, err := b.CreateWorkspace(protoWS(missing, dataDir, uuid.New().String())) + require.NoError(t, err) + require.NotEmpty(t, p.ID) +} + +// TestCreateWorkspace_IdempotentSameClient checks that a duplicate +// create from the same client at the same path does not produce a +// second claim. +func TestCreateWorkspace_IdempotentSameClient(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwd := t.TempDir() + dataDir := t.TempDir() + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + cid := uuid.New().String() + ws1, _, err := b.CreateWorkspace(protoWS(cwd, dataDir, cid)) + require.NoError(t, err) + ws2, _, err := b.CreateWorkspace(protoWS(cwd, dataDir, cid)) + require.NoError(t, err) + require.Equal(t, ws1.ID, ws2.ID) + + ws1.clientsMu.Lock() + require.Len(t, ws1.clients, 1, "duplicate create from same client must not double the claim") + ws1.clientsMu.Unlock() +} + +// TestPathDedupe_ParallelCreates ensures two simultaneous CreateWorkspace +// calls at the same path produce the same workspace and the clients map +// contains both client IDs. +func TestPathDedupe_ParallelCreates(t *testing.T) { + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) + + cwd := t.TempDir() + dataDir := t.TempDir() + + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + cidA := uuid.New().String() + cidB := uuid.New().String() + + type result struct { + ws *Workspace + proto proto.Workspace + err error + } + ch := make(chan result, 2) + start := make(chan struct{}) + go func() { + <-start + ws, p, err := b.CreateWorkspace(protoWS(cwd, dataDir, cidA)) + ch <- result{ws, p, err} + }() + go func() { + <-start + ws, p, err := b.CreateWorkspace(protoWS(cwd, dataDir, cidB)) + ch <- result{ws, p, err} + }() + close(start) + r1 := <-ch + r2 := <-ch + require.NoError(t, r1.err) + require.NoError(t, r2.err) + require.Equal(t, r1.ws.ID, r2.ws.ID, "both creates must converge on one workspace ID") + + ws := r1.ws + ws.clientsMu.Lock() + defer ws.clientsMu.Unlock() + require.Contains(t, ws.clients, cidA) + require.Contains(t, ws.clients, cidB) +} + +// TestCreateWorkspace_RejectsBadClientID covers the 400 path from the +// backend side. +func TestCreateWorkspace_RejectsBadClientID(t *testing.T) { + t.Parallel() + + b := New(context.Background(), nil, func() {}) + + _, _, err := b.CreateWorkspace(protoWS("/tmp/x", t.TempDir(), "")) + require.ErrorIs(t, err, ErrInvalidClientID) + _, _, err = b.CreateWorkspace(protoWS("/tmp/x", t.TempDir(), "not-a-uuid")) + require.ErrorIs(t, err, ErrInvalidClientID) +} + +// drainBackend tears the backend down at the end of a test by deleting +// every remaining workspace. Necessary so the test process doesn't +// leak goroutines or DB handles from the embedded [app.App] instances. +func drainBackend(t *testing.T, b *Backend) { + t.Helper() + for _, ws := range b.workspaces.Seq2() { + ws.clientsMu.Lock() + ids := make([]string, 0, len(ws.clients)) + for id := range ws.clients { + ids = append(ids, id) + } + ws.clientsMu.Unlock() + for _, cid := range ids { + _ = b.releaseHold(ws.ID, cid) + } + } +} + +func protoWS(path, dataDir, clientID string) proto.Workspace { + return proto.Workspace{Path: path, DataDir: dataDir, ClientID: clientID} +} + +// captureDebugLogs installs a buffer-backed slog handler at Debug +// level for the duration of the test, returning the buffer. The +// previous default handler is restored via t.Cleanup. +func captureDebugLogs(t *testing.T) *bytes.Buffer { + t.Helper() + var buf bytes.Buffer + prev := slog.Default() + handler := slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelDebug}) + slog.SetDefault(slog.New(handler)) + t.Cleanup(func() { slog.SetDefault(prev) }) + return &buf +} + +// xdgIsolated points HOME and XDG_* variables at fresh tempdirs so +// CreateWorkspace's config loading does not interfere with the host +// machine's real config. +func xdgIsolated(t *testing.T) { + t.Helper() + t.Setenv("HOME", t.TempDir()) + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + t.Setenv("XDG_CONFIG_HOME", t.TempDir()) + t.Setenv("XDG_DATA_HOME", t.TempDir()) +} + +// TestFirstWinsMismatch_LogsOnFlagDifferences verifies that the +// debug mismatch line is emitted when any of YOLO, Debug, DataDir, +// or Env differs between the first and second CreateWorkspace at +// the same path, and that the existing workspace's Debug flag is +// not overwritten. +func TestFirstWinsMismatch_LogsOnFlagDifferences(t *testing.T) { + tests := []struct { + name string + mutate func(*proto.Workspace) + }{ + { + name: "yolo", + mutate: func(p *proto.Workspace) { p.YOLO = true }, + }, + { + name: "debug", + mutate: func(p *proto.Workspace) { p.Debug = true }, + }, + { + name: "datadir", + mutate: func(p *proto.Workspace) { p.DataDir = "" }, + }, + { + name: "env", + mutate: func(p *proto.Workspace) { p.Env = []string{"NEW=val"} }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + xdgIsolated(t) + cwd := t.TempDir() + dataDir := t.TempDir() + + buf := captureDebugLogs(t) + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + argsA := protoWS(cwd, dataDir, uuid.New().String()) + argsA.Env = []string{"FOO=bar"} + wsA, _, err := b.CreateWorkspace(argsA) + require.NoError(t, err) + originalDebug := wsA.Cfg.Config().Options.Debug + originalYOLO := wsA.Cfg.Overrides().SkipPermissionRequests + + argsB := protoWS(cwd, dataDir, uuid.New().String()) + argsB.Env = []string{"FOO=bar"} // identical by default + tc.mutate(&argsB) + _, _, err = b.CreateWorkspace(argsB) + require.NoError(t, err) + + require.Contains( + t, buf.String(), + "Workspace flag mismatch on duplicate create", + "expected debug log for mismatching %s", tc.name, + ) + // Existing workspace's YOLO and Debug must not change. + require.Equal(t, originalYOLO, wsA.Cfg.Overrides().SkipPermissionRequests, "YOLO must be immutable on first-wins") + require.Equal(t, originalDebug, wsA.Cfg.Config().Options.Debug, "Debug must be immutable on first-wins") + }) + } +} + +// TestFirstWinsMismatch_NoLogWhenIdentical confirms identical args +// do not emit the mismatch log line. +func TestFirstWinsMismatch_NoLogWhenIdentical(t *testing.T) { + xdgIsolated(t) + cwd := t.TempDir() + dataDir := t.TempDir() + + buf := captureDebugLogs(t) + b := New(context.Background(), nil, func() {}) + b.SetCreateGrace(2 * time.Second) + t.Cleanup(func() { drainBackend(t, b) }) + + argsA := protoWS(cwd, dataDir, uuid.New().String()) + argsA.Env = []string{"FOO=bar"} + _, _, err := b.CreateWorkspace(argsA) + require.NoError(t, err) + + argsB := protoWS(cwd, dataDir, uuid.New().String()) + argsB.Env = []string{"FOO=bar"} + _, _, err = b.CreateWorkspace(argsB) + require.NoError(t, err) + + require.False(t, + strings.Contains(buf.String(), "Workspace flag mismatch on duplicate create"), + "identical args must not log a mismatch: %s", buf.String()) +} + +// TestRaceTwoClientsAttachOneDetaches exercises the PLAN-required +// race scenario: two clients attach concurrently, then one detaches. +// The workspace must remain alive with refcount==1 and the clients +// map must reflect the remaining client only. +func TestRaceTwoClientsAttachOneDetaches(t *testing.T) { + t.Parallel() + + b, _ := newTestBackend(t) + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/race-two") + + cidA := newClientID(t) + cidB := newClientID(t) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + require.NoError(t, b.AttachClient(ws.ID, cidA)) + }() + go func() { + defer wg.Done() + require.NoError(t, b.AttachClient(ws.ID, cidB)) + }() + wg.Wait() + + ws.clientsMu.Lock() + require.Len(t, ws.clients, 2, "both clients must be attached") + ws.clientsMu.Unlock() + + b.DetachClient(ws.ID, cidA) + + ws.clientsMu.Lock() + require.Len(t, ws.clients, 1, "refcount must be 1 after one detach") + require.Contains(t, ws.clients, cidB, "remaining client must be cidB") + require.NotContains(t, ws.clients, cidA, "detached client must be removed") + ws.clientsMu.Unlock() + require.Equal(t, int32(0), shutdowns.Load(), "workspace must remain alive") + + // Drain. + b.DetachClient(ws.ID, cidB) + require.Equal(t, int32(1), shutdowns.Load()) +} + +// TestExplicitDeleteThenAttach reproduces the PLAN scenario: start +// with a real hold, releaseHold consumes it, AttachClient from the +// same clientID creates a fresh entry with streams==1, and calling +// releaseHold again is a no-op. A second client keeps the workspace +// alive so AttachClient can still resolve the workspace ID after the +// first client's hold is released. +func TestExplicitDeleteThenAttach(t *testing.T) { + t.Parallel() + + // Large grace window so timers cannot fire during the test + // — we want to exercise the explicit releaseHold path. + b, _ := newTestBackend(t) + b.createGrace = time.Hour + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/delete-then-attach") + + // Anchor client keeps the workspace registered in + // b.workspaces across the cid's releaseHold below. + anchor := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, anchor)) + + cid := newClientID(t) + // Real hold via registerClient (mirrors CreateWorkspace). + b.registerClient(ws, cid) + ws.clientsMu.Lock() + require.Contains(t, ws.clients, cid) + require.NotNil(t, ws.clients[cid].holdTimer, "hold must be live") + require.Equal(t, 0, ws.clients[cid].streams) + ws.clientsMu.Unlock() + + // releaseHold: consumes the hold and removes the entry + // (streams == 0). The anchor client keeps the workspace + // alive. + require.NoError(t, b.releaseHold(ws.ID, cid)) + require.Equal(t, int32(0), shutdowns.Load(), "anchor must keep workspace alive") + ws.clientsMu.Lock() + require.NotContains(t, ws.clients, cid, "entry must be removed by releaseHold") + ws.clientsMu.Unlock() + + // AttachClient creates a fresh entry with streams==1 and no + // hold timer. + require.NoError(t, b.AttachClient(ws.ID, cid)) + ws.clientsMu.Lock() + require.Contains(t, ws.clients, cid, "fresh entry must be created") + require.Equal(t, 1, ws.clients[cid].streams, "fresh attach must start at streams=1") + require.Nil(t, ws.clients[cid].holdTimer, "fresh attach must have no hold timer") + ws.clientsMu.Unlock() + + // Calling releaseHold again is a no-op (no hold timer to + // stop, streams > 0 so the entry stays). + require.NoError(t, b.releaseHold(ws.ID, cid)) + ws.clientsMu.Lock() + require.Contains(t, ws.clients, cid, "releaseHold must not touch a stream-only entry") + require.Equal(t, 1, ws.clients[cid].streams) + require.Nil(t, ws.clients[cid].holdTimer) + ws.clientsMu.Unlock() + + // Drain. + b.DetachClient(ws.ID, cid) + b.DetachClient(ws.ID, anchor) + require.Equal(t, int32(1), shutdowns.Load()) +} + +// TestAttachClient_RacesWithTeardown forces AttachClient to compete +// with the teardown path triggered by DetachClient. Before the fix, +// AttachClient could observe a workspace after teardown had already +// decided to remove it (because AttachClient did not synchronize with +// Backend.mu), leaving a live stream claim attached to a workspace +// that was then removed and shut down. With the fix, the outcome must +// be deterministic: either AttachClient won and the workspace is +// alive with the client registered, or teardown won and AttachClient +// returns ErrWorkspaceNotFound — never a half-state where the +// workspace is gone but ws.clients still contains the new client. +func TestAttachClient_RacesWithTeardown(t *testing.T) { + t.Parallel() + + for i := range 200 { + b, _ := newTestBackend(t) + // Keep the grace window long so it can't fire during the + // test and confuse the bookkeeping. + b.createGrace = time.Hour + ws, shutdowns := insertTestWorkspace(t, b, "/tmp/race-teardown") + + // Seed: cidA holds the workspace open via a stream. The + // imminent DetachClient(cidA) will be the *only* claim + // drop, so teardown will run. + cidA := newClientID(t) + require.NoError(t, b.AttachClient(ws.ID, cidA)) + + // cidB attempts to attach concurrently with the detach + // that will tear the workspace down. + cidB := newClientID(t) + start := make(chan struct{}) + errCh := make(chan error, 1) + detachDone := make(chan struct{}) + go func() { + <-start + errCh <- b.AttachClient(ws.ID, cidB) + }() + go func() { + <-start + b.DetachClient(ws.ID, cidA) + close(detachDone) + }() + close(start) + + // Wait for both goroutines so teardown (including + // shutdownFn) has fully run before we read state. + attachErr := <-errCh + <-detachDone + + _, wsStillRegistered := b.workspaces.Get(ws.ID) + ws.clientsMu.Lock() + _, hasA := ws.clients[cidA] + _, hasB := ws.clients[cidB] + clientCount := len(ws.clients) + ws.clientsMu.Unlock() + shutdownCount := shutdowns.Load() + + switch { + case attachErr == nil: + // AttachClient won. The workspace must be alive + // (registered) with cidB in its clients map. cidA + // may or may not still be there depending on who + // took clientsMu first, but the workspace must + // not have been torn down. + require.True(t, wsStillRegistered, + "iter %d: attach succeeded but workspace was removed", i) + require.True(t, hasB, + "iter %d: attach succeeded but cidB missing from clients", i) + require.Equal(t, int32(0), shutdownCount, + "iter %d: attach succeeded but workspace was shut down", i) + case errors.Is(attachErr, ErrWorkspaceNotFound): + // Teardown won. The workspace must be removed, + // shut down exactly once, and ws.clients must be + // empty (no half-state with cidB inserted into a + // dead workspace's clients map). + require.False(t, wsStillRegistered, + "iter %d: ErrWorkspaceNotFound but workspace still registered", i) + require.Equal(t, int32(1), shutdownCount, + "iter %d: ErrWorkspaceNotFound but shutdown count = %d", i, shutdownCount) + require.False(t, hasA, + "iter %d: teardown won but cidA still in clients", i) + require.False(t, hasB, + "iter %d: teardown won but cidB still in clients (would be the leaked attach)", i) + require.Zero(t, clientCount, + "iter %d: teardown won but clients map is non-empty", i) + default: + t.Fatalf("iter %d: unexpected AttachClient error: %v", i, attachErr) + } + } +} diff --git a/internal/client/client.go b/internal/client/client.go index 42dd0243b234bc1c9bfc4801311a728d027eb240..7b83da5cbb29e3959e5ee22762d303341e76be0c 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -15,6 +15,7 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/server" + "github.com/google/uuid" ) // DummyHost is used to satisfy the http.Client's requirement for a URL. @@ -22,10 +23,11 @@ const DummyHost = "api.crush.localhost" // Client represents an RPC client connected to a Crush server. type Client struct { - h *http.Client - path string - network string - addr string + h *http.Client + path string + network string + addr string + clientID string } // DefaultClient creates a new [Client] connected to the default server address. @@ -44,6 +46,7 @@ func NewClient(path, network, address string) (*Client, error) { c.path = filepath.Clean(path) c.network = network c.addr = address + c.clientID = uuid.New().String() p := &http.Protocols{} p.SetHTTP1(true) p.SetUnencryptedHTTP2(true) @@ -65,6 +68,12 @@ func (c *Client) Path() string { return c.path } +// ClientID returns the per-process client ID minted in [NewClient]. +// The server uses it as a presence/coordination handle. +func (c *Client) ClientID() string { + return c.clientID +} + // GetGlobalConfig retrieves the server's configuration. func (c *Client) GetGlobalConfig(ctx context.Context) (*config.Config, error) { var cfg config.Config diff --git a/internal/client/proto.go b/internal/client/proto.go index 442a4f0f3a8ff90981ab90e24fcdcdd98adf4004..e17f08dc8b836e7066476ad354c9ea3229e0bfb1 100644 --- a/internal/client/proto.go +++ b/internal/client/proto.go @@ -39,6 +39,7 @@ func (c *Client) ListWorkspaces(ctx context.Context) ([]proto.Workspace, error) // CreateWorkspace creates a new workspace on the server. func (c *Client) CreateWorkspace(ctx context.Context, ws proto.Workspace) (*proto.Workspace, error) { + ws.ClientID = c.clientID rsp, err := c.post(ctx, "/workspaces", nil, jsonBody(ws), http.Header{"Content-Type": []string{"application/json"}}) if err != nil { return nil, fmt.Errorf("failed to create workspace: %w", err) @@ -73,7 +74,8 @@ func (c *Client) GetWorkspace(ctx context.Context, id string) (*proto.Workspace, // DeleteWorkspace deletes a workspace on the server. func (c *Client) DeleteWorkspace(ctx context.Context, id string) error { - rsp, err := c.delete(ctx, fmt.Sprintf("/workspaces/%s", id), nil, nil) + q := url.Values{"client_id": []string{c.clientID}} + rsp, err := c.delete(ctx, fmt.Sprintf("/workspaces/%s", id), q, nil) if err != nil { return fmt.Errorf("failed to delete workspace: %w", err) } @@ -87,8 +89,9 @@ func (c *Client) DeleteWorkspace(ctx context.Context, id string) error { // SubscribeEvents subscribes to server-sent events for a workspace. func (c *Client) SubscribeEvents(ctx context.Context, id string) (<-chan any, error) { events := make(chan any, 100) + q := url.Values{"client_id": []string{c.clientID}} //nolint:bodyclose - rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/events", id), nil, http.Header{ + rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/events", id), q, http.Header{ "Accept": []string{"text/event-stream"}, "Cache-Control": []string{"no-cache"}, "Connection": []string{"keep-alive"}, diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 87dafd24abc44dabff608ed6744c17703c244a37..fbd71f33da3a330cfe7c14112ead7763d4b4d948 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -13,14 +13,15 @@ import ( // Workspace represents a running app.App workspace with its associated // resources and state. type Workspace struct { - ID string `json:"id"` - Path string `json:"path"` - YOLO bool `json:"yolo,omitempty"` - Debug bool `json:"debug,omitempty"` - DataDir string `json:"data_dir,omitempty"` - Version string `json:"version,omitempty"` - Config *config.Config `json:"config,omitempty"` - Env []string `json:"env,omitempty"` + ID string `json:"id"` + Path string `json:"path"` + YOLO bool `json:"yolo,omitempty"` + Debug bool `json:"debug,omitempty"` + DataDir string `json:"data_dir,omitempty"` + Version string `json:"version,omitempty"` + ClientID string `json:"client_id,omitempty"` + Config *config.Config `json:"config,omitempty"` + Env []string `json:"env,omitempty"` } // Error represents an error response. diff --git a/internal/server/multiclient_test.go b/internal/server/multiclient_test.go new file mode 100644 index 0000000000000000000000000000000000000000..bd62ee8e6e0f36c151fad8e590716966236d5ba7 --- /dev/null +++ b/internal/server/multiclient_test.go @@ -0,0 +1,107 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/charmbracelet/crush/internal/backend" + "github.com/charmbracelet/crush/internal/proto" + "github.com/stretchr/testify/require" +) + +// newTestController builds a controllerV1 around a backend without a +// real config store, suitable for handler-level 400 tests. +func newTestController() *controllerV1 { + s := &Server{} + s.backend = backend.New(context.Background(), nil, nil) + return &controllerV1{backend: s.backend, server: s} +} + +func TestPostWorkspaces_RejectsMissingClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + body, err := json.Marshal(proto.Workspace{Path: t.TempDir()}) + require.NoError(t, err) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/workspaces", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + c.handlePostWorkspaces(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + var perr proto.Error + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &perr)) + require.Contains(t, perr.Message, "client_id") +} + +func TestPostWorkspaces_RejectsMalformedClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + body, err := json.Marshal(proto.Workspace{Path: t.TempDir(), ClientID: "not-a-uuid"}) + require.NoError(t, err) + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/v1/workspaces", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + c.handlePostWorkspaces(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDeleteWorkspace_RejectsMissingClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodDelete, "/v1/workspaces/abc", nil) + req.SetPathValue("id", "abc") + rec := httptest.NewRecorder() + + c.handleDeleteWorkspaces(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDeleteWorkspace_RejectsMalformedClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodDelete, "/v1/workspaces/abc?client_id=nope", nil) + req.SetPathValue("id", "abc") + rec := httptest.NewRecorder() + + c.handleDeleteWorkspaces(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSubscribeEvents_RejectsMissingClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/workspaces/abc/events", nil) + req.SetPathValue("id", "abc") + rec := httptest.NewRecorder() + + c.handleGetWorkspaceEvents(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSubscribeEvents_RejectsMalformedClientID(t *testing.T) { + t.Parallel() + c := newTestController() + + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/v1/workspaces/abc/events?client_id=nope", nil) + req.SetPathValue("id", "abc") + rec := httptest.NewRecorder() + + c.handleGetWorkspaceEvents(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/internal/server/proto.go b/internal/server/proto.go index f30dade2c66fdd62a5caa4b80d29235ef2930c4a..0523904a3d2d4317da9a4afcea40985b675b41ba 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -9,6 +9,7 @@ import ( "github.com/charmbracelet/crush/internal/backend" "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/session" + "github.com/google/uuid" ) type controllerV1 struct { @@ -133,6 +134,23 @@ func (c *controllerV1) handlePostWorkspaces(w http.ResponseWriter, r *http.Reque jsonEncode(w, result) } +// requireClientID reads the client_id query parameter and validates it +// as a UUID. On failure it writes a 400 and returns false. +func (c *controllerV1) requireClientID(w http.ResponseWriter, r *http.Request) (string, bool) { + cid := r.URL.Query().Get("client_id") + if cid == "" { + c.server.logError(r, "Missing client_id query parameter") + jsonError(w, http.StatusBadRequest, "client_id is required") + return "", false + } + if _, err := uuid.Parse(cid); err != nil { + c.server.logError(r, "Invalid client_id", "error", err) + jsonError(w, http.StatusBadRequest, "client_id is not a valid UUID") + return "", false + } + return cid, true +} + // handleDeleteWorkspaces deletes a workspace. // // @Summary Delete workspace @@ -143,7 +161,14 @@ func (c *controllerV1) handlePostWorkspaces(w http.ResponseWriter, r *http.Reque // @Router /workspaces/{id} [delete] func (c *controllerV1) handleDeleteWorkspaces(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") - c.backend.DeleteWorkspace(id) + clientID, ok := c.requireClientID(w, r) + if !ok { + return + } + if err := c.backend.DeleteWorkspace(id, clientID); err != nil { + c.handleError(w, r, err) + return + } } // handleGetWorkspaceConfig returns workspace configuration. @@ -199,6 +224,15 @@ func (c *controllerV1) handleGetWorkspaceProviders(w http.ResponseWriter, r *htt func (c *controllerV1) handleGetWorkspaceEvents(w http.ResponseWriter, r *http.Request) { flusher := http.NewResponseController(w) id := r.PathValue("id") + clientID, ok := c.requireClientID(w, r) + if !ok { + return + } + if err := c.backend.AttachClient(id, clientID); err != nil { + c.handleError(w, r, err) + return + } + defer c.backend.DetachClient(id, clientID) events, err := c.backend.SubscribeEvents(r.Context(), id) if err != nil { c.handleError(w, r, err) @@ -951,6 +985,8 @@ func (c *controllerV1) handleError(w http.ResponseWriter, r *http.Request, err e status = http.StatusBadRequest case errors.Is(err, backend.ErrUnknownCommand): status = http.StatusBadRequest + case errors.Is(err, backend.ErrInvalidClientID): + status = http.StatusBadRequest } c.server.logError(r, err.Error()) jsonError(w, status, err.Error())