fix(mcp): safely teardown partially initialized mcp sessions

Christian Rocha created

Change summary

internal/agent/tools/mcp/init.go      | 86 +++++++++++++++++++-------
internal/agent/tools/mcp/init_test.go | 92 +++++++++++++++++++++++++++++
internal/ui/model/ui.go               |  4 
3 files changed, 156 insertions(+), 26 deletions(-)

Detailed changes

internal/agent/tools/mcp/init.go 🔗

@@ -14,6 +14,7 @@ import (
 	"os/exec"
 	"strings"
 	"sync"
+	"syscall"
 	"time"
 
 	"github.com/charmbracelet/crush/internal/config"
@@ -136,6 +137,26 @@ func GetState(name string) (ClientInfo, bool) {
 	return states.Get(name)
 }
 
+// isIgnorableCloseErr returns true for errors that are expected during MCP
+// session shutdown and can be safely suppressed.
+func isIgnorableCloseErr(err error) bool {
+	return err == nil ||
+		errors.Is(err, io.EOF) ||
+		errors.Is(err, context.Canceled) ||
+		isKilledErr(err)
+}
+
+// isKilledErr returns true if the error is an exec.ExitError caused by
+// SIGKILL.
+func isKilledErr(err error) bool {
+	var exitErr *exec.ExitError
+	if !errors.As(err, &exitErr) {
+		return false
+	}
+	ws, ok := exitErr.Sys().(syscall.WaitStatus)
+	return ok && ws.Signaled() && ws.Signal() == syscall.SIGKILL
+}
+
 // Close closes all MCP clients. This should be called during application shutdown.
 func Close(ctx context.Context) error {
 	var wg sync.WaitGroup
@@ -147,10 +168,7 @@ func Close(ctx context.Context) error {
 			}()
 			select {
 			case err := <-done:
-				if err != nil &&
-					!errors.Is(err, io.EOF) &&
-					!errors.Is(err, context.Canceled) &&
-					err.Error() != "signal: killed" {
+				if !isIgnorableCloseErr(err) {
 					slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
 				}
 			case <-ctx.Done():
@@ -195,7 +213,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
 			}()
 
 			if err := initClient(ctx, cfg, name, m, cfg.Resolver()); err != nil {
-				slog.Debug("failed to initialize mcp client", "name", name, "error", err)
+				slog.Debug("Failed to initialize MCP client", "name", name, "error", err)
 			}
 		}(name, m)
 	}
@@ -223,7 +241,7 @@ func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore)
 
 	if m.Disabled {
 		updateState(name, StateDisabled, nil, nil, Counts{})
-		slog.Debug("skipping disabled mcp", "name", name)
+		slog.Debug("Skipping disabled MCP", "name", name)
 		return nil
 	}
 
