@@ -14,6 +14,7 @@ import (
"os/exec"
"strings"
"sync"
+ "syscall"
"time"
"github.com/charmbracelet/crush/internal/config"
@@ -136,6 +137,26 @@ func GetState(name string) (ClientInfo, bool) {
return states.Get(name)
}
+// isIgnorableCloseErr returns true for errors that are expected during MCP
+// session shutdown and can be safely suppressed.
+func isIgnorableCloseErr(err error) bool {
+ return err == nil ||
+ errors.Is(err, io.EOF) ||
+ errors.Is(err, context.Canceled) ||
+ isKilledErr(err)
+}
+
+// isKilledErr returns true if the error is an exec.ExitError caused by
+// SIGKILL.
+func isKilledErr(err error) bool {
+ var exitErr *exec.ExitError
+ if !errors.As(err, &exitErr) {
+ return false
+ }
+ ws, ok := exitErr.Sys().(syscall.WaitStatus)
+ return ok && ws.Signaled() && ws.Signal() == syscall.SIGKILL
+}
+
// Close closes all MCP clients. This should be called during application shutdown.
func Close(ctx context.Context) error {
var wg sync.WaitGroup
@@ -147,10 +168,7 @@ func Close(ctx context.Context) error {
}()
select {
case err := <-done:
- if err != nil &&
- !errors.Is(err, io.EOF) &&
- !errors.Is(err, context.Canceled) &&
- err.Error() != "signal: killed" {
+ if !isIgnorableCloseErr(err) {
slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
}
case <-ctx.Done():
@@ -195,7 +213,7 @@ func Initialize(ctx context.Context, permissions permission.Service, cfg *config
}()
if err := initClient(ctx, cfg, name, m, cfg.Resolver()); err != nil {
- slog.Debug("failed to initialize mcp client", "name", name, "error", err)
+ slog.Debug("Failed to initialize MCP client", "name", name, "error", err)
}
}(name, m)
}
@@ -223,7 +241,7 @@ func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore)
if m.Disabled {
updateState(name, StateDisabled, nil, nil, Counts{})
- slog.Debug("skipping disabled mcp", "name", name)
+ slog.Debug("Skipping disabled MCP", "name", name)
return nil
}
@@ -232,10 +250,8 @@ func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore)
// initClient initializes a single MCP client with the given configuration.
func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m config.MCPConfig, resolver config.VariableResolver) error {
- // Set initial starting state.
updateState(name, StateStarting, nil, nil, Counts{})
- // createSession handles its own timeout internally.
session, err := createSession(ctx, name, m, resolver)
if err != nil {
return err
@@ -243,25 +259,25 @@ func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m con
tools, err := getTools(ctx, session)
if err != nil {
- slog.Error("Error listing tools", "error", err)
+ slog.Error("Error listing tools", "name", name, "error", err)
updateState(name, StateError, err, nil, Counts{})
- session.Close()
+ closeSessionOnInitError(name, session)
return err
}
prompts, err := getPrompts(ctx, session)
if err != nil {
- slog.Error("Error listing prompts", "error", err)
+ slog.Error("Error listing prompts", "name", name, "error", err)
updateState(name, StateError, err, nil, Counts{})
- session.Close()
+ closeSessionOnInitError(name, session)
return err
}
resources, err := getResources(ctx, session)
if err != nil {
- slog.Error("Error listing resources", "error", err)
+ slog.Error("Error listing resources", "name", name, "error", err)
updateState(name, StateError, err, nil, Counts{})
- session.Close()
+ closeSessionOnInitError(name, session)
return err
}
@@ -279,28 +295,52 @@ func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m con
return nil
}
+// closeSessionOnInitError closes a session that failed during initialization,
+// suppressing expected shutdown errors. Uses a fixed timeout to avoid blocking
+// indefinitely if the parent context has no deadline.
+//
+// On timeout the Close goroutine may outlive this function, but since
+// session.Close cancels the session context internally, it will unblock
+// shortly after.
+func closeSessionOnInitError(name string, session *ClientSession) {
+ const closeTimeout = 5 * time.Second
+ ctx, cancel := context.WithTimeout(context.Background(), closeTimeout)
+ defer cancel()
+
+ done := make(chan error, 1)
+ go func() {
+ done <- session.Close()
+ }()
+
+ select {
+ case err := <-done:
+ if !isIgnorableCloseErr(err) {
+ slog.Warn("Failed to close MCP session after init error", "name", name, "error", err)
+ }
+ case <-ctx.Done():
+ slog.Warn("Timed out waiting to close MCP session after init error", "name", name, "error", ctx.Err())
+ }
+}
+
// DisableSingle disables and closes a single MCP client by name.
-func DisableSingle(cfg *config.ConfigStore, name string) error {
+func DisableSingle(cfg *config.ConfigStore, name string) {
session, ok := sessions.Get(name)
if ok {
- if err := session.Close(); err != nil &&
- !errors.Is(err, io.EOF) &&
- !errors.Is(err, context.Canceled) &&
- err.Error() != "signal: killed" {
- slog.Warn("error closing mcp session", "name", name, "error", err)
+ if err := session.Close(); !isIgnorableCloseErr(err) {
+ slog.Warn("Error closing MCP session", "name", name, "error", err)
}
sessions.Del(name)
}
- // Clear tools and prompts for this MCP.
+ // Clear tools, prompts, and resources for this MCP.
updateTools(cfg, name, nil)
updatePrompts(name, nil)
+ updateResources(name, nil)
// Update state to disabled.
updateState(name, StateDisabled, nil, nil, Counts{})
- slog.Info("Disabled mcp client", "name", name)
- return nil
+ slog.Info("Disabled MCP client", "name", name)
}
func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) {
@@ -4,6 +4,7 @@ import (
"context"
"testing"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
@@ -36,3 +37,94 @@ func TestMCPSession_CancelOnClose(t *testing.T) {
// After Close, the context must be cancelled.
require.ErrorIs(t, ctx.Err(), context.Canceled)
}
+
+func TestInitClient_PopulatesResources(t *testing.T) {
+ defer goleak.VerifyNone(t,
+ goleak.IgnoreAnyFunction("net/http.(*http2Transport).newClientConn"),
+ goleak.IgnoreAnyFunction("internal/poll.runtime_pollWait"),
+ goleak.IgnoreAnyFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
+ )
+
+ const name = "test-resources"
+
+ serverTransport, clientTransport := mcp.NewInMemoryTransports()
+ server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
+ server.AddResource(
+ &mcp.Resource{URI: "file:///readme.md", Name: "readme"},
+ func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
+ return &mcp.ReadResourceResult{
+ Contents: []*mcp.ResourceContents{{URI: "file:///readme.md"}},
+ }, nil
+ },
+ )
+ server.AddResource(
+ &mcp.Resource{URI: "file:///license", Name: "license"},
+ func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
+ return &mcp.ReadResourceResult{
+ Contents: []*mcp.ResourceContents{{URI: "file:///license"}},
+ }, nil
+ },
+ )
+
+ serverSession, err := server.Connect(context.Background(), serverTransport, nil)
+ require.NoError(t, err)
+ defer serverSession.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
+ clientSession, err := client.Connect(ctx, clientTransport, nil)
+ require.NoError(t, err)
+ session := &ClientSession{clientSession, cancel}
+
+ cfg, err := config.Init(t.TempDir(), "", false)
+ require.NoError(t, err)
+
+ // Clean up any prior state for this name.
+ t.Cleanup(func() {
+ allTools.Del(name)
+ allPrompts.Del(name)
+ allResources.Del(name)
+ sessions.Del(name)
+ states.Del(name)
+ })
+
+ toolCount := updateTools(cfg, name, nil)
+ updatePrompts(name, nil)
+ resourceCount := updateResources(name, nil)
+ require.Equal(t, 0, toolCount)
+ require.Equal(t, 0, resourceCount)
+
+ // Simulate what initClient does after creating a session.
+ tools, err := getTools(ctx, session)
+ require.NoError(t, err)
+
+ prompts, err := getPrompts(ctx, session)
+ require.NoError(t, err)
+
+ resources, err := getResources(ctx, session)
+ require.NoError(t, err)
+ require.Len(t, resources, 2)
+
+ toolCount = updateTools(cfg, name, tools)
+ updatePrompts(name, prompts)
+ resourceCount = updateResources(name, resources)
+ sessions.Set(name, session)
+
+ updateState(name, StateConnected, nil, session, Counts{
+ Tools: toolCount,
+ Prompts: len(prompts),
+ Resources: resourceCount,
+ })
+
+ // Verify resources are stored and counts are correct.
+ storedResources, ok := allResources.Get(name)
+ require.True(t, ok)
+ require.Len(t, storedResources, 2)
+
+ state, ok := states.Get(name)
+ require.True(t, ok)
+ require.Equal(t, StateConnected, state.State)
+ require.Equal(t, 2, state.Counts.Resources)
+}