From 5fcea4422b2dc0356dfe19914fdef018128694f0 Mon Sep 17 00:00:00 2001 From: Christian Rocha Date: Thu, 19 Mar 2026 21:24:08 -0400 Subject: [PATCH] fix(mcp): safely teardown partially initialized mcp sessions --- internal/agent/tools/mcp/init.go | 86 ++++++++++++++++++------- internal/agent/tools/mcp/init_test.go | 92 +++++++++++++++++++++++++++ internal/ui/model/ui.go | 4 +- 3 files changed, 156 insertions(+), 26 deletions(-) diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index 39bd05254109ec25cf68bda288c89aca25d72168..007579fe3896f8a2869798f85e0b5885f9817893 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -14,6 +14,7 @@ import ( "os/exec" "strings" "sync" + "syscall" "time" "github.com/charmbracelet/crush/internal/config" @@ -136,6 +137,26 @@ func GetState(name string) (ClientInfo, bool) { return states.Get(name) } +// isIgnorableCloseErr returns true for errors that are expected during MCP +// session shutdown and can be safely suppressed. +func isIgnorableCloseErr(err error) bool { + return err == nil || + errors.Is(err, io.EOF) || + errors.Is(err, context.Canceled) || + isKilledErr(err) +} + +// isKilledErr returns true if the error is an exec.ExitError caused by +// SIGKILL. +func isKilledErr(err error) bool { + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + return false + } + ws, ok := exitErr.Sys().(syscall.WaitStatus) + return ok && ws.Signaled() && ws.Signal() == syscall.SIGKILL +} + // Close closes all MCP clients. This should be called during application shutdown. func Close(ctx context.Context) error { var wg sync.WaitGroup @@ -147,10 +168,7 @@ func Close(ctx context.Context) error { }() select { case err := <-done: - if err != nil && - !errors.Is(err, io.EOF) && - !errors.Is(err, context.Canceled) && - err.Error() != "signal: killed" { + if !isIgnorableCloseErr(err) { slog.Warn("Failed to shutdown MCP client", "name", name, "error", err) } case <-ctx.Done(): @@ -195,7 +213,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config }() if err := initClient(ctx, cfg, name, m, cfg.Resolver()); err != nil { - slog.Debug("failed to initialize mcp client", "name", name, "error", err) + slog.Debug("Failed to initialize MCP client", "name", name, "error", err) } }(name, m) } @@ -223,7 +241,7 @@ func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore) if m.Disabled { updateState(name, StateDisabled, nil, nil, Counts{}) - slog.Debug("skipping disabled mcp", "name", name) + slog.Debug("Skipping disabled MCP", "name", name) return nil } @@ -232,10 +250,8 @@ func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore) // initClient initializes a single MCP client with the given configuration. func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m config.MCPConfig, resolver config.VariableResolver) error { - // Set initial starting state. updateState(name, StateStarting, nil, nil, Counts{}) - // createSession handles its own timeout internally. session, err := createSession(ctx, name, m, resolver) if err != nil { return err @@ -243,25 +259,25 @@ func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m con tools, err := getTools(ctx, session) if err != nil { - slog.Error("Error listing tools", "error", err) + slog.Error("Error listing tools", "name", name, "error", err) updateState(name, StateError, err, nil, Counts{}) - session.Close() + closeSessionOnInitError(name, session) return err } prompts, err := getPrompts(ctx, session) if err != nil { - slog.Error("Error listing prompts", "error", err) + slog.Error("Error listing prompts", "name", name, "error", err) updateState(name, StateError, err, nil, Counts{}) - session.Close() + closeSessionOnInitError(name, session) return err } resources, err := getResources(ctx, session) if err != nil { - slog.Error("Error listing resources", "error", err) + slog.Error("Error listing resources", "name", name, "error", err) updateState(name, StateError, err, nil, Counts{}) - session.Close() + closeSessionOnInitError(name, session) return err } @@ -279,28 +295,52 @@ func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m con return nil } +// closeSessionOnInitError closes a session that failed during initialization, +// suppressing expected shutdown errors. Uses a fixed timeout to avoid blocking +// indefinitely if the parent context has no deadline. +// +// On timeout the Close goroutine may outlive this function, but since +// session.Close cancels the session context internally, it will unblock +// shortly after. +func closeSessionOnInitError(name string, session *ClientSession) { + const closeTimeout = 5 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), closeTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- session.Close() + }() + + select { + case err := <-done: + if !isIgnorableCloseErr(err) { + slog.Warn("Failed to close MCP session after init error", "name", name, "error", err) + } + case <-ctx.Done(): + slog.Warn("Timed out waiting to close MCP session after init error", "name", name, "error", ctx.Err()) + } +} + // DisableSingle disables and closes a single MCP client by name. -func DisableSingle(cfg *config.ConfigStore, name string) error { +func DisableSingle(cfg *config.ConfigStore, name string) { session, ok := sessions.Get(name) if ok { - if err := session.Close(); err != nil && - !errors.Is(err, io.EOF) && - !errors.Is(err, context.Canceled) && - err.Error() != "signal: killed" { - slog.Warn("error closing mcp session", "name", name, "error", err) + if err := session.Close(); !isIgnorableCloseErr(err) { + slog.Warn("Error closing MCP session", "name", name, "error", err) } sessions.Del(name) } - // Clear tools and prompts for this MCP. + // Clear tools, prompts, and resources for this MCP. updateTools(cfg, name, nil) updatePrompts(name, nil) + updateResources(name, nil) // Update state to disabled. updateState(name, StateDisabled, nil, nil, Counts{}) - slog.Info("Disabled mcp client", "name", name) - return nil + slog.Info("Disabled MCP client", "name", name) } func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) { diff --git a/internal/agent/tools/mcp/init_test.go b/internal/agent/tools/mcp/init_test.go index 94958593750852d30ff96734ada23671252e508e..07e067545dc31f98a12690deb9338548367763db 100644 --- a/internal/agent/tools/mcp/init_test.go +++ b/internal/agent/tools/mcp/init_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/charmbracelet/crush/internal/config" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -36,3 +37,94 @@ func TestMCPSession_CancelOnClose(t *testing.T) { // After Close, the context must be cancelled. require.ErrorIs(t, ctx.Err(), context.Canceled) } + +func TestInitClient_PopulatesResources(t *testing.T) { + defer goleak.VerifyNone(t, + goleak.IgnoreAnyFunction("net/http.(*http2Transport).newClientConn"), + goleak.IgnoreAnyFunction("internal/poll.runtime_pollWait"), + goleak.IgnoreAnyFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"), + ) + + const name = "test-resources" + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil) + server.AddResource( + &mcp.Resource{URI: "file:///readme.md", Name: "readme"}, + func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{URI: "file:///readme.md"}}, + }, nil + }, + ) + server.AddResource( + &mcp.Resource{URI: "file:///license", Name: "license"}, + func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { + return &mcp.ReadResourceResult{ + Contents: []*mcp.ResourceContents{{URI: "file:///license"}}, + }, nil + }, + ) + + serverSession, err := server.Connect(context.Background(), serverTransport, nil) + require.NoError(t, err) + defer serverSession.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil) + clientSession, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + session := &ClientSession{clientSession, cancel} + + cfg, err := config.Init(t.TempDir(), "", false) + require.NoError(t, err) + + // Clean up any prior state for this name. + t.Cleanup(func() { + allTools.Del(name) + allPrompts.Del(name) + allResources.Del(name) + sessions.Del(name) + states.Del(name) + }) + + toolCount := updateTools(cfg, name, nil) + updatePrompts(name, nil) + resourceCount := updateResources(name, nil) + require.Equal(t, 0, toolCount) + require.Equal(t, 0, resourceCount) + + // Simulate what initClient does after creating a session. + tools, err := getTools(ctx, session) + require.NoError(t, err) + + prompts, err := getPrompts(ctx, session) + require.NoError(t, err) + + resources, err := getResources(ctx, session) + require.NoError(t, err) + require.Len(t, resources, 2) + + toolCount = updateTools(cfg, name, tools) + updatePrompts(name, prompts) + resourceCount = updateResources(name, resources) + sessions.Set(name, session) + + updateState(name, StateConnected, nil, session, Counts{ + Tools: toolCount, + Prompts: len(prompts), + Resources: resourceCount, + }) + + // Verify resources are stored and counts are correct. + storedResources, ok := allResources.Get(name) + require.True(t, ok) + require.Len(t, storedResources, 2) + + state, ok := states.Get(name) + require.True(t, ok) + require.Equal(t, StateConnected, state.State) + require.Equal(t, 2, state.Counts.Resources) +} diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 7e7c36348337bf0d8299fa9d9eba7e52176be284..02d2fd5c5193fd73968ffdb4a29f26bedc7b3532 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -3481,9 +3481,7 @@ func (m *UI) enableDockerMCP() tea.Msg { func (m *UI) disableDockerMCP() tea.Msg { store := m.com.Store() // Close the Docker MCP client. - if err := mcp.DisableSingle(store, config.DockerMCPName); err != nil { - return util.ReportError(fmt.Errorf("failed to disable docker MCP: %w", err))() - } + mcp.DisableSingle(store, config.DockerMCPName) // Remove from config and persist. if err := store.DisableDockerMCP(); err != nil {