@@ -232,10 +250,8 @@ func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore)
 
 // initClient initializes a single MCP client with the given configuration.
 func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m config.MCPConfig, resolver config.VariableResolver) error {
-	// Set initial starting state.
 	updateState(name, StateStarting, nil, nil, Counts{})
 
-	// createSession handles its own timeout internally.
 	session, err := createSession(ctx, name, m, resolver)
 	if err != nil {
 		return err
@@ -243,25 +259,25 @@ func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m con
 
 	tools, err := getTools(ctx, session)
 	if err != nil {
-		slog.Error("Error listing tools", "error", err)
+		slog.Error("Error listing tools", "name", name, "error", err)
 		updateState(name, StateError, err, nil, Counts{})
-		session.Close()
+		closeSessionOnInitError(name, session)
 		return err
 	}
 
 	prompts, err := getPrompts(ctx, session)
 	if err != nil {
-		slog.Error("Error listing prompts", "error", err)
+		slog.Error("Error listing prompts", "name", name, "error", err)
 		updateState(name, StateError, err, nil, Counts{})
-		session.Close()
+		closeSessionOnInitError(name, session)
 		return err
 	}
 
 	resources, err := getResources(ctx, session)
 	if err != nil {
-		slog.Error("Error listing resources", "error", err)
+		slog.Error("Error listing resources", "name", name, "error", err)
 		updateState(name, StateError, err, nil, Counts{})
-		session.Close()
+		closeSessionOnInitError(name, session)
 		return err
 	}
 
@@ -279,28 +295,52 @@ func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m con
 	return nil
 }
 
+// closeSessionOnInitError closes a session that failed during initialization,
+// suppressing expected shutdown errors. Uses a fixed timeout to avoid blocking
+// indefinitely if the parent context has no deadline.
+//
+// On timeout the Close goroutine may outlive this function, but since
+// session.Close cancels the session context internally, it will unblock
+// shortly after.
+func closeSessionOnInitError(name string, session *ClientSession) {
+	const closeTimeout = 5 * time.Second
+	ctx, cancel := context.WithTimeout(context.Background(), closeTimeout)
+	defer cancel()
+
+	done := make(chan error, 1)
+	go func() {
+		done <- session.Close()
+	}()
+
+	select {
+	case err := <-done:
+		if !isIgnorableCloseErr(err) {
+			slog.Warn("Failed to close MCP session after init error", "name", name, "error", err)
+		}
+	case <-ctx.Done():
+		slog.Warn("Timed out waiting to close MCP session after init error", "name", name, "error", ctx.Err())
+	}
+}
+
 // DisableSingle disables and closes a single MCP client by name.
-func DisableSingle(cfg *config.ConfigStore, name string) error {
+func DisableSingle(cfg *config.ConfigStore, name string) {
 	session, ok := sessions.Get(name)
 	if ok {
-		if err := session.Close(); err != nil &&
-			!errors.Is(err, io.EOF) &&
-			!errors.Is(err, context.Canceled) &&
-			err.Error() != "signal: killed" {
-			slog.Warn("error closing mcp session", "name", name, "error", err)
+		if err := session.Close(); !isIgnorableCloseErr(err) {
+			slog.Warn("Error closing MCP session", "name", name, "error", err)
 		}
 		sessions.Del(name)
 	}
 
-	// Clear tools and prompts for this MCP.
+	// Clear tools, prompts, and resources for this MCP.
 	updateTools(cfg, name, nil)
 	updatePrompts(name, nil)
+	updateResources(name, nil)
 
 	// Update state to disabled.
 	updateState(name, StateDisabled, nil, nil, Counts{})
 
-	slog.Info("Disabled mcp client", "name", name)
-	return nil
+	slog.Info("Disabled MCP client", "name", name)
 }
 
 func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) {

internal/agent/tools/mcp/init_test.go 🔗

@@ -4,6 +4,7 @@ import (
 	"context"
 	"testing"
 
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/modelcontextprotocol/go-sdk/mcp"
 	"github.com/stretchr/testify/require"
 	"go.uber.org/goleak"
@@ -36,3 +37,94 @@ func TestMCPSession_CancelOnClose(t *testing.T) {
 	// After Close, the context must be cancelled.
 	require.ErrorIs(t, ctx.Err(), context.Canceled)
 }
+
+func TestInitClient_PopulatesResources(t *testing.T) {
+	defer goleak.VerifyNone(t,
+		goleak.IgnoreAnyFunction("net/http.(*http2Transport).newClientConn"),
+		goleak.IgnoreAnyFunction("internal/poll.runtime_pollWait"),
+		goleak.IgnoreAnyFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
+	)
+
+	const name = "test-resources"
+
+	serverTransport, clientTransport := mcp.NewInMemoryTransports()
+	server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
+	server.AddResource(
+		&mcp.Resource{URI: "file:///readme.md", Name: "readme"},
+		func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
+			return &mcp.ReadResourceResult{
+				Contents: []*mcp.ResourceContents{{URI: "file:///readme.md"}},
+			}, nil
+		},
+	)
+	server.AddResource(
+		&mcp.Resource{URI: "file:///license", Name: "license"},
+		func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
+			return &mcp.ReadResourceResult{
+				Contents: []*mcp.ResourceContents{{URI: "file:///license"}},
+			}, nil
+		},
+	)
+
+	serverSession, err := server.Connect(context.Background(), serverTransport, nil)
+	require.NoError(t, err)
+	defer serverSession.Close()
+
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
+	clientSession, err := client.Connect(ctx, clientTransport, nil)
+	require.NoError(t, err)
+	session := &ClientSession{clientSession, cancel}
+
+	cfg, err := config.Init(t.TempDir(), "", false)
+	require.NoError(t, err)
+
+	// Clean up any prior state for this name.
+	t.Cleanup(func() {
+		allTools.Del(name)
+		allPrompts.Del(name)
+		allResources.Del(name)
+		sessions.Del(name)
+		states.Del(name)
+	})
+
+	toolCount := updateTools(cfg, name, nil)
+	updatePrompts(name, nil)
+	resourceCount := updateResources(name, nil)
+	require.Equal(t, 0, toolCount)
+	require.Equal(t, 0, resourceCount)
+
+	// Simulate what initClient does after creating a session.
+	tools, err := getTools(ctx, session)
+	require.NoError(t, err)
+
+	prompts, err := getPrompts(ctx, session)
+	require.NoError(t, err)
+
+	resources, err := getResources(ctx, session)
+	require.NoError(t, err)
+	require.Len(t, resources, 2)
+
+	toolCount = updateTools(cfg, name, tools)
+	updatePrompts(name, prompts)
+	resourceCount = updateResources(name, resources)
+	sessions.Set(name, session)
+
+	updateState(name, StateConnected, nil, session, Counts{
+		Tools:     toolCount,
+		Prompts:   len(prompts),
+		Resources: resourceCount,
+	})
+
+	// Verify resources are stored and counts are correct.
+	storedResources, ok := allResources.Get(name)
+	require.True(t, ok)
+	require.Len(t, storedResources, 2)
+
+	state, ok := states.Get(name)
+	require.True(t, ok)
+	require.Equal(t, StateConnected, state.State)
+	require.Equal(t, 2, state.Counts.Resources)
+}

internal/ui/model/ui.go 🔗

@@ -3481,9 +3481,7 @@ func (m *UI) enableDockerMCP() tea.Msg {
 func (m *UI) disableDockerMCP() tea.Msg {
 	store := m.com.Store()
 	// Close the Docker MCP client.
-	if err := mcp.DisableSingle(store, config.DockerMCPName); err != nil {
-		return util.ReportError(fmt.Errorf("failed to disable docker MCP: %w", err))()
-	}
+	mcp.DisableSingle(store, config.DockerMCPName)
 
 	// Remove from config and persist.
 	if err := store.DisableDockerMCP(); err != nil {