init_test.go

  1package mcp
  2
  3import (
  4	"context"
  5	"testing"
  6
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/modelcontextprotocol/go-sdk/mcp"
  9	"github.com/stretchr/testify/require"
 10	"go.uber.org/goleak"
 11)
 12
 13func TestMCPSession_CancelOnClose(t *testing.T) {
 14	defer goleak.VerifyNone(t)
 15
 16	serverTransport, clientTransport := mcp.NewInMemoryTransports()
 17
 18	server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
 19	serverSession, err := server.Connect(context.Background(), serverTransport, nil)
 20	require.NoError(t, err)
 21	defer serverSession.Close()
 22
 23	ctx, cancel := context.WithCancel(context.Background())
 24
 25	client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
 26	clientSession, err := client.Connect(ctx, clientTransport, nil)
 27	require.NoError(t, err)
 28
 29	sess := &ClientSession{clientSession, cancel}
 30
 31	// Verify the context is not cancelled before close.
 32	require.NoError(t, ctx.Err())
 33
 34	err = sess.Close()
 35	require.NoError(t, err)
 36
 37	// After Close, the context must be cancelled.
 38	require.ErrorIs(t, ctx.Err(), context.Canceled)
 39}
 40
 41func TestInitClient_PopulatesResources(t *testing.T) {
 42	defer goleak.VerifyNone(t,
 43		goleak.IgnoreAnyFunction("net/http.(*http2Transport).newClientConn"),
 44		goleak.IgnoreAnyFunction("internal/poll.runtime_pollWait"),
 45		goleak.IgnoreAnyFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
 46	)
 47
 48	const name = "test-resources"
 49
 50	serverTransport, clientTransport := mcp.NewInMemoryTransports()
 51	server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
 52	server.AddResource(
 53		&mcp.Resource{URI: "file:///readme.md", Name: "readme"},
 54		func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
 55			return &mcp.ReadResourceResult{
 56				Contents: []*mcp.ResourceContents{{URI: "file:///readme.md"}},
 57			}, nil
 58		},
 59	)
 60	server.AddResource(
 61		&mcp.Resource{URI: "file:///license", Name: "license"},
 62		func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
 63			return &mcp.ReadResourceResult{
 64				Contents: []*mcp.ResourceContents{{URI: "file:///license"}},
 65			}, nil
 66		},
 67	)
 68
 69	serverSession, err := server.Connect(context.Background(), serverTransport, nil)
 70	require.NoError(t, err)
 71	defer serverSession.Close()
 72
 73	ctx, cancel := context.WithCancel(context.Background())
 74	defer cancel()
 75
 76	client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
 77	clientSession, err := client.Connect(ctx, clientTransport, nil)
 78	require.NoError(t, err)
 79	session := &ClientSession{clientSession, cancel}
 80
 81	cfg, err := config.Init(t.TempDir(), "", false)
 82	require.NoError(t, err)
 83
 84	// Clean up any prior state for this name.
 85	t.Cleanup(func() {
 86		allTools.Del(name)
 87		allPrompts.Del(name)
 88		allResources.Del(name)
 89		sessions.Del(name)
 90		states.Del(name)
 91	})
 92
 93	toolCount := updateTools(cfg, name, nil)
 94	updatePrompts(name, nil)
 95	resourceCount := updateResources(name, nil)
 96	require.Equal(t, 0, toolCount)
 97	require.Equal(t, 0, resourceCount)
 98
 99	// Simulate what initClient does after creating a session.
100	tools, err := getTools(ctx, session)
101	require.NoError(t, err)
102
103	prompts, err := getPrompts(ctx, session)
104	require.NoError(t, err)
105
106	resources, err := getResources(ctx, session)
107	require.NoError(t, err)
108	require.Len(t, resources, 2)
109
110	toolCount = updateTools(cfg, name, tools)
111	updatePrompts(name, prompts)
112	resourceCount = updateResources(name, resources)
113	sessions.Set(name, session)
114
115	updateState(name, StateConnected, nil, session, Counts{
116		Tools:     toolCount,
117		Prompts:   len(prompts),
118		Resources: resourceCount,
119	})
120
121	// Verify resources are stored and counts are correct.
122	storedResources, ok := allResources.Get(name)
123	require.True(t, ok)
124	require.Len(t, storedResources, 2)
125
126	state, ok := states.Get(name)
127	require.True(t, ok)
128	require.Equal(t, StateConnected, state.State)
129	require.Equal(t, 2, state.Counts.Resources)
130}