fix: prevent goroutine orphaning in mcp.Close() and shell.KillAll() (#2159)

M1xA created

Change summary

internal/agent/tools/mcp/init.go  | 29 ++++++++++++++++-------------
internal/app/app.go               |  8 +++++---
internal/shell/background.go      | 32 +++++++++++++-------------------
internal/shell/background_test.go | 26 +++++++++++++++++++++++++-
4 files changed, 59 insertions(+), 36 deletions(-)

Detailed changes

internal/agent/tools/mcp/init.go 🔗

@@ -124,26 +124,29 @@ func GetState(name string) (ClientInfo, bool) {
 
 // Close closes all MCP clients. This should be called during application shutdown.
 func Close() error {
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+	defer cancel()
+
 	var wg sync.WaitGroup
-	done := make(chan struct{}, 1)
-	go func() {
-		for name, session := range sessions.Seq2() {
-			wg.Go(func() {
-				if err := session.Close(); err != nil &&
+	for name, session := range sessions.Seq2() {
+		wg.Go(func() {
+			done := make(chan error, 1)
+			go func() {
+				done <- session.Close()
+			}()
+			select {
+			case err := <-done:
+				if err != nil &&
 					!errors.Is(err, io.EOF) &&
 					!errors.Is(err, context.Canceled) &&
 					err.Error() != "signal: killed" {
 					slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
 				}
-			})
-		}
-		wg.Wait()
-		done <- struct{}{}
-	}()
-	select {
-	case <-done:
-	case <-time.After(5 * time.Second):
+			case <-ctx.Done():
+			}
+		})
 	}
+	wg.Wait()
 	broker.Shutdown()
 	return nil
 }

internal/app/app.go 🔗

@@ -540,14 +540,16 @@ func (app *App) Shutdown() {
 	// Now run remaining cleanup tasks in parallel.
 	var wg sync.WaitGroup
 
+	// Shared shutdown context for all timeout-bounded cleanup.
+	shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
+	defer cancel()
+
 	// Kill all background shells.
 	wg.Go(func() {
-		shell.GetBackgroundShellManager().KillAll()
+		shell.GetBackgroundShellManager().KillAll(shutdownCtx)
 	})
 
 	// Shutdown all LSP clients.
-	shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
-	defer cancel()
 	wg.Go(func() {
 		app.LSPManager.StopAll(shutdownCtx)
 	})

internal/shell/background.go 🔗

@@ -191,29 +191,23 @@ func (m *BackgroundShellManager) Cleanup() int {
 	return len(toRemove)
 }
 
-// KillAll terminates all background shells.
-func (m *BackgroundShellManager) KillAll() {
+// KillAll terminates all background shells. The provided context bounds how
+// long the function waits for each shell to exit.
+func (m *BackgroundShellManager) KillAll(ctx context.Context) {
 	shells := slices.Collect(m.shells.Seq())
 	m.shells.Reset(map[string]*BackgroundShell{})
-	done := make(chan struct{}, 1)
-	go func() {
-		var wg sync.WaitGroup
-		for _, shell := range shells {
-			wg.Go(func() {
-				shell.cancel()
-				<-shell.done
-			})
-		}
-		wg.Wait()
-		done <- struct{}{}
-	}()
 
-	select {
-	case <-done:
-		return
-	case <-time.After(time.Second * 5):
-		return
+	var wg sync.WaitGroup
+	for _, shell := range shells {
+		wg.Go(func() {
+			shell.cancel()
+			select {
+			case <-shell.done:
+			case <-ctx.Done():
+			}
+		})
 	}
+	wg.Wait()
 }
 
 // GetOutput returns the current output of a background shell.

internal/shell/background_test.go 🔗

@@ -6,6 +6,8 @@ import (
 	"strings"
 	"testing"
 	"time"
+
+	"github.com/stretchr/testify/require"
 )
 
 func TestBackgroundShellManager_Start(t *testing.T) {
@@ -248,7 +250,7 @@ func TestBackgroundShellManager_KillAll(t *testing.T) {
 	}
 
 	// Kill all shells
-	manager.KillAll()
+	manager.KillAll(context.Background())
 
 	// Verify all shells are done
 	if !shell1.IsDone() {
@@ -280,3 +282,25 @@ func TestBackgroundShellManager_KillAll(t *testing.T) {
 		}
 	}
 }
+
+func TestBackgroundShellManager_KillAll_Timeout(t *testing.T) {
+	t.Parallel()
+
+	workingDir := t.TempDir()
+	manager := newBackgroundShellManager()
+
+	// Start a shell that traps signals and ignores cancellation.
+	_, err := manager.Start(context.Background(), workingDir, nil, "trap '' TERM INT; sleep 60", "")
+	require.NoError(t, err)
+
+	// Short timeout to test the timeout path.
+	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+	defer cancel()
+
+	start := time.Now()
+	manager.KillAll(ctx)
+	elapsed := time.Since(start)
+
+	// Must return promptly after timeout, not hang for 60 seconds.
+	require.Less(t, elapsed, 2*time.Second)
+}