fix(mcp): cancel context on MCP session close to prevent leak (#2157)

M1xA created

* fix(mcp): cancel context on MCP session close to prevent leak

* refactor: rename unexported mcpSession to exported ClientSession

Change summary

internal/agent/tools/mcp/init.go      | 26 +++++++++++++++----
internal/agent/tools/mcp/init_test.go | 38 +++++++++++++++++++++++++++++
internal/agent/tools/mcp/prompts.go   |  2 
internal/agent/tools/mcp/resources.go |  2 
internal/agent/tools/mcp/tools.go     |  2 
5 files changed, 61 insertions(+), 9 deletions(-)

Detailed changes

internal/agent/tools/mcp/init.go 🔗

@@ -38,8 +38,22 @@ func parseLevel(level mcp.LoggingLevel) slog.Level {
 	}
 }
 
+// ClientSession wraps an mcp.ClientSession with a context cancel function so
+// that the context created during session establishment is properly cleaned up
+// on close.
+type ClientSession struct {
+	*mcp.ClientSession
+	cancel context.CancelFunc
+}
+
+// Close cancels the session context and then closes the underlying session.
+func (s *ClientSession) Close() error {
+	s.cancel()
+	return s.ClientSession.Close()
+}
+
 var (
-	sessions = csync.NewMap[string, *mcp.ClientSession]()
+	sessions = csync.NewMap[string, *ClientSession]()
 	states   = csync.NewMap[string, ClientInfo]()
 	broker   = pubsub.NewBroker[Event]()
 	initOnce sync.Once
@@ -102,7 +116,7 @@ type ClientInfo struct {
 	Name        string
 	State       State
 	Error       error
-	Client      *mcp.ClientSession
+	Client      *ClientSession
 	Counts      Counts
 	ConnectedAt time.Time
 }
@@ -239,7 +253,7 @@ func WaitForInit(ctx context.Context) error {
 	}
 }
 
-func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*mcp.ClientSession, error) {
+func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*ClientSession, error) {
 	sess, ok := sessions.Get(name)
 	if !ok {
 		return nil, fmt.Errorf("mcp '%s' not available", name)
@@ -268,7 +282,7 @@ func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*mc
 }
 
 // updateState updates the state of an MCP client and publishes an event
-func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
+func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
 	info := ClientInfo{
 		Name:   name,
 		State:  state,
@@ -294,7 +308,7 @@ func updateState(name string, state State, err error, client *mcp.ClientSession,
 	})
 }
 
-func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
+func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
 	timeout := mcpTimeout(m)
 	mcpCtx, cancel := context.WithCancel(ctx)
 	cancelTimer := time.AfterFunc(timeout, cancel)
@@ -352,7 +366,7 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve
 
 	cancelTimer.Stop()
 	slog.Debug("MCP client initialized", "name", name)
-	return session, nil
+	return &ClientSession{session, cancel}, nil
 }
 
 // maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail

internal/agent/tools/mcp/init_test.go 🔗

@@ -0,0 +1,38 @@
+package mcp
+
+import (
+	"context"
+	"testing"
+
+	"github.com/modelcontextprotocol/go-sdk/mcp"
+	"github.com/stretchr/testify/require"
+	"go.uber.org/goleak"
+)
+
+func TestMCPSession_CancelOnClose(t *testing.T) {
+	defer goleak.VerifyNone(t)
+
+	serverTransport, clientTransport := mcp.NewInMemoryTransports()
+
+	server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
+	serverSession, err := server.Connect(context.Background(), serverTransport, nil)
+	require.NoError(t, err)
+	defer serverSession.Close()
+
+	ctx, cancel := context.WithCancel(context.Background())
+
+	client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
+	clientSession, err := client.Connect(ctx, clientTransport, nil)
+	require.NoError(t, err)
+
+	sess := &ClientSession{clientSession, cancel}
+
+	// Verify the context is not cancelled before close.
+	require.NoError(t, ctx.Err())
+
+	err = sess.Close()
+	require.NoError(t, err)
+
+	// After Close, the context must be cancelled.
+	require.ErrorIs(t, ctx.Err(), context.Canceled)
+}

internal/agent/tools/mcp/prompts.go 🔗

@@ -67,7 +67,7 @@ func RefreshPrompts(ctx context.Context, name string) {
 	updateState(name, StateConnected, nil, session, prev.Counts)
 }
 
-func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
+func getPrompts(ctx context.Context, c *ClientSession) ([]*Prompt, error) {
 	if c.InitializeResult().Capabilities.Prompts == nil {
 		return nil, nil
 	}

internal/agent/tools/mcp/resources.go 🔗

@@ -75,7 +75,7 @@ func RefreshResources(ctx context.Context, name string) {
 	updateState(name, StateConnected, nil, session, prev.Counts)
 }
 
-func getResources(ctx context.Context, c *mcp.ClientSession) ([]*Resource, error) {
+func getResources(ctx context.Context, c *ClientSession) ([]*Resource, error) {
 	if c.InitializeResult().Capabilities.Resources == nil {
 		return nil, nil
 	}

internal/agent/tools/mcp/tools.go 🔗

@@ -128,7 +128,7 @@ func RefreshTools(ctx context.Context, cfg *config.Config, name string) {
 	updateState(name, StateConnected, nil, session, prev.Counts)
 }
 
-func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
+func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) {
 	// Always call ListTools to get the actual available tools.
 	// The InitializeResult Capabilities.Tools field may be an empty object {},
 	// which is valid per MCP spec, but we still need to call ListTools to discover tools.