init_test.go

  1package mcp
  2
  3import (
  4	"context"
  5	"maps"
  6	"os"
  7	"testing"
  8
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/env"
 11	"github.com/modelcontextprotocol/go-sdk/mcp"
 12	"github.com/stretchr/testify/require"
 13	"go.uber.org/goleak"
 14)
 15
 16// shellResolverWithPath builds a shell resolver whose env carries PATH
 17// plus any caller-supplied overrides. Without PATH, $(cat), $(echo),
 18// etc. can't find their binaries in a test process where the shell env
 19// is otherwise empty.
 20func shellResolverWithPath(t *testing.T, overrides map[string]string) config.VariableResolver {
 21	t.Helper()
 22	m := map[string]string{"PATH": os.Getenv("PATH")}
 23	maps.Copy(m, overrides)
 24	return config.NewShellVariableResolver(env.NewFromMap(m))
 25}
 26
 27func TestMCPSession_CancelOnClose(t *testing.T) {
 28	defer goleak.VerifyNone(t)
 29
 30	serverTransport, clientTransport := mcp.NewInMemoryTransports()
 31
 32	server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
 33	serverSession, err := server.Connect(context.Background(), serverTransport, nil)
 34	require.NoError(t, err)
 35	defer serverSession.Close()
 36
 37	ctx, cancel := context.WithCancel(context.Background())
 38
 39	client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
 40	clientSession, err := client.Connect(ctx, clientTransport, nil)
 41	require.NoError(t, err)
 42
 43	sess := &ClientSession{clientSession, cancel}
 44
 45	// Verify the context is not cancelled before close.
 46	require.NoError(t, ctx.Err())
 47
 48	err = sess.Close()
 49	require.NoError(t, err)
 50
 51	// After Close, the context must be cancelled.
 52	require.ErrorIs(t, ctx.Err(), context.Canceled)
 53}
 54
 55// TestCreateTransport_URLResolution pins that m.URL goes through the
 56// same resolver seam as command, args, env, and headers. Covers both
 57// the HTTP and SSE branches, success and failure, so a regression in
 58// ResolvedURL wiring is caught at the transport layer rather than only
 59// at the config layer.
 60func TestCreateTransport_URLResolution(t *testing.T) {
 61	t.Parallel()
 62
 63	shell := config.NewShellVariableResolver(env.NewFromMap(map[string]string{
 64		"MCP_HOST": "mcp.example.com",
 65	}))
 66
 67	t.Run("http success expands $VAR", func(t *testing.T) {
 68		t.Parallel()
 69		m := config.MCPConfig{
 70			Type: config.MCPHttp,
 71			URL:  "https://$MCP_HOST/api",
 72		}
 73		tr, err := createTransport(t.Context(), m, shell)
 74		require.NoError(t, err)
 75		require.NotNil(t, tr)
 76		sct, ok := tr.(*mcp.StreamableClientTransport)
 77		require.True(t, ok, "expected StreamableClientTransport, got %T", tr)
 78		require.Equal(t, "https://mcp.example.com/api", sct.Endpoint)
 79	})
 80
 81	t.Run("sse success expands $(cmd)", func(t *testing.T) {
 82		t.Parallel()
 83		m := config.MCPConfig{
 84			Type: config.MCPSSE,
 85			URL:  "https://$(echo mcp.example.com)/events",
 86		}
 87		tr, err := createTransport(t.Context(), m, shell)
 88		require.NoError(t, err)
 89		sse, ok := tr.(*mcp.SSEClientTransport)
 90		require.True(t, ok, "expected SSEClientTransport, got %T", tr)
 91		require.Equal(t, "https://mcp.example.com/events", sse.Endpoint)
 92	})
 93
 94	t.Run("http failing $(cmd) surfaces error, no transport created", func(t *testing.T) {
 95		t.Parallel()
 96		// Under lenient nounset, unset $VAR expands to "" silently,
 97		// so the only way a URL resolution *errors* is a failing
 98		// $(cmd). Mirror the SSE subtest so both transports share
 99		// coverage for the url-resolve-failure path.
100		m := config.MCPConfig{
101			Type: config.MCPHttp,
102			URL:  "https://$(false)/api",
103		}
104		tr, err := createTransport(t.Context(), m, shellResolverWithPath(t, nil))
105		require.Error(t, err)
106		require.Nil(t, tr)
107		require.Contains(t, err.Error(), "url:")
108		require.Contains(t, err.Error(), "$(false)")
109	})
110
111	t.Run("http unset var expands empty", func(t *testing.T) {
112		t.Parallel()
113		// Pinning test for the new lenient-nounset default: an
114		// unset bare $VAR in the URL is *not* an error. It
115		// expands to "" and, here, leaves a syntactically weird
116		// but non-empty URL that the existing non-empty guard
117		// still lets through. Guards against a future regression
118		// that flips strict-by-default back on.
119		m := config.MCPConfig{
120			Type: config.MCPHttp,
121			URL:  "https://$MCP_MISSING_HOST/api",
122		}
123		tr, err := createTransport(t.Context(), m, shell)
124		require.NoError(t, err)
125		sct, ok := tr.(*mcp.StreamableClientTransport)
126		require.True(t, ok)
127		require.Equal(t, "https:///api", sct.Endpoint)
128	})
129
130	t.Run("sse failing $(cmd) surfaces error, no transport created", func(t *testing.T) {
131		t.Parallel()
132		m := config.MCPConfig{
133			Type: config.MCPSSE,
134			URL:  "https://$(false)/events",
135		}
136		tr, err := createTransport(t.Context(), m, shell)
137		require.Error(t, err)
138		require.Nil(t, tr)
139		require.Contains(t, err.Error(), "url:")
140		require.Contains(t, err.Error(), "$(false)")
141	})
142
143	t.Run("http empty-after-resolve still fails the non-empty guard", func(t *testing.T) {
144		t.Parallel()
145		// ${MCP_EMPTY:-} resolves to the empty string (no error),
146		// then the existing TrimSpace guard in createTransport must
147		// reject it so we never spawn a transport against "".
148		m := config.MCPConfig{
149			Type: config.MCPHttp,
150			URL:  "${MCP_EMPTY:-}",
151		}
152		tr, err := createTransport(t.Context(), m, shell)
153		require.Error(t, err)
154		require.Nil(t, tr)
155		require.Contains(t, err.Error(), "non-empty 'url'")
156	})
157
158	t.Run("identity resolver round-trips template verbatim", func(t *testing.T) {
159		t.Parallel()
160		// Client mode forwards the template to the server; no local
161		// expansion, no error on unset vars.
162		tmpl := "https://$MCP_MISSING_HOST/api"
163		m := config.MCPConfig{Type: config.MCPHttp, URL: tmpl}
164		tr, err := createTransport(t.Context(), m, config.IdentityResolver())
165		require.NoError(t, err)
166		sct, ok := tr.(*mcp.StreamableClientTransport)
167		require.True(t, ok)
168		require.Equal(t, tmpl, sct.Endpoint)
169	})
170}
171
172// TestCreateTransport_StdioResolution pins that command, args, and env
173// for stdio MCPs go through the same resolver seam as the other
174// transports. Covers both success (expansion produced the expected
175// exec.Cmd) and failure (any one field erroring prevents transport
176// creation).
177func TestCreateTransport_StdioResolution(t *testing.T) {
178	t.Parallel()
179
180	t.Run("success expands command, args, and env", func(t *testing.T) {
181		t.Parallel()
182		r := shellResolverWithPath(t, map[string]string{
183			"MY_TOKEN": "hunter2",
184		})
185		m := config.MCPConfig{
186			Type:    config.MCPStdio,
187			Command: "forgejo-mcp",
188			Args:    []string{"--token", "$MY_TOKEN", "--host", "$(echo example.com)"},
189			Env: map[string]string{
190				"SECRET":    "$(echo shh)",
191				"PLAIN":     "literal",
192				"REFERENCE": "$MY_TOKEN",
193			},
194		}
195		tr, err := createTransport(t.Context(), m, r)
196		require.NoError(t, err)
197		require.NotNil(t, tr)
198
199		ct, ok := tr.(*mcp.CommandTransport)
200		require.True(t, ok, "expected CommandTransport, got %T", tr)
201
202		// exec.Cmd.Args[0] is the command name; the rest are positional
203		// args as passed.
204		require.Equal(t, []string{"forgejo-mcp", "--token", "hunter2", "--host", "example.com"}, ct.Command.Args)
205
206		// Env is os.Environ() + resolved entries (sorted). Check the
207		// resolved entries are present with their expanded values.
208		require.Contains(t, ct.Command.Env, "SECRET=shh")
209		require.Contains(t, ct.Command.Env, "PLAIN=literal")
210		require.Contains(t, ct.Command.Env, "REFERENCE=hunter2")
211	})
212
213	t.Run("env resolution failure surfaces error, no transport created", func(t *testing.T) {
214		t.Parallel()
215		r := shellResolverWithPath(t, nil)
216		m := config.MCPConfig{
217			Type:    config.MCPStdio,
218			Command: "forgejo-mcp",
219			Env:     map[string]string{"TOKEN": "$(false)"},
220		}
221		tr, err := createTransport(t.Context(), m, r)
222		require.Error(t, err)
223		require.Nil(t, tr)
224		require.Contains(t, err.Error(), "env TOKEN")
225	})
226
227	t.Run("failing env command is a hard error", func(t *testing.T) {
228		t.Parallel()
229		// Under lenient nounset a bare $UNSET expands to ""
230		// silently — see the pinning subtest below. The remaining
231		// failure mode for env resolution is a $(cmd) that exits
232		// non-zero, which must still error out and prevent exec so
233		// we never hand a broken credential to the child process.
234		r := shellResolverWithPath(t, nil)
235		m := config.MCPConfig{
236			Type:    config.MCPStdio,
237			Command: "forgejo-mcp",
238			Env:     map[string]string{"FORGEJO_ACCESS_TOKEN": "$(exit 5)"},
239		}
240		tr, err := createTransport(t.Context(), m, r)
241		require.Error(t, err)
242		require.Nil(t, tr)
243		require.Contains(t, err.Error(), "env FORGEJO_ACCESS_TOKEN")
244	})
245
246	t.Run("unset env var expands empty", func(t *testing.T) {
247		t.Parallel()
248		// Pinning test for the lenient-nounset default: a bare
249		// $UNSET in an env value expands to "" without error, and
250		// the empty entry is kept on the resulting exec.Cmd (env
251		// entries, unlike headers, are not dropped — see design
252		// decision #18). Guards against a regression that flips
253		// strict-by-default back on and silently breaks users
254		// with configs like FORGEJO_ACCESS_TOKEN=$FORGEJO_TOKEN.
255		r := shellResolverWithPath(t, nil)
256		m := config.MCPConfig{
257			Type:    config.MCPStdio,
258			Command: "forgejo-mcp",
259			Env:     map[string]string{"FORGEJO_ACCESS_TOKEN": "$FORGEJO_TOKEN_UNSET"},
260		}
261		tr, err := createTransport(t.Context(), m, r)
262		require.NoError(t, err)
263		ct, ok := tr.(*mcp.CommandTransport)
264		require.True(t, ok)
265		require.Contains(t, ct.Command.Env, "FORGEJO_ACCESS_TOKEN=")
266	})
267
268	t.Run("args resolution failure surfaces error, no transport created", func(t *testing.T) {
269		t.Parallel()
270		r := shellResolverWithPath(t, nil)
271		m := config.MCPConfig{
272			Type:    config.MCPStdio,
273			Command: "forgejo-mcp",
274			Args:    []string{"--token", "$(false)"},
275		}
276		tr, err := createTransport(t.Context(), m, r)
277		require.Error(t, err)
278		require.Nil(t, tr)
279		require.Contains(t, err.Error(), "arg 1")
280	})
281
282	t.Run("command resolution failure surfaces error, no transport created", func(t *testing.T) {
283		t.Parallel()
284		r := shellResolverWithPath(t, nil)
285		m := config.MCPConfig{
286			Type:    config.MCPStdio,
287			Command: "$(false)",
288		}
289		tr, err := createTransport(t.Context(), m, r)
290		require.Error(t, err)
291		require.Nil(t, tr)
292		require.Contains(t, err.Error(), "invalid mcp command")
293	})
294
295	t.Run("identity resolver round-trips templates verbatim", func(t *testing.T) {
296		t.Parallel()
297		// Client mode: no local expansion, no error on unset vars.
298		m := config.MCPConfig{
299			Type:    config.MCPStdio,
300			Command: "forgejo-mcp",
301			Args:    []string{"--token", "$MCP_MISSING"},
302			Env:     map[string]string{"TOKEN": "$(vault read -f token)"},
303		}
304		tr, err := createTransport(t.Context(), m, config.IdentityResolver())
305		require.NoError(t, err)
306		ct, ok := tr.(*mcp.CommandTransport)
307		require.True(t, ok)
308		require.Equal(t, []string{"forgejo-mcp", "--token", "$MCP_MISSING"}, ct.Command.Args)
309		require.Contains(t, ct.Command.Env, "TOKEN=$(vault read -f token)")
310	})
311}
312
313// TestCreateTransport_HeadersResolution pins that a single failing
314// header aborts HTTP/SSE transport creation and that the successful
315// resolver passes every expanded header through to the round tripper.
316func TestCreateTransport_HeadersResolution(t *testing.T) {
317	t.Parallel()
318
319	t.Run("http headers success expands $(cmd)", func(t *testing.T) {
320		t.Parallel()
321		r := shellResolverWithPath(t, map[string]string{
322			"GITHUB_TOKEN": "gh-secret",
323		})
324		m := config.MCPConfig{
325			Type: config.MCPHttp,
326			URL:  "https://mcp.example.com/api",
327			Headers: map[string]string{
328				"Authorization": "$(echo Bearer $GITHUB_TOKEN)",
329				"X-Static":      "kept",
330			},
331		}
332		tr, err := createTransport(t.Context(), m, r)
333		require.NoError(t, err)
334
335		sct, ok := tr.(*mcp.StreamableClientTransport)
336		require.True(t, ok)
337		rt, ok := sct.HTTPClient.Transport.(*headerRoundTripper)
338		require.True(t, ok, "expected headerRoundTripper, got %T", sct.HTTPClient.Transport)
339		require.Equal(t, map[string]string{
340			"Authorization": "Bearer gh-secret",
341			"X-Static":      "kept",
342		}, rt.headers)
343	})
344
345	t.Run("http failing header surfaces error, no transport", func(t *testing.T) {
346		t.Parallel()
347		r := shellResolverWithPath(t, nil)
348		m := config.MCPConfig{
349			Type:    config.MCPHttp,
350			URL:     "https://mcp.example.com/api",
351			Headers: map[string]string{"Authorization": "$(false)"},
352		}
353		tr, err := createTransport(t.Context(), m, r)
354		require.Error(t, err)
355		require.Nil(t, tr)
356		require.Contains(t, err.Error(), "header Authorization")
357	})
358
359	t.Run("sse failing header surfaces error, no transport", func(t *testing.T) {
360		t.Parallel()
361		// Under lenient nounset a bare $MISSING expands to "",
362		// which ResolvedHeaders drops — no error. The failing
363		// $(cmd) path is the remaining way this can fail loudly;
364		// cover it on the SSE branch to mirror the HTTP subtest.
365		r := shellResolverWithPath(t, nil)
366		m := config.MCPConfig{
367			Type:    config.MCPSSE,
368			URL:     "https://mcp.example.com/events",
369			Headers: map[string]string{"Authorization": "$(false)"},
370		}
371		tr, err := createTransport(t.Context(), m, r)
372		require.Error(t, err)
373		require.Nil(t, tr)
374		require.Contains(t, err.Error(), "header Authorization")
375	})
376
377	t.Run("sse unset var header drops silently", func(t *testing.T) {
378		t.Parallel()
379		// Pinning test for design decision #18 + lenient nounset:
380		// a header whose value resolves to "" (here because the
381		// bare $VAR is unset) is omitted from the round tripper
382		// rather than sent as "X-Header:". Guards against a
383		// regression that either re-introduces strict-by-default
384		// or stops dropping empty headers.
385		r := shellResolverWithPath(t, nil)
386		m := config.MCPConfig{
387			Type:    config.MCPSSE,
388			URL:     "https://mcp.example.com/events",
389			Headers: map[string]string{"Authorization": "$MISSING_TOKEN"},
390		}
391		tr, err := createTransport(t.Context(), m, r)
392		require.NoError(t, err)
393		sse, ok := tr.(*mcp.SSEClientTransport)
394		require.True(t, ok)
395		rt, ok := sse.HTTPClient.Transport.(*headerRoundTripper)
396		require.True(t, ok)
397		require.NotContains(t, rt.headers, "Authorization")
398	})
399}
400
401// TestCreateSession_ResolutionFailureUpdatesState pins the user-visible
402// half of the regression fix: when any of command/args/env/headers/url
403// fails to resolve, createSession must publish StateError to the state
404// map so crush_info and the TUI's MCP status card can render a real
405// error instead of the MCP silently sitting in "starting" or being
406// spawned with an empty credential.
407//
408// These subtests cannot run in parallel: `states` is a package-level
409// csync.Map and each assertion reads the entry written by the call
410// under test. They do use unique MCP names per subtest to keep them
411// independent regardless of ordering.
412func TestCreateSession_ResolutionFailureUpdatesState(t *testing.T) {
413	r := shellResolverWithPath(t, nil)
414
415	tests := []struct {
416		name            string
417		mcpName         string
418		cfg             config.MCPConfig
419		wantErrContains string
420	}{
421		{
422			name:    "stdio env failure",
423			mcpName: "test-stdio-env-fail",
424			cfg: config.MCPConfig{
425				Type:    config.MCPStdio,
426				Command: "echo",
427				Env:     map[string]string{"FORGEJO_ACCESS_TOKEN": "$(false)"},
428			},
429			wantErrContains: "env FORGEJO_ACCESS_TOKEN",
430		},
431		{
432			// Args that reference an unset bare $VAR no longer
433			// error out under lenient nounset; the only remaining
434			// failure mode for arg resolution is a failing $(cmd).
435			name:    "stdio args failure",
436			mcpName: "test-stdio-args-fail",
437			cfg: config.MCPConfig{
438				Type:    config.MCPStdio,
439				Command: "echo",
440				Args:    []string{"--token", "$(false)"},
441			},
442			wantErrContains: "arg 1",
443		},
444		{
445			// Likewise for URL: bare $UNSET expands to ""
446			// silently, so we need a failing $(cmd) to exercise
447			// the "url:" wrap from ResolvedURL.
448			name:    "http url failure",
449			mcpName: "test-http-url-fail",
450			cfg: config.MCPConfig{
451				Type: config.MCPHttp,
452				URL:  "https://$(false)/api",
453			},
454			wantErrContains: "url:",
455		},
456		{
457			// A URL whose shell expansion yields the empty
458			// string (here via ${VAR:-}) is not a ResolvedURL
459			// error, but the non-empty guard in createTransport
460			// must still reject it so the state card renders an
461			// error instead of spawning a transport against "".
462			name:    "http empty-resolved url",
463			mcpName: "test-http-url-empty",
464			cfg: config.MCPConfig{
465				Type: config.MCPHttp,
466				URL:  "${MCP_URL_EMPTY:-}",
467			},
468			wantErrContains: "non-empty 'url'",
469		},
470		{
471			name:    "http header failure",
472			mcpName: "test-http-header-fail",
473			cfg: config.MCPConfig{
474				Type:    config.MCPHttp,
475				URL:     "https://mcp.example.com/api",
476				Headers: map[string]string{"Authorization": "$(false)"},
477			},
478			wantErrContains: "header Authorization",
479		},
480		{
481			name:    "sse url failure",
482			mcpName: "test-sse-url-fail",
483			cfg: config.MCPConfig{
484				Type: config.MCPSSE,
485				URL:  "https://$(false)/events",
486			},
487			wantErrContains: "url:",
488		},
489		{
490			// Bare $MISSING in a header resolves to "" silently
491			// and is then dropped (design decision #18). The
492			// "header Authorization" wrap only surfaces on a
493			// $(cmd) failure; that is what this subtest now
494			// pins for the SSE path.
495			name:    "sse header failure",
496			mcpName: "test-sse-header-fail",
497			cfg: config.MCPConfig{
498				Type:    config.MCPSSE,
499				URL:     "https://mcp.example.com/events",
500				Headers: map[string]string{"Authorization": "$(false)"},
501			},
502			wantErrContains: "header Authorization",
503		},
504	}
505
506	for _, tc := range tests {
507		t.Run(tc.name, func(t *testing.T) {
508			// Guarantee a clean slate on the shared state map so a
509			// stale entry from another test can't satisfy the
510			// assertion.
511			states.Del(tc.mcpName)
512			t.Cleanup(func() { states.Del(tc.mcpName) })
513
514			sess, err := createSession(t.Context(), tc.mcpName, tc.cfg, r)
515			require.Error(t, err)
516			require.Nil(t, sess)
517			require.Contains(t, err.Error(), tc.wantErrContains)
518
519			info, ok := GetState(tc.mcpName)
520			require.True(t, ok, "state entry must be written for %q", tc.mcpName)
521			require.Equal(t, StateError, info.State, "expected StateError, got %s", info.State)
522			require.Error(t, info.Error, "state must carry the failure error")
523			require.Contains(t, info.Error.Error(), tc.wantErrContains)
524			require.Nil(t, info.Client, "no client session on failure")
525		})
526	}
527}