Detailed changes
@@ -124,26 +124,29 @@ func GetState(name string) (ClientInfo, bool) {
// Close closes all MCP clients. This should be called during application shutdown.
func Close() error {
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
var wg sync.WaitGroup
- done := make(chan struct{}, 1)
- go func() {
- for name, session := range sessions.Seq2() {
- wg.Go(func() {
- if err := session.Close(); err != nil &&
+ for name, session := range sessions.Seq2() {
+ wg.Go(func() {
+ done := make(chan error, 1)
+ go func() {
+ done <- session.Close()
+ }()
+ select {
+ case err := <-done:
+ if err != nil &&
!errors.Is(err, io.EOF) &&
!errors.Is(err, context.Canceled) &&
err.Error() != "signal: killed" {
slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
}
- })
- }
- wg.Wait()
- done <- struct{}{}
- }()
- select {
- case <-done:
- case <-time.After(5 * time.Second):
+ case <-ctx.Done():
+ }
+ })
}
+ wg.Wait()
broker.Shutdown()
return nil
}
@@ -540,14 +540,16 @@ func (app *App) Shutdown() {
// Now run remaining cleanup tasks in parallel.
var wg sync.WaitGroup
+ // Shared shutdown context for all timeout-bounded cleanup.
+ shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
+ defer cancel()
+
// Kill all background shells.
wg.Go(func() {
- shell.GetBackgroundShellManager().KillAll()
+ shell.GetBackgroundShellManager().KillAll(shutdownCtx)
})
// Shutdown all LSP clients.
- shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second)
- defer cancel()
wg.Go(func() {
app.LSPManager.StopAll(shutdownCtx)
})
@@ -191,29 +191,23 @@ func (m *BackgroundShellManager) Cleanup() int {
return len(toRemove)
}
-// KillAll terminates all background shells.
-func (m *BackgroundShellManager) KillAll() {
+// KillAll terminates all background shells. The provided context bounds how
+// long the function waits for each shell to exit.
+func (m *BackgroundShellManager) KillAll(ctx context.Context) {
shells := slices.Collect(m.shells.Seq())
m.shells.Reset(map[string]*BackgroundShell{})
- done := make(chan struct{}, 1)
- go func() {
- var wg sync.WaitGroup
- for _, shell := range shells {
- wg.Go(func() {
- shell.cancel()
- <-shell.done
- })
- }
- wg.Wait()
- done <- struct{}{}
- }()
- select {
- case <-done:
- return
- case <-time.After(time.Second * 5):
- return
+ var wg sync.WaitGroup
+ for _, shell := range shells {
+ wg.Go(func() {
+ shell.cancel()
+ select {
+ case <-shell.done:
+ case <-ctx.Done():
+ }
+ })
}
+ wg.Wait()
}
// GetOutput returns the current output of a background shell.
@@ -6,6 +6,8 @@ import (
"strings"
"testing"
"time"
+
+ "github.com/stretchr/testify/require"
)
func TestBackgroundShellManager_Start(t *testing.T) {
@@ -248,7 +250,7 @@ func TestBackgroundShellManager_KillAll(t *testing.T) {
}
// Kill all shells
- manager.KillAll()
+ manager.KillAll(context.Background())
// Verify all shells are done
if !shell1.IsDone() {
@@ -280,3 +282,25 @@ func TestBackgroundShellManager_KillAll(t *testing.T) {
}
}
}
+
+func TestBackgroundShellManager_KillAll_Timeout(t *testing.T) {
+ t.Parallel()
+
+ workingDir := t.TempDir()
+ manager := newBackgroundShellManager()
+
+ // Start a shell that traps signals and ignores cancellation.
+ _, err := manager.Start(context.Background(), workingDir, nil, "trap '' TERM INT; sleep 60", "")
+ require.NoError(t, err)
+
+ // Short timeout to test the timeout path.
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+
+ start := time.Now()
+ manager.KillAll(ctx)
+ elapsed := time.Since(start)
+
+ // Must return promptly after timeout, not hang for 60 seconds.
+ require.Less(t, elapsed, 2*time.Second)
+}