From d391ea85ed2f43d8e1290a2c64269c0782a6185a Mon Sep 17 00:00:00 2001 From: M1xA Date: Tue, 10 Feb 2026 14:56:25 +0200 Subject: [PATCH] fix(mcp): cancel context on MCP session close to prevent leak (#2157) * fix(mcp): cancel context on MCP session close to prevent leak * refactor: rename unexported mcpSession to exported ClientSession --- internal/agent/tools/mcp/init.go | 26 +++++++++++++----- internal/agent/tools/mcp/init_test.go | 38 +++++++++++++++++++++++++++ internal/agent/tools/mcp/prompts.go | 2 +- internal/agent/tools/mcp/resources.go | 2 +- internal/agent/tools/mcp/tools.go | 2 +- 5 files changed, 61 insertions(+), 9 deletions(-) create mode 100644 internal/agent/tools/mcp/init_test.go diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go index e8397915f434072387d92fd59c8842a278709426..f8cfe0ce84bf7b1987496607d42753b8ca72263f 100644 --- a/internal/agent/tools/mcp/init.go +++ b/internal/agent/tools/mcp/init.go @@ -38,8 +38,22 @@ func parseLevel(level mcp.LoggingLevel) slog.Level { } } +// ClientSession wraps an mcp.ClientSession with a context cancel function so +// that the context created during session establishment is properly cleaned up +// on close. +type ClientSession struct { + *mcp.ClientSession + cancel context.CancelFunc +} + +// Close cancels the session context and then closes the underlying session. +func (s *ClientSession) Close() error { + s.cancel() + return s.ClientSession.Close() +} + var ( - sessions = csync.NewMap[string, *mcp.ClientSession]() + sessions = csync.NewMap[string, *ClientSession]() states = csync.NewMap[string, ClientInfo]() broker = pubsub.NewBroker[Event]() initOnce sync.Once @@ -102,7 +116,7 @@ type ClientInfo struct { Name string State State Error error - Client *mcp.ClientSession + Client *ClientSession Counts Counts ConnectedAt time.Time } @@ -239,7 +253,7 @@ func WaitForInit(ctx context.Context) error { } } -func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*mcp.ClientSession, error) { +func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*ClientSession, error) { sess, ok := sessions.Get(name) if !ok { return nil, fmt.Errorf("mcp '%s' not available", name) @@ -268,7 +282,7 @@ func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*mc } // updateState updates the state of an MCP client and publishes an event -func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) { +func updateState(name string, state State, err error, client *ClientSession, counts Counts) { info := ClientInfo{ Name: name, State: state, @@ -294,7 +308,7 @@ func updateState(name string, state State, err error, client *mcp.ClientSession, }) } -func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) { +func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) { timeout := mcpTimeout(m) mcpCtx, cancel := context.WithCancel(ctx) cancelTimer := time.AfterFunc(timeout, cancel) @@ -352,7 +366,7 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve cancelTimer.Stop() slog.Debug("MCP client initialized", "name", name) - return session, nil + return &ClientSession{session, cancel}, nil } // maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail diff --git a/internal/agent/tools/mcp/init_test.go b/internal/agent/tools/mcp/init_test.go new file mode 100644 index 0000000000000000000000000000000000000000..94958593750852d30ff96734ada23671252e508e --- /dev/null +++ b/internal/agent/tools/mcp/init_test.go @@ -0,0 +1,38 @@ +package mcp + +import ( + "context" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" +) + +func TestMCPSession_CancelOnClose(t *testing.T) { + defer goleak.VerifyNone(t) + + serverTransport, clientTransport := mcp.NewInMemoryTransports() + + server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil) + serverSession, err := server.Connect(context.Background(), serverTransport, nil) + require.NoError(t, err) + defer serverSession.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil) + clientSession, err := client.Connect(ctx, clientTransport, nil) + require.NoError(t, err) + + sess := &ClientSession{clientSession, cancel} + + // Verify the context is not cancelled before close. + require.NoError(t, ctx.Err()) + + err = sess.Close() + require.NoError(t, err) + + // After Close, the context must be cancelled. + require.ErrorIs(t, ctx.Err(), context.Canceled) +} diff --git a/internal/agent/tools/mcp/prompts.go b/internal/agent/tools/mcp/prompts.go index 76338b4a8e349c9177ecaa216be217e241ec402d..2b39d5dc2db43aff418c3dd7561edbcebd6af865 100644 --- a/internal/agent/tools/mcp/prompts.go +++ b/internal/agent/tools/mcp/prompts.go @@ -67,7 +67,7 @@ func RefreshPrompts(ctx context.Context, name string) { updateState(name, StateConnected, nil, session, prev.Counts) } -func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) { +func getPrompts(ctx context.Context, c *ClientSession) ([]*Prompt, error) { if c.InitializeResult().Capabilities.Prompts == nil { return nil, nil } diff --git a/internal/agent/tools/mcp/resources.go b/internal/agent/tools/mcp/resources.go index 92f6c83836181a8441d35431f900f5c68334a9eb..912651f0eb4d5c8cf3999cc1fb7f6027cd9bcd52 100644 --- a/internal/agent/tools/mcp/resources.go +++ b/internal/agent/tools/mcp/resources.go @@ -75,7 +75,7 @@ func RefreshResources(ctx context.Context, name string) { updateState(name, StateConnected, nil, session, prev.Counts) } -func getResources(ctx context.Context, c *mcp.ClientSession) ([]*Resource, error) { +func getResources(ctx context.Context, c *ClientSession) ([]*Resource, error) { if c.InitializeResult().Capabilities.Resources == nil { return nil, nil } diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go index da4b463bbc850ea8bfa0c3400defecf05507951d..b6e208f7ccb3363bee0a0b60ef56c103ad9cd41b 100644 --- a/internal/agent/tools/mcp/tools.go +++ b/internal/agent/tools/mcp/tools.go @@ -128,7 +128,7 @@ func RefreshTools(ctx context.Context, cfg *config.Config, name string) { updateState(name, StateConnected, nil, session, prev.Counts) } -func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) { +func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) { // Always call ListTools to get the actual available tools. // The InitializeResult Capabilities.Tools field may be an empty object {}, // which is valid per MCP spec, but we still need to call ListTools to discover tools.