1package mcp
2
3import (
4 "context"
5 "maps"
6 "os"
7 "testing"
8
9 "github.com/charmbracelet/crush/internal/config"
10 "github.com/charmbracelet/crush/internal/env"
11 "github.com/modelcontextprotocol/go-sdk/mcp"
12 "github.com/stretchr/testify/require"
13 "go.uber.org/goleak"
14)
15
16// shellResolverWithPath builds a shell resolver whose env carries PATH
17// plus any caller-supplied overrides. Without PATH, $(cat), $(echo),
18// etc. can't find their binaries in a test process where the shell env
19// is otherwise empty.
20func shellResolverWithPath(t *testing.T, overrides map[string]string) config.VariableResolver {
21 t.Helper()
22 m := map[string]string{"PATH": os.Getenv("PATH")}
23 maps.Copy(m, overrides)
24 return config.NewShellVariableResolver(env.NewFromMap(m))
25}
26
27func TestMCPSession_CancelOnClose(t *testing.T) {
28 defer goleak.VerifyNone(t)
29
30 serverTransport, clientTransport := mcp.NewInMemoryTransports()
31
32 server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
33 serverSession, err := server.Connect(context.Background(), serverTransport, nil)
34 require.NoError(t, err)
35 defer serverSession.Close()
36
37 ctx, cancel := context.WithCancel(context.Background())
38
39 client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
40 clientSession, err := client.Connect(ctx, clientTransport, nil)
41 require.NoError(t, err)
42
43 sess := &ClientSession{clientSession, cancel}
44
45 // Verify the context is not cancelled before close.
46 require.NoError(t, ctx.Err())
47
48 err = sess.Close()
49 require.NoError(t, err)
50
51 // After Close, the context must be cancelled.
52 require.ErrorIs(t, ctx.Err(), context.Canceled)
53}
54
55// TestCreateTransport_URLResolution pins that m.URL goes through the
56// same resolver seam as command, args, env, and headers. Covers both
57// the HTTP and SSE branches, success and failure, so a regression in
58// ResolvedURL wiring is caught at the transport layer rather than only
59// at the config layer.
60func TestCreateTransport_URLResolution(t *testing.T) {
61 t.Parallel()
62
63 shell := config.NewShellVariableResolver(env.NewFromMap(map[string]string{
64 "MCP_HOST": "mcp.example.com",
65 }))
66
67 t.Run("http success expands $VAR", func(t *testing.T) {
68 t.Parallel()
69 m := config.MCPConfig{
70 Type: config.MCPHttp,
71 URL: "https://$MCP_HOST/api",
72 }
73 tr, err := createTransport(t.Context(), m, shell)
74 require.NoError(t, err)
75 require.NotNil(t, tr)
76 sct, ok := tr.(*mcp.StreamableClientTransport)
77 require.True(t, ok, "expected StreamableClientTransport, got %T", tr)
78 require.Equal(t, "https://mcp.example.com/api", sct.Endpoint)
79 })
80
81 t.Run("sse success expands $(cmd)", func(t *testing.T) {
82 t.Parallel()
83 m := config.MCPConfig{
84 Type: config.MCPSSE,
85 URL: "https://$(echo mcp.example.com)/events",
86 }
87 tr, err := createTransport(t.Context(), m, shell)
88 require.NoError(t, err)
89 sse, ok := tr.(*mcp.SSEClientTransport)
90 require.True(t, ok, "expected SSEClientTransport, got %T", tr)
91 require.Equal(t, "https://mcp.example.com/events", sse.Endpoint)
92 })
93
94 t.Run("http failing $(cmd) surfaces error, no transport created", func(t *testing.T) {
95 t.Parallel()
96 // Under lenient nounset, unset $VAR expands to "" silently,
97 // so the only way a URL resolution *errors* is a failing
98 // $(cmd). Mirror the SSE subtest so both transports share
99 // coverage for the url-resolve-failure path.
100 m := config.MCPConfig{
101 Type: config.MCPHttp,
102 URL: "https://$(false)/api",
103 }
104 tr, err := createTransport(t.Context(), m, shellResolverWithPath(t, nil))
105 require.Error(t, err)
106 require.Nil(t, tr)
107 require.Contains(t, err.Error(), "url:")
108 require.Contains(t, err.Error(), "$(false)")
109 })
110
111 t.Run("http unset var expands empty", func(t *testing.T) {
112 t.Parallel()
113 // Pinning test for the new lenient-nounset default: an
114 // unset bare $VAR in the URL is *not* an error. It
115 // expands to "" and, here, leaves a syntactically weird
116 // but non-empty URL that the existing non-empty guard
117 // still lets through. Guards against a future regression
118 // that flips strict-by-default back on.
119 m := config.MCPConfig{
120 Type: config.MCPHttp,
121 URL: "https://$MCP_MISSING_HOST/api",
122 }
123 tr, err := createTransport(t.Context(), m, shell)
124 require.NoError(t, err)
125 sct, ok := tr.(*mcp.StreamableClientTransport)
126 require.True(t, ok)
127 require.Equal(t, "https:///api", sct.Endpoint)
128 })
129
130 t.Run("sse failing $(cmd) surfaces error, no transport created", func(t *testing.T) {
131 t.Parallel()
132 m := config.MCPConfig{
133 Type: config.MCPSSE,
134 URL: "https://$(false)/events",
135 }
136 tr, err := createTransport(t.Context(), m, shell)
137 require.Error(t, err)
138 require.Nil(t, tr)
139 require.Contains(t, err.Error(), "url:")
140 require.Contains(t, err.Error(), "$(false)")
141 })
142
143 t.Run("http empty-after-resolve still fails the non-empty guard", func(t *testing.T) {
144 t.Parallel()
145 // ${MCP_EMPTY:-} resolves to the empty string (no error),
146 // then the existing TrimSpace guard in createTransport must
147 // reject it so we never spawn a transport against "".
148 m := config.MCPConfig{
149 Type: config.MCPHttp,
150 URL: "${MCP_EMPTY:-}",
151 }
152 tr, err := createTransport(t.Context(), m, shell)
153 require.Error(t, err)
154 require.Nil(t, tr)
155 require.Contains(t, err.Error(), "non-empty 'url'")
156 })
157
158 t.Run("identity resolver round-trips template verbatim", func(t *testing.T) {
159 t.Parallel()
160 // Client mode forwards the template to the server; no local
161 // expansion, no error on unset vars.
162 tmpl := "https://$MCP_MISSING_HOST/api"
163 m := config.MCPConfig{Type: config.MCPHttp, URL: tmpl}
164 tr, err := createTransport(t.Context(), m, config.IdentityResolver())
165 require.NoError(t, err)
166 sct, ok := tr.(*mcp.StreamableClientTransport)
167 require.True(t, ok)
168 require.Equal(t, tmpl, sct.Endpoint)
169 })
170}
171
172// TestCreateTransport_StdioResolution pins that command, args, and env
173// for stdio MCPs go through the same resolver seam as the other
174// transports. Covers both success (expansion produced the expected
175// exec.Cmd) and failure (any one field erroring prevents transport
176// creation).
177func TestCreateTransport_StdioResolution(t *testing.T) {
178 t.Parallel()
179
180 t.Run("success expands command, args, and env", func(t *testing.T) {
181 t.Parallel()
182 r := shellResolverWithPath(t, map[string]string{
183 "MY_TOKEN": "hunter2",
184 })
185 m := config.MCPConfig{
186 Type: config.MCPStdio,
187 Command: "forgejo-mcp",
188 Args: []string{"--token", "$MY_TOKEN", "--host", "$(echo example.com)"},
189 Env: map[string]string{
190 "SECRET": "$(echo shh)",
191 "PLAIN": "literal",
192 "REFERENCE": "$MY_TOKEN",
193 },
194 }
195 tr, err := createTransport(t.Context(), m, r)
196 require.NoError(t, err)
197 require.NotNil(t, tr)
198
199 ct, ok := tr.(*mcp.CommandTransport)
200 require.True(t, ok, "expected CommandTransport, got %T", tr)
201
202 // exec.Cmd.Args[0] is the command name; the rest are positional
203 // args as passed.
204 require.Equal(t, []string{"forgejo-mcp", "--token", "hunter2", "--host", "example.com"}, ct.Command.Args)
205
206 // Env is os.Environ() + resolved entries (sorted). Check the
207 // resolved entries are present with their expanded values.
208 require.Contains(t, ct.Command.Env, "SECRET=shh")
209 require.Contains(t, ct.Command.Env, "PLAIN=literal")
210 require.Contains(t, ct.Command.Env, "REFERENCE=hunter2")
211 })
212
213 t.Run("env resolution failure surfaces error, no transport created", func(t *testing.T) {
214 t.Parallel()
215 r := shellResolverWithPath(t, nil)
216 m := config.MCPConfig{
217 Type: config.MCPStdio,
218 Command: "forgejo-mcp",
219 Env: map[string]string{"TOKEN": "$(false)"},
220 }
221 tr, err := createTransport(t.Context(), m, r)
222 require.Error(t, err)
223 require.Nil(t, tr)
224 require.Contains(t, err.Error(), "env TOKEN")
225 })
226
227 t.Run("failing env command is a hard error", func(t *testing.T) {
228 t.Parallel()
229 // Under lenient nounset a bare $UNSET expands to ""
230 // silently — see the pinning subtest below. The remaining
231 // failure mode for env resolution is a $(cmd) that exits
232 // non-zero, which must still error out and prevent exec so
233 // we never hand a broken credential to the child process.
234 r := shellResolverWithPath(t, nil)
235 m := config.MCPConfig{
236 Type: config.MCPStdio,
237 Command: "forgejo-mcp",
238 Env: map[string]string{"FORGEJO_ACCESS_TOKEN": "$(exit 5)"},
239 }
240 tr, err := createTransport(t.Context(), m, r)
241 require.Error(t, err)
242 require.Nil(t, tr)
243 require.Contains(t, err.Error(), "env FORGEJO_ACCESS_TOKEN")
244 })
245
246 t.Run("unset env var expands empty", func(t *testing.T) {
247 t.Parallel()
248 // Pinning test for the lenient-nounset default: a bare
249 // $UNSET in an env value expands to "" without error, and
250 // the empty entry is kept on the resulting exec.Cmd (env
251 // entries, unlike headers, are not dropped — see design
252 // decision #18). Guards against a regression that flips
253 // strict-by-default back on and silently breaks users
254 // with configs like FORGEJO_ACCESS_TOKEN=$FORGEJO_TOKEN.
255 r := shellResolverWithPath(t, nil)
256 m := config.MCPConfig{
257 Type: config.MCPStdio,
258 Command: "forgejo-mcp",
259 Env: map[string]string{"FORGEJO_ACCESS_TOKEN": "$FORGEJO_TOKEN_UNSET"},
260 }
261 tr, err := createTransport(t.Context(), m, r)
262 require.NoError(t, err)
263 ct, ok := tr.(*mcp.CommandTransport)
264 require.True(t, ok)
265 require.Contains(t, ct.Command.Env, "FORGEJO_ACCESS_TOKEN=")
266 })
267
268 t.Run("args resolution failure surfaces error, no transport created", func(t *testing.T) {
269 t.Parallel()
270 r := shellResolverWithPath(t, nil)
271 m := config.MCPConfig{
272 Type: config.MCPStdio,
273 Command: "forgejo-mcp",
274 Args: []string{"--token", "$(false)"},
275 }
276 tr, err := createTransport(t.Context(), m, r)
277 require.Error(t, err)
278 require.Nil(t, tr)
279 require.Contains(t, err.Error(), "arg 1")
280 })
281
282 t.Run("command resolution failure surfaces error, no transport created", func(t *testing.T) {
283 t.Parallel()
284 r := shellResolverWithPath(t, nil)
285 m := config.MCPConfig{
286 Type: config.MCPStdio,
287 Command: "$(false)",
288 }
289 tr, err := createTransport(t.Context(), m, r)
290 require.Error(t, err)
291 require.Nil(t, tr)
292 require.Contains(t, err.Error(), "invalid mcp command")
293 })
294
295 t.Run("identity resolver round-trips templates verbatim", func(t *testing.T) {
296 t.Parallel()
297 // Client mode: no local expansion, no error on unset vars.
298 m := config.MCPConfig{
299 Type: config.MCPStdio,
300 Command: "forgejo-mcp",
301 Args: []string{"--token", "$MCP_MISSING"},
302 Env: map[string]string{"TOKEN": "$(vault read -f token)"},
303 }
304 tr, err := createTransport(t.Context(), m, config.IdentityResolver())
305 require.NoError(t, err)
306 ct, ok := tr.(*mcp.CommandTransport)
307 require.True(t, ok)
308 require.Equal(t, []string{"forgejo-mcp", "--token", "$MCP_MISSING"}, ct.Command.Args)
309 require.Contains(t, ct.Command.Env, "TOKEN=$(vault read -f token)")
310 })
311}
312
313// TestCreateTransport_HeadersResolution pins that a single failing
314// header aborts HTTP/SSE transport creation and that the successful
315// resolver passes every expanded header through to the round tripper.
316func TestCreateTransport_HeadersResolution(t *testing.T) {
317 t.Parallel()
318
319 t.Run("http headers success expands $(cmd)", func(t *testing.T) {
320 t.Parallel()
321 r := shellResolverWithPath(t, map[string]string{
322 "GITHUB_TOKEN": "gh-secret",
323 })
324 m := config.MCPConfig{
325 Type: config.MCPHttp,
326 URL: "https://mcp.example.com/api",
327 Headers: map[string]string{
328 "Authorization": "$(echo Bearer $GITHUB_TOKEN)",
329 "X-Static": "kept",
330 },
331 }
332 tr, err := createTransport(t.Context(), m, r)
333 require.NoError(t, err)
334
335 sct, ok := tr.(*mcp.StreamableClientTransport)
336 require.True(t, ok)
337 rt, ok := sct.HTTPClient.Transport.(*headerRoundTripper)
338 require.True(t, ok, "expected headerRoundTripper, got %T", sct.HTTPClient.Transport)
339 require.Equal(t, map[string]string{
340 "Authorization": "Bearer gh-secret",
341 "X-Static": "kept",
342 }, rt.headers)
343 })
344
345 t.Run("http failing header surfaces error, no transport", func(t *testing.T) {
346 t.Parallel()
347 r := shellResolverWithPath(t, nil)
348 m := config.MCPConfig{
349 Type: config.MCPHttp,
350 URL: "https://mcp.example.com/api",
351 Headers: map[string]string{"Authorization": "$(false)"},
352 }
353 tr, err := createTransport(t.Context(), m, r)
354 require.Error(t, err)
355 require.Nil(t, tr)
356 require.Contains(t, err.Error(), "header Authorization")
357 })
358
359 t.Run("sse failing header surfaces error, no transport", func(t *testing.T) {
360 t.Parallel()
361 // Under lenient nounset a bare $MISSING expands to "",
362 // which ResolvedHeaders drops — no error. The failing
363 // $(cmd) path is the remaining way this can fail loudly;
364 // cover it on the SSE branch to mirror the HTTP subtest.
365 r := shellResolverWithPath(t, nil)
366 m := config.MCPConfig{
367 Type: config.MCPSSE,
368 URL: "https://mcp.example.com/events",
369 Headers: map[string]string{"Authorization": "$(false)"},
370 }
371 tr, err := createTransport(t.Context(), m, r)
372 require.Error(t, err)
373 require.Nil(t, tr)
374 require.Contains(t, err.Error(), "header Authorization")
375 })
376
377 t.Run("sse unset var header drops silently", func(t *testing.T) {
378 t.Parallel()
379 // Pinning test for design decision #18 + lenient nounset:
380 // a header whose value resolves to "" (here because the
381 // bare $VAR is unset) is omitted from the round tripper
382 // rather than sent as "X-Header:". Guards against a
383 // regression that either re-introduces strict-by-default
384 // or stops dropping empty headers.
385 r := shellResolverWithPath(t, nil)
386 m := config.MCPConfig{
387 Type: config.MCPSSE,
388 URL: "https://mcp.example.com/events",
389 Headers: map[string]string{"Authorization": "$MISSING_TOKEN"},
390 }
391 tr, err := createTransport(t.Context(), m, r)
392 require.NoError(t, err)
393 sse, ok := tr.(*mcp.SSEClientTransport)
394 require.True(t, ok)
395 rt, ok := sse.HTTPClient.Transport.(*headerRoundTripper)
396 require.True(t, ok)
397 require.NotContains(t, rt.headers, "Authorization")
398 })
399}
400
401// TestCreateSession_ResolutionFailureUpdatesState pins the user-visible
402// half of the regression fix: when any of command/args/env/headers/url
403// fails to resolve, createSession must publish StateError to the state
404// map so crush_info and the TUI's MCP status card can render a real
405// error instead of the MCP silently sitting in "starting" or being
406// spawned with an empty credential.
407//
408// These subtests cannot run in parallel: `states` is a package-level
409// csync.Map and each assertion reads the entry written by the call
410// under test. They do use unique MCP names per subtest to keep them
411// independent regardless of ordering.
412func TestCreateSession_ResolutionFailureUpdatesState(t *testing.T) {
413 r := shellResolverWithPath(t, nil)
414
415 tests := []struct {
416 name string
417 mcpName string
418 cfg config.MCPConfig
419 wantErrContains string
420 }{
421 {
422 name: "stdio env failure",
423 mcpName: "test-stdio-env-fail",
424 cfg: config.MCPConfig{
425 Type: config.MCPStdio,
426 Command: "echo",
427 Env: map[string]string{"FORGEJO_ACCESS_TOKEN": "$(false)"},
428 },
429 wantErrContains: "env FORGEJO_ACCESS_TOKEN",
430 },
431 {
432 // Args that reference an unset bare $VAR no longer
433 // error out under lenient nounset; the only remaining
434 // failure mode for arg resolution is a failing $(cmd).
435 name: "stdio args failure",
436 mcpName: "test-stdio-args-fail",
437 cfg: config.MCPConfig{
438 Type: config.MCPStdio,
439 Command: "echo",
440 Args: []string{"--token", "$(false)"},
441 },
442 wantErrContains: "arg 1",
443 },
444 {
445 // Likewise for URL: bare $UNSET expands to ""
446 // silently, so we need a failing $(cmd) to exercise
447 // the "url:" wrap from ResolvedURL.
448 name: "http url failure",
449 mcpName: "test-http-url-fail",
450 cfg: config.MCPConfig{
451 Type: config.MCPHttp,
452 URL: "https://$(false)/api",
453 },
454 wantErrContains: "url:",
455 },
456 {
457 // A URL whose shell expansion yields the empty
458 // string (here via ${VAR:-}) is not a ResolvedURL
459 // error, but the non-empty guard in createTransport
460 // must still reject it so the state card renders an
461 // error instead of spawning a transport against "".
462 name: "http empty-resolved url",
463 mcpName: "test-http-url-empty",
464 cfg: config.MCPConfig{
465 Type: config.MCPHttp,
466 URL: "${MCP_URL_EMPTY:-}",
467 },
468 wantErrContains: "non-empty 'url'",
469 },
470 {
471 name: "http header failure",
472 mcpName: "test-http-header-fail",
473 cfg: config.MCPConfig{
474 Type: config.MCPHttp,
475 URL: "https://mcp.example.com/api",
476 Headers: map[string]string{"Authorization": "$(false)"},
477 },
478 wantErrContains: "header Authorization",
479 },
480 {
481 name: "sse url failure",
482 mcpName: "test-sse-url-fail",
483 cfg: config.MCPConfig{
484 Type: config.MCPSSE,
485 URL: "https://$(false)/events",
486 },
487 wantErrContains: "url:",
488 },
489 {
490 // Bare $MISSING in a header resolves to "" silently
491 // and is then dropped (design decision #18). The
492 // "header Authorization" wrap only surfaces on a
493 // $(cmd) failure; that is what this subtest now
494 // pins for the SSE path.
495 name: "sse header failure",
496 mcpName: "test-sse-header-fail",
497 cfg: config.MCPConfig{
498 Type: config.MCPSSE,
499 URL: "https://mcp.example.com/events",
500 Headers: map[string]string{"Authorization": "$(false)"},
501 },
502 wantErrContains: "header Authorization",
503 },
504 }
505
506 for _, tc := range tests {
507 t.Run(tc.name, func(t *testing.T) {
508 // Guarantee a clean slate on the shared state map so a
509 // stale entry from another test can't satisfy the
510 // assertion.
511 states.Del(tc.mcpName)
512 t.Cleanup(func() { states.Del(tc.mcpName) })
513
514 sess, err := createSession(t.Context(), tc.mcpName, tc.cfg, r)
515 require.Error(t, err)
516 require.Nil(t, sess)
517 require.Contains(t, err.Error(), tc.wantErrContains)
518
519 info, ok := GetState(tc.mcpName)
520 require.True(t, ok, "state entry must be written for %q", tc.mcpName)
521 require.Equal(t, StateError, info.State, "expected StateError, got %s", info.State)
522 require.Error(t, info.Error, "state must carry the failure error")
523 require.Contains(t, info.Error.Error(), tc.wantErrContains)
524 require.Nil(t, info.Client, "no client session on failure")
525 })
526 }
527}