diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index 7fdd2cd0a6477fce7a9dea85473e87e83d8e1a35..7ddc2a7ee44eaad15b1177f98141161359bb6b4c 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -124,26 +124,29 @@ func GetState(name string) (ClientInfo, bool) { // Close closes all MCP clients. This should be called during application shutdown. func Close() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var wg sync.WaitGroup - done := make(chan struct{}, 1) - go func() { - for name, session := range sessions.Seq2() { - wg.Go(func() { - if err := session.Close(); err != nil && + for name, session := range sessions.Seq2() { + wg.Go(func() { + done := make(chan error, 1) + go func() { + done <- session.Close() + }() + select { + case err := <-done: + if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, context.Canceled) && err.Error() != "signal: killed" { slog.Warn("Failed to shutdown MCP client", "name", name, "error", err) } - }) - } - wg.Wait() - done <- struct{}{} - }() - select { - case <-done: - case <-time.After(5 * time.Second): + case <-ctx.Done(): + } + }) } + wg.Wait() broker.Shutdown() return nil } diff --git a/internal/app/app.go b/internal/app/app.go index 7e16e294c17553a030d412ecde4ad95a90d53ecc..9993e3ee80732a47ad98aa00d23e98a5438bb2b8 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -540,14 +540,16 @@ func (app *App) Shutdown() { // Now run remaining cleanup tasks in parallel. var wg sync.WaitGroup + // Shared shutdown context for all timeout-bounded cleanup. + shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second) + defer cancel() + // Kill all background shells. wg.Go(func() { - shell.GetBackgroundShellManager().KillAll() + shell.GetBackgroundShellManager().KillAll(shutdownCtx) }) // Shutdown all LSP clients. - shutdownCtx, cancel := context.WithTimeout(app.globalCtx, 5*time.Second) - defer cancel() wg.Go(func() { app.LSPManager.StopAll(shutdownCtx) }) diff --git a/internal/shell/background.go b/internal/shell/background.go index cb1855836f64bdd56a90802c2bbb939a5a514100..c6a0f81e2c4c0b9de19a599b07f58cf7225d32a2 100644 --- a/internal/shell/background.go +++ b/internal/shell/background.go @@ -191,29 +191,23 @@ func (m *BackgroundShellManager) Cleanup() int { return len(toRemove) } -// KillAll terminates all background shells. -func (m *BackgroundShellManager) KillAll() { +// KillAll terminates all background shells. The provided context bounds how +// long the function waits for each shell to exit. +func (m *BackgroundShellManager) KillAll(ctx context.Context) { shells := slices.Collect(m.shells.Seq()) m.shells.Reset(map[string]*BackgroundShell{}) - done := make(chan struct{}, 1) - go func() { - var wg sync.WaitGroup - for _, shell := range shells { - wg.Go(func() { - shell.cancel() - <-shell.done - }) - } - wg.Wait() - done <- struct{}{} - }() - select { - case <-done: - return - case <-time.After(time.Second * 5): - return + var wg sync.WaitGroup + for _, shell := range shells { + wg.Go(func() { + shell.cancel() + select { + case <-shell.done: + case <-ctx.Done(): + } + }) } + wg.Wait() } // GetOutput returns the current output of a background shell. diff --git a/internal/shell/background_test.go b/internal/shell/background_test.go index 7c521bc1477b07775cffb69f310fa83d710d4634..f3a8cb9f7db442be67fc1ac7f2898fd6d1d2a87e 100644 --- a/internal/shell/background_test.go +++ b/internal/shell/background_test.go @@ -6,6 +6,8 @@ import ( "strings" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestBackgroundShellManager_Start(t *testing.T) { @@ -248,7 +250,7 @@ func TestBackgroundShellManager_KillAll(t *testing.T) { } // Kill all shells - manager.KillAll() + manager.KillAll(context.Background()) // Verify all shells are done if !shell1.IsDone() { @@ -280,3 +282,25 @@ func TestBackgroundShellManager_KillAll(t *testing.T) { } } } + +func TestBackgroundShellManager_KillAll_Timeout(t *testing.T) { + t.Parallel() + + workingDir := t.TempDir() + manager := newBackgroundShellManager() + + // Start a shell that traps signals and ignores cancellation. + _, err := manager.Start(context.Background(), workingDir, nil, "trap '' TERM INT; sleep 60", "") + require.NoError(t, err) + + // Short timeout to test the timeout path. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + manager.KillAll(ctx) + elapsed := time.Since(start) + + // Must return promptly after timeout, not hang for 60 seconds. + require.Less(t, elapsed, 2*time.Second) +}