1package mcp
2
3import (
4 "context"
5 "testing"
6
7 "github.com/charmbracelet/crush/internal/config"
8 "github.com/modelcontextprotocol/go-sdk/mcp"
9 "github.com/stretchr/testify/require"
10 "go.uber.org/goleak"
11)
12
13func TestMCPSession_CancelOnClose(t *testing.T) {
14 defer goleak.VerifyNone(t)
15
16 serverTransport, clientTransport := mcp.NewInMemoryTransports()
17
18 server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
19 serverSession, err := server.Connect(context.Background(), serverTransport, nil)
20 require.NoError(t, err)
21 defer serverSession.Close()
22
23 ctx, cancel := context.WithCancel(context.Background())
24
25 client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
26 clientSession, err := client.Connect(ctx, clientTransport, nil)
27 require.NoError(t, err)
28
29 sess := &ClientSession{clientSession, cancel}
30
31 // Verify the context is not cancelled before close.
32 require.NoError(t, ctx.Err())
33
34 err = sess.Close()
35 require.NoError(t, err)
36
37 // After Close, the context must be cancelled.
38 require.ErrorIs(t, ctx.Err(), context.Canceled)
39}
40
41func TestInitClient_PopulatesResources(t *testing.T) {
42 defer goleak.VerifyNone(t,
43 goleak.IgnoreAnyFunction("net/http.(*http2Transport).newClientConn"),
44 goleak.IgnoreAnyFunction("internal/poll.runtime_pollWait"),
45 goleak.IgnoreAnyFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"),
46 )
47
48 const name = "test-resources"
49
50 serverTransport, clientTransport := mcp.NewInMemoryTransports()
51 server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
52 server.AddResource(
53 &mcp.Resource{URI: "file:///readme.md", Name: "readme"},
54 func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
55 return &mcp.ReadResourceResult{
56 Contents: []*mcp.ResourceContents{{URI: "file:///readme.md"}},
57 }, nil
58 },
59 )
60 server.AddResource(
61 &mcp.Resource{URI: "file:///license", Name: "license"},
62 func(context.Context, *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) {
63 return &mcp.ReadResourceResult{
64 Contents: []*mcp.ResourceContents{{URI: "file:///license"}},
65 }, nil
66 },
67 )
68
69 serverSession, err := server.Connect(context.Background(), serverTransport, nil)
70 require.NoError(t, err)
71 defer serverSession.Close()
72
73 ctx, cancel := context.WithCancel(context.Background())
74 defer cancel()
75
76 client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
77 clientSession, err := client.Connect(ctx, clientTransport, nil)
78 require.NoError(t, err)
79 session := &ClientSession{clientSession, cancel}
80
81 cfg, err := config.Init(t.TempDir(), "", false)
82 require.NoError(t, err)
83
84 // Clean up any prior state for this name.
85 t.Cleanup(func() {
86 allTools.Del(name)
87 allPrompts.Del(name)
88 allResources.Del(name)
89 sessions.Del(name)
90 states.Del(name)
91 })
92
93 toolCount := updateTools(cfg, name, nil)
94 updatePrompts(name, nil)
95 resourceCount := updateResources(name, nil)
96 require.Equal(t, 0, toolCount)
97 require.Equal(t, 0, resourceCount)
98
99 // Simulate what initClient does after creating a session.
100 tools, err := getTools(ctx, session)
101 require.NoError(t, err)
102
103 prompts, err := getPrompts(ctx, session)
104 require.NoError(t, err)
105
106 resources, err := getResources(ctx, session)
107 require.NoError(t, err)
108 require.Len(t, resources, 2)
109
110 toolCount = updateTools(cfg, name, tools)
111 updatePrompts(name, prompts)
112 resourceCount = updateResources(name, resources)
113 sessions.Set(name, session)
114
115 updateState(name, StateConnected, nil, session, Counts{
116 Tools: toolCount,
117 Prompts: len(prompts),
118 Resources: resourceCount,
119 })
120
121 // Verify resources are stored and counts are correct.
122 storedResources, ok := allResources.Get(name)
123 require.True(t, ok)
124 require.Len(t, storedResources, 2)
125
126 state, ok := states.Get(name)
127 require.True(t, ok)
128 require.Equal(t, StateConnected, state.State)
129 require.Equal(t, 2, state.Counts.Resources)
130}