Detailed changes
@@ -38,8 +38,22 @@ func parseLevel(level mcp.LoggingLevel) slog.Level {
}
}
+// ClientSession wraps an mcp.ClientSession with a context cancel function so
+// that the context created during session establishment is properly cleaned up
+// on close.
+type ClientSession struct {
+ *mcp.ClientSession
+ cancel context.CancelFunc
+}
+
+// Close cancels the session context and then closes the underlying session.
+func (s *ClientSession) Close() error {
+ s.cancel()
+ return s.ClientSession.Close()
+}
+
var (
- sessions = csync.NewMap[string, *mcp.ClientSession]()
+ sessions = csync.NewMap[string, *ClientSession]()
states = csync.NewMap[string, ClientInfo]()
broker = pubsub.NewBroker[Event]()
initOnce sync.Once
@@ -102,7 +116,7 @@ type ClientInfo struct {
Name string
State State
Error error
- Client *mcp.ClientSession
+ Client *ClientSession
Counts Counts
ConnectedAt time.Time
}
@@ -239,7 +253,7 @@ func WaitForInit(ctx context.Context) error {
}
}
-func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*mcp.ClientSession, error) {
+func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*ClientSession, error) {
sess, ok := sessions.Get(name)
if !ok {
return nil, fmt.Errorf("mcp '%s' not available", name)
@@ -268,7 +282,7 @@ func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*mc
}
// updateState updates the state of an MCP client and publishes an event
-func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
+func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
info := ClientInfo{
Name: name,
State: state,
@@ -294,7 +308,7 @@ func updateState(name string, state State, err error, client *mcp.ClientSession,
})
}
-func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
+func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
timeout := mcpTimeout(m)
mcpCtx, cancel := context.WithCancel(ctx)
cancelTimer := time.AfterFunc(timeout, cancel)
@@ -352,7 +366,7 @@ func createSession(ctx context.Context, name string, m config.MCPConfig, resolve
cancelTimer.Stop()
slog.Debug("MCP client initialized", "name", name)
- return session, nil
+ return &ClientSession{session, cancel}, nil
}
// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
@@ -0,0 +1,38 @@
+package mcp
+
+import (
+ "context"
+ "testing"
+
+ "github.com/modelcontextprotocol/go-sdk/mcp"
+ "github.com/stretchr/testify/require"
+ "go.uber.org/goleak"
+)
+
+func TestMCPSession_CancelOnClose(t *testing.T) {
+ defer goleak.VerifyNone(t)
+
+ serverTransport, clientTransport := mcp.NewInMemoryTransports()
+
+ server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
+ serverSession, err := server.Connect(context.Background(), serverTransport, nil)
+ require.NoError(t, err)
+ defer serverSession.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
+ clientSession, err := client.Connect(ctx, clientTransport, nil)
+ require.NoError(t, err)
+
+ sess := &ClientSession{clientSession, cancel}
+
+ // Verify the context is not cancelled before close.
+ require.NoError(t, ctx.Err())
+
+ err = sess.Close()
+ require.NoError(t, err)
+
+ // After Close, the context must be cancelled.
+ require.ErrorIs(t, ctx.Err(), context.Canceled)
+}
@@ -67,7 +67,7 @@ func RefreshPrompts(ctx context.Context, name string) {
updateState(name, StateConnected, nil, session, prev.Counts)
}
-func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) {
+func getPrompts(ctx context.Context, c *ClientSession) ([]*Prompt, error) {
if c.InitializeResult().Capabilities.Prompts == nil {
return nil, nil
}
@@ -75,7 +75,7 @@ func RefreshResources(ctx context.Context, name string) {
updateState(name, StateConnected, nil, session, prev.Counts)
}
-func getResources(ctx context.Context, c *mcp.ClientSession) ([]*Resource, error) {
+func getResources(ctx context.Context, c *ClientSession) ([]*Resource, error) {
if c.InitializeResult().Capabilities.Resources == nil {
return nil, nil
}
@@ -128,7 +128,7 @@ func RefreshTools(ctx context.Context, cfg *config.Config, name string) {
updateState(name, StateConnected, nil, session, prev.Counts)
}
-func getTools(ctx context.Context, session *mcp.ClientSession) ([]*Tool, error) {
+func getTools(ctx context.Context, session *ClientSession) ([]*Tool, error) {
// Always call ListTools to get the actual available tools.
// The InitializeResult Capabilities.Tools field may be an empty object {},
// which is valid per MCP spec, but we still need to call ListTools to discover tools.