1package mcp
2
3import (
4 "context"
5 "maps"
6 "os"
7 "testing"
8
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/env"
11 "github.com/modelcontextprotocol/go-sdk/mcp"
12 "github.com/stretchr/testify/require"
13 "go.uber.org/goleak"
14)
15
16// shellResolverWithPath builds a shell resolver whose env carries PATH
17// plus any caller-supplied overrides. Without PATH, $(cat), $(echo),
18// etc. can't find their binaries in a test process where the shell env
19// is otherwise empty.
20func shellResolverWithPath(t *testing.T, overrides map[string]string) config.VariableResolver {
21 t.Helper()
22 m := map[string]string{"PATH": os.Getenv("PATH")}
23 maps.Copy(m, overrides)
24 return config.NewShellVariableResolver(env.NewFromMap(m))
25}
26
27func TestMCPSession_CancelOnClose(t *testing.T) {
28 defer goleak.VerifyNone(t)
29
30 serverTransport, clientTransport := mcp.NewInMemoryTransports()
31
32 server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
33 serverSession, err := server.Connect(context.Background(), serverTransport, nil)
34 require.NoError(t, err)
35 defer serverSession.Close()
36
37 ctx, cancel := context.WithCancel(context.Background())
38
39 client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
40 clientSession, err := client.Connect(ctx, clientTransport, nil)
41 require.NoError(t, err)
42
43 sess := &ClientSession{clientSession, cancel}
44
45 // Verify the context is not cancelled before close.
46 require.NoError(t, ctx.Err())
47
48 err = sess.Close()
49 require.NoError(t, err)
50
51 // After Close, the context must be cancelled.
52 require.ErrorIs(t, ctx.Err(), context.Canceled)
53}
54
55// TestCreateTransport_URLResolution pins that m.URL goes through the
56// same resolver seam as command, args, env, and headers. Covers both
57// the HTTP and SSE branches, success and failure, so a regression in
58// ResolvedURL wiring is caught at the transport layer rather than only
59// at the config layer.
60func TestCreateTransport_URLResolution(t *testing.T) {
61 t.Parallel()
62
63 shell := config.NewShellVariableResolver(env.NewFromMap(map[string]string{
64 "MCP_HOST": "mcp.example.com",
65 }))
66
67 t.Run("http success expands $VAR", func(t *testing.T) {
68 t.Parallel()
69 m := config.MCPConfig{
70 Type: config.MCPHttp,
71 URL: "https://$MCP_HOST/api",
72 }
73 tr, err := createTransport(t.Context(), m, shell)
74 require.NoError(t, err)
75 require.NotNil(t, tr)
76 sct, ok := tr.(*mcp.StreamableClientTransport)
77 require.True(t, ok, "expected StreamableClientTransport, got %T", tr)
78 require.Equal(t, "https://mcp.example.com/api", sct.Endpoint)
79 })
80
81 t.Run("sse success expands $(cmd)", func(t *testing.T) {
82 t.Parallel()
83 m := config.MCPConfig{
84 Type: config.MCPSSE,
85 URL: "https://$(echo mcp.example.com)/events",
86 }
87 tr, err := createTransport(t.Context(), m, shell)
88 require.NoError(t, err)
89 sse, ok := tr.(*mcp.SSEClientTransport)
90 require.True(t, ok, "expected SSEClientTransport, got %T", tr)
91 require.Equal(t, "https://mcp.example.com/events", sse.Endpoint)
92 })
93
94 t.Run("http unset var surfaces error, no transport created", func(t *testing.T) {
95 t.Parallel()
96 m := config.MCPConfig{
97 Type: config.MCPHttp,
98 URL: "https://$MCP_MISSING_HOST/api",
99 }
100 tr, err := createTransport(t.Context(), m, shell)
101 require.Error(t, err)
102 require.Nil(t, tr)
103 require.Contains(t, err.Error(), "url:")
104 require.Contains(t, err.Error(), "$MCP_MISSING_HOST")
105 })
106
107 t.Run("sse failing $(cmd) surfaces error, no transport created", func(t *testing.T) {
108 t.Parallel()
109 m := config.MCPConfig{
110 Type: config.MCPSSE,
111 URL: "https://$(false)/events",
112 }
113 tr, err := createTransport(t.Context(), m, shell)
114 require.Error(t, err)
115 require.Nil(t, tr)
116 require.Contains(t, err.Error(), "url:")
117 require.Contains(t, err.Error(), "$(false)")
118 })
119
120 t.Run("http empty-after-resolve still fails the non-empty guard", func(t *testing.T) {
121 t.Parallel()
122 // ${MCP_EMPTY:-} resolves to the empty string (no error),
123 // then the existing TrimSpace guard in createTransport must
124 // reject it so we never spawn a transport against "".
125 m := config.MCPConfig{
126 Type: config.MCPHttp,
127 URL: "${MCP_EMPTY:-}",
128 }
129 tr, err := createTransport(t.Context(), m, shell)
130 require.Error(t, err)
131 require.Nil(t, tr)
132 require.Contains(t, err.Error(), "non-empty 'url'")
133 })
134
135 t.Run("identity resolver round-trips template verbatim", func(t *testing.T) {
136 t.Parallel()
137 // Client mode forwards the template to the server; no local
138 // expansion, no error on unset vars.
139 tmpl := "https://$MCP_MISSING_HOST/api"
140 m := config.MCPConfig{Type: config.MCPHttp, URL: tmpl}
141 tr, err := createTransport(t.Context(), m, config.IdentityResolver())
142 require.NoError(t, err)
143 sct, ok := tr.(*mcp.StreamableClientTransport)
144 require.True(t, ok)
145 require.Equal(t, tmpl, sct.Endpoint)
146 })
147}
148
149// TestCreateTransport_StdioResolution pins that command, args, and env
150// for stdio MCPs go through the same resolver seam as the other
151// transports. Covers both success (expansion produced the expected
152// exec.Cmd) and failure (any one field erroring prevents transport
153// creation).
154func TestCreateTransport_StdioResolution(t *testing.T) {
155 t.Parallel()
156
157 t.Run("success expands command, args, and env", func(t *testing.T) {
158 t.Parallel()
159 r := shellResolverWithPath(t, map[string]string{
160 "MY_TOKEN": "hunter2",
161 })
162 m := config.MCPConfig{
163 Type: config.MCPStdio,
164 Command: "forgejo-mcp",
165 Args: []string{"--token", "$MY_TOKEN", "--host", "$(echo example.com)"},
166 Env: map[string]string{
167 "SECRET": "$(echo shh)",
168 "PLAIN": "literal",
169 "REFERENCE": "$MY_TOKEN",
170 },
171 }
172 tr, err := createTransport(t.Context(), m, r)
173 require.NoError(t, err)
174 require.NotNil(t, tr)
175
176 ct, ok := tr.(*mcp.CommandTransport)
177 require.True(t, ok, "expected CommandTransport, got %T", tr)
178
179 // exec.Cmd.Args[0] is the command name; the rest are positional
180 // args as passed.
181 require.Equal(t, []string{"forgejo-mcp", "--token", "hunter2", "--host", "example.com"}, ct.Command.Args)
182
183 // Env is os.Environ() + resolved entries (sorted). Check the
184 // resolved entries are present with their expanded values.
185 require.Contains(t, ct.Command.Env, "SECRET=shh")
186 require.Contains(t, ct.Command.Env, "PLAIN=literal")
187 require.Contains(t, ct.Command.Env, "REFERENCE=hunter2")
188 })
189
190 t.Run("env resolution failure surfaces error, no transport created", func(t *testing.T) {
191 t.Parallel()
192 r := shellResolverWithPath(t, nil)
193 m := config.MCPConfig{
194 Type: config.MCPStdio,
195 Command: "forgejo-mcp",
196 Env: map[string]string{"TOKEN": "$(false)"},
197 }
198 tr, err := createTransport(t.Context(), m, r)
199 require.Error(t, err)
200 require.Nil(t, tr)
201 require.Contains(t, err.Error(), "env TOKEN")
202 })
203
204 t.Run("unset env var is a hard error, not silent empty", func(t *testing.T) {
205 t.Parallel()
206 // The regression at the heart of PLAN.md: unset vars used to
207 // silently substitute "" and hand an empty credential to the
208 // child process. Now they must error out before exec.
209 r := shellResolverWithPath(t, nil)
210 m := config.MCPConfig{
211 Type: config.MCPStdio,
212 Command: "forgejo-mcp",
213 Env: map[string]string{"FORGEJO_ACCESS_TOKEN": "$FORGJO_TOKEN"},
214 }
215 tr, err := createTransport(t.Context(), m, r)
216 require.Error(t, err)
217 require.Nil(t, tr)
218 require.Contains(t, err.Error(), "env FORGEJO_ACCESS_TOKEN")
219 })
220
221 t.Run("args resolution failure surfaces error, no transport created", func(t *testing.T) {
222 t.Parallel()
223 r := shellResolverWithPath(t, nil)
224 m := config.MCPConfig{
225 Type: config.MCPStdio,
226 Command: "forgejo-mcp",
227 Args: []string{"--token", "$(false)"},
228 }
229 tr, err := createTransport(t.Context(), m, r)
230 require.Error(t, err)
231 require.Nil(t, tr)
232 require.Contains(t, err.Error(), "arg 1")
233 })
234
235 t.Run("command resolution failure surfaces error, no transport created", func(t *testing.T) {
236 t.Parallel()
237 r := shellResolverWithPath(t, nil)
238 m := config.MCPConfig{
239 Type: config.MCPStdio,
240 Command: "$(false)",
241 }
242 tr, err := createTransport(t.Context(), m, r)
243 require.Error(t, err)
244 require.Nil(t, tr)
245 require.Contains(t, err.Error(), "invalid mcp command")
246 })
247
248 t.Run("identity resolver round-trips templates verbatim", func(t *testing.T) {
249 t.Parallel()
250 // Client mode: no local expansion, no error on unset vars.
251 m := config.MCPConfig{
252 Type: config.MCPStdio,
253 Command: "forgejo-mcp",
254 Args: []string{"--token", "$MCP_MISSING"},
255 Env: map[string]string{"TOKEN": "$(vault read -f token)"},
256 }
257 tr, err := createTransport(t.Context(), m, config.IdentityResolver())
258 require.NoError(t, err)
259 ct, ok := tr.(*mcp.CommandTransport)
260 require.True(t, ok)
261 require.Equal(t, []string{"forgejo-mcp", "--token", "$MCP_MISSING"}, ct.Command.Args)
262 require.Contains(t, ct.Command.Env, "TOKEN=$(vault read -f token)")
263 })
264}
265
266// TestCreateTransport_HeadersResolution pins that a single failing
267// header aborts HTTP/SSE transport creation and that the successful
268// resolver passes every expanded header through to the round tripper.
269func TestCreateTransport_HeadersResolution(t *testing.T) {
270 t.Parallel()
271
272 t.Run("http headers success expands $(cmd)", func(t *testing.T) {
273 t.Parallel()
274 r := shellResolverWithPath(t, map[string]string{
275 "GITHUB_TOKEN": "gh-secret",
276 })
277 m := config.MCPConfig{
278 Type: config.MCPHttp,
279 URL: "https://mcp.example.com/api",
280 Headers: map[string]string{
281 "Authorization": "$(echo Bearer $GITHUB_TOKEN)",
282 "X-Static": "kept",
283 },
284 }
285 tr, err := createTransport(t.Context(), m, r)
286 require.NoError(t, err)
287
288 sct, ok := tr.(*mcp.StreamableClientTransport)
289 require.True(t, ok)
290 rt, ok := sct.HTTPClient.Transport.(*headerRoundTripper)
291 require.True(t, ok, "expected headerRoundTripper, got %T", sct.HTTPClient.Transport)
292 require.Equal(t, map[string]string{
293 "Authorization": "Bearer gh-secret",
294 "X-Static": "kept",
295 }, rt.headers)
296 })
297
298 t.Run("http failing header surfaces error, no transport", func(t *testing.T) {
299 t.Parallel()
300 r := shellResolverWithPath(t, nil)
301 m := config.MCPConfig{
302 Type: config.MCPHttp,
303 URL: "https://mcp.example.com/api",
304 Headers: map[string]string{"Authorization": "$(false)"},
305 }
306 tr, err := createTransport(t.Context(), m, r)
307 require.Error(t, err)
308 require.Nil(t, tr)
309 require.Contains(t, err.Error(), "header Authorization")
310 })
311
312 t.Run("sse failing header surfaces error, no transport", func(t *testing.T) {
313 t.Parallel()
314 r := shellResolverWithPath(t, nil)
315 m := config.MCPConfig{
316 Type: config.MCPSSE,
317 URL: "https://mcp.example.com/events",
318 Headers: map[string]string{"Authorization": "Bearer $MISSING_TOKEN"},
319 }
320 tr, err := createTransport(t.Context(), m, r)
321 require.Error(t, err)
322 require.Nil(t, tr)
323 require.Contains(t, err.Error(), "header Authorization")
324 require.Contains(t, err.Error(), "$MISSING_TOKEN")
325 })
326}
327
328// TestCreateSession_ResolutionFailureUpdatesState pins the user-visible
329// half of the regression fix: when any of command/args/env/headers/url
330// fails to resolve, createSession must publish StateError to the state
331// map so crush_info and the TUI's MCP status card can render a real
332// error instead of the MCP silently sitting in "starting" or being
333// spawned with an empty credential.
334//
335// These subtests cannot run in parallel: `states` is a package-level
336// csync.Map and each assertion reads the entry written by the call
337// under test. They do use unique MCP names per subtest to keep them
338// independent regardless of ordering.
339func TestCreateSession_ResolutionFailureUpdatesState(t *testing.T) {
340 r := shellResolverWithPath(t, nil)
341
342 tests := []struct {
343 name string
344 mcpName string
345 cfg config.MCPConfig
346 wantErrContains string
347 }{
348 {
349 name: "stdio env failure",
350 mcpName: "test-stdio-env-fail",
351 cfg: config.MCPConfig{
352 Type: config.MCPStdio,
353 Command: "echo",
354 Env: map[string]string{"FORGEJO_ACCESS_TOKEN": "$(false)"},
355 },
356 wantErrContains: "env FORGEJO_ACCESS_TOKEN",
357 },
358 {
359 name: "stdio args failure",
360 mcpName: "test-stdio-args-fail",
361 cfg: config.MCPConfig{
362 Type: config.MCPStdio,
363 Command: "echo",
364 Args: []string{"--token", "$MCP_UNSET_TOKEN"},
365 },
366 wantErrContains: "arg 1",
367 },
368 {
369 name: "http url failure",
370 mcpName: "test-http-url-fail",
371 cfg: config.MCPConfig{
372 Type: config.MCPHttp,
373 URL: "https://$MCP_MISSING_HOST/api",
374 },
375 wantErrContains: "url:",
376 },
377 {
378 name: "http header failure",
379 mcpName: "test-http-header-fail",
380 cfg: config.MCPConfig{
381 Type: config.MCPHttp,
382 URL: "https://mcp.example.com/api",
383 Headers: map[string]string{"Authorization": "$(false)"},
384 },
385 wantErrContains: "header Authorization",
386 },
387 {
388 name: "sse url failure",
389 mcpName: "test-sse-url-fail",
390 cfg: config.MCPConfig{
391 Type: config.MCPSSE,
392 URL: "https://$(false)/events",
393 },
394 wantErrContains: "url:",
395 },
396 {
397 name: "sse header failure",
398 mcpName: "test-sse-header-fail",
399 cfg: config.MCPConfig{
400 Type: config.MCPSSE,
401 URL: "https://mcp.example.com/events",
402 Headers: map[string]string{"Authorization": "Bearer $MISSING_SSE_TOKEN"},
403 },
404 wantErrContains: "header Authorization",
405 },
406 }
407
408 for _, tc := range tests {
409 t.Run(tc.name, func(t *testing.T) {
410 // Guarantee a clean slate on the shared state map so a
411 // stale entry from another test can't satisfy the
412 // assertion.
413 states.Del(tc.mcpName)
414 t.Cleanup(func() { states.Del(tc.mcpName) })
415
416 sess, err := createSession(t.Context(), tc.mcpName, tc.cfg, r)
417 require.Error(t, err)
418 require.Nil(t, sess)
419 require.Contains(t, err.Error(), tc.wantErrContains)
420
421 info, ok := GetState(tc.mcpName)
422 require.True(t, ok, "state entry must be written for %q", tc.mcpName)
423 require.Equal(t, StateError, info.State, "expected StateError, got %s", info.State)
424 require.Error(t, info.Error, "state must carry the failure error")
425 require.Contains(t, info.Error.Error(), tc.wantErrContains)
426 require.Nil(t, info.Client, "no client session on failure")
427 })
428 }
429}