init_test.go

  1package mcp
  2
  3import (
  4	"context"
  5	"maps"
  6	"os"
  7	"testing"
  8
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/env"
 11	"github.com/modelcontextprotocol/go-sdk/mcp"
 12	"github.com/stretchr/testify/require"
 13	"go.uber.org/goleak"
 14)
 15
 16// shellResolverWithPath builds a shell resolver whose env carries PATH
 17// plus any caller-supplied overrides. Without PATH, $(cat), $(echo),
 18// etc. can't find their binaries in a test process where the shell env
 19// is otherwise empty.
 20func shellResolverWithPath(t *testing.T, overrides map[string]string) config.VariableResolver {
 21	t.Helper()
 22	m := map[string]string{"PATH": os.Getenv("PATH")}
 23	maps.Copy(m, overrides)
 24	return config.NewShellVariableResolver(env.NewFromMap(m))
 25}
 26
 27func TestMCPSession_CancelOnClose(t *testing.T) {
 28	defer goleak.VerifyNone(t)
 29
 30	serverTransport, clientTransport := mcp.NewInMemoryTransports()
 31
 32	server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
 33	serverSession, err := server.Connect(context.Background(), serverTransport, nil)
 34	require.NoError(t, err)
 35	defer serverSession.Close()
 36
 37	ctx, cancel := context.WithCancel(context.Background())
 38
 39	client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
 40	clientSession, err := client.Connect(ctx, clientTransport, nil)
 41	require.NoError(t, err)
 42
 43	sess := &ClientSession{clientSession, cancel}
 44
 45	// Verify the context is not cancelled before close.
 46	require.NoError(t, ctx.Err())
 47
 48	err = sess.Close()
 49	require.NoError(t, err)
 50
 51	// After Close, the context must be cancelled.
 52	require.ErrorIs(t, ctx.Err(), context.Canceled)
 53}
 54
 55// TestCreateTransport_URLResolution pins that m.URL goes through the
 56// same resolver seam as command, args, env, and headers. Covers both
 57// the HTTP and SSE branches, success and failure, so a regression in
 58// ResolvedURL wiring is caught at the transport layer rather than only
 59// at the config layer.
 60func TestCreateTransport_URLResolution(t *testing.T) {
 61	t.Parallel()
 62
 63	shell := config.NewShellVariableResolver(env.NewFromMap(map[string]string{
 64		"MCP_HOST": "mcp.example.com",
 65	}))
 66
 67	t.Run("http success expands $VAR", func(t *testing.T) {
 68		t.Parallel()
 69		m := config.MCPConfig{
 70			Type: config.MCPHttp,
 71			URL:  "https://$MCP_HOST/api",
 72		}
 73		tr, err := createTransport(t.Context(), m, shell)
 74		require.NoError(t, err)
 75		require.NotNil(t, tr)
 76		sct, ok := tr.(*mcp.StreamableClientTransport)
 77		require.True(t, ok, "expected StreamableClientTransport, got %T", tr)
 78		require.Equal(t, "https://mcp.example.com/api", sct.Endpoint)
 79	})
 80
 81	t.Run("sse success expands $(cmd)", func(t *testing.T) {
 82		t.Parallel()
 83		m := config.MCPConfig{
 84			Type: config.MCPSSE,
 85			URL:  "https://$(echo mcp.example.com)/events",
 86		}
 87		tr, err := createTransport(t.Context(), m, shell)
 88		require.NoError(t, err)
 89		sse, ok := tr.(*mcp.SSEClientTransport)
 90		require.True(t, ok, "expected SSEClientTransport, got %T", tr)
 91		require.Equal(t, "https://mcp.example.com/events", sse.Endpoint)
 92	})
 93
 94	t.Run("http unset var surfaces error, no transport created", func(t *testing.T) {
 95		t.Parallel()
 96		m := config.MCPConfig{
 97			Type: config.MCPHttp,
 98			URL:  "https://$MCP_MISSING_HOST/api",
 99		}
100		tr, err := createTransport(t.Context(), m, shell)
101		require.Error(t, err)
102		require.Nil(t, tr)
103		require.Contains(t, err.Error(), "url:")
104		require.Contains(t, err.Error(), "$MCP_MISSING_HOST")
105	})
106
107	t.Run("sse failing $(cmd) surfaces error, no transport created", func(t *testing.T) {
108		t.Parallel()
109		m := config.MCPConfig{
110			Type: config.MCPSSE,
111			URL:  "https://$(false)/events",
112		}
113		tr, err := createTransport(t.Context(), m, shell)
114		require.Error(t, err)
115		require.Nil(t, tr)
116		require.Contains(t, err.Error(), "url:")
117		require.Contains(t, err.Error(), "$(false)")
118	})
119
120	t.Run("http empty-after-resolve still fails the non-empty guard", func(t *testing.T) {
121		t.Parallel()
122		// ${MCP_EMPTY:-} resolves to the empty string (no error),
123		// then the existing TrimSpace guard in createTransport must
124		// reject it so we never spawn a transport against "".
125		m := config.MCPConfig{
126			Type: config.MCPHttp,
127			URL:  "${MCP_EMPTY:-}",
128		}
129		tr, err := createTransport(t.Context(), m, shell)
130		require.Error(t, err)
131		require.Nil(t, tr)
132		require.Contains(t, err.Error(), "non-empty 'url'")
133	})
134
135	t.Run("identity resolver round-trips template verbatim", func(t *testing.T) {
136		t.Parallel()
137		// Client mode forwards the template to the server; no local
138		// expansion, no error on unset vars.
139		tmpl := "https://$MCP_MISSING_HOST/api"
140		m := config.MCPConfig{Type: config.MCPHttp, URL: tmpl}
141		tr, err := createTransport(t.Context(), m, config.IdentityResolver())
142		require.NoError(t, err)
143		sct, ok := tr.(*mcp.StreamableClientTransport)
144		require.True(t, ok)
145		require.Equal(t, tmpl, sct.Endpoint)
146	})
147}
148
149// TestCreateTransport_StdioResolution pins that command, args, and env
150// for stdio MCPs go through the same resolver seam as the other
151// transports. Covers both success (expansion produced the expected
152// exec.Cmd) and failure (any one field erroring prevents transport
153// creation).
154func TestCreateTransport_StdioResolution(t *testing.T) {
155	t.Parallel()
156
157	t.Run("success expands command, args, and env", func(t *testing.T) {
158		t.Parallel()
159		r := shellResolverWithPath(t, map[string]string{
160			"MY_TOKEN": "hunter2",
161		})
162		m := config.MCPConfig{
163			Type:    config.MCPStdio,
164			Command: "forgejo-mcp",
165			Args:    []string{"--token", "$MY_TOKEN", "--host", "$(echo example.com)"},
166			Env: map[string]string{
167				"SECRET":    "$(echo shh)",
168				"PLAIN":     "literal",
169				"REFERENCE": "$MY_TOKEN",
170			},
171		}
172		tr, err := createTransport(t.Context(), m, r)
173		require.NoError(t, err)
174		require.NotNil(t, tr)
175
176		ct, ok := tr.(*mcp.CommandTransport)
177		require.True(t, ok, "expected CommandTransport, got %T", tr)
178
179		// exec.Cmd.Args[0] is the command name; the rest are positional
180		// args as passed.
181		require.Equal(t, []string{"forgejo-mcp", "--token", "hunter2", "--host", "example.com"}, ct.Command.Args)
182
183		// Env is os.Environ() + resolved entries (sorted). Check the
184		// resolved entries are present with their expanded values.
185		require.Contains(t, ct.Command.Env, "SECRET=shh")
186		require.Contains(t, ct.Command.Env, "PLAIN=literal")
187		require.Contains(t, ct.Command.Env, "REFERENCE=hunter2")
188	})
189
190	t.Run("env resolution failure surfaces error, no transport created", func(t *testing.T) {
191		t.Parallel()
192		r := shellResolverWithPath(t, nil)
193		m := config.MCPConfig{
194			Type:    config.MCPStdio,
195			Command: "forgejo-mcp",
196			Env:     map[string]string{"TOKEN": "$(false)"},
197		}
198		tr, err := createTransport(t.Context(), m, r)
199		require.Error(t, err)
200		require.Nil(t, tr)
201		require.Contains(t, err.Error(), "env TOKEN")
202	})
203
204	t.Run("unset env var is a hard error, not silent empty", func(t *testing.T) {
205		t.Parallel()
206		// The regression at the heart of PLAN.md: unset vars used to
207		// silently substitute "" and hand an empty credential to the
208		// child process. Now they must error out before exec.
209		r := shellResolverWithPath(t, nil)
210		m := config.MCPConfig{
211			Type:    config.MCPStdio,
212			Command: "forgejo-mcp",
213			Env:     map[string]string{"FORGEJO_ACCESS_TOKEN": "$FORGJO_TOKEN"},
214		}
215		tr, err := createTransport(t.Context(), m, r)
216		require.Error(t, err)
217		require.Nil(t, tr)
218		require.Contains(t, err.Error(), "env FORGEJO_ACCESS_TOKEN")
219	})
220
221	t.Run("args resolution failure surfaces error, no transport created", func(t *testing.T) {
222		t.Parallel()
223		r := shellResolverWithPath(t, nil)
224		m := config.MCPConfig{
225			Type:    config.MCPStdio,
226			Command: "forgejo-mcp",
227			Args:    []string{"--token", "$(false)"},
228		}
229		tr, err := createTransport(t.Context(), m, r)
230		require.Error(t, err)
231		require.Nil(t, tr)
232		require.Contains(t, err.Error(), "arg 1")
233	})
234
235	t.Run("command resolution failure surfaces error, no transport created", func(t *testing.T) {
236		t.Parallel()
237		r := shellResolverWithPath(t, nil)
238		m := config.MCPConfig{
239			Type:    config.MCPStdio,
240			Command: "$(false)",
241		}
242		tr, err := createTransport(t.Context(), m, r)
243		require.Error(t, err)
244		require.Nil(t, tr)
245		require.Contains(t, err.Error(), "invalid mcp command")
246	})
247
248	t.Run("identity resolver round-trips templates verbatim", func(t *testing.T) {
249		t.Parallel()
250		// Client mode: no local expansion, no error on unset vars.
251		m := config.MCPConfig{
252			Type:    config.MCPStdio,
253			Command: "forgejo-mcp",
254			Args:    []string{"--token", "$MCP_MISSING"},
255			Env:     map[string]string{"TOKEN": "$(vault read -f token)"},
256		}
257		tr, err := createTransport(t.Context(), m, config.IdentityResolver())
258		require.NoError(t, err)
259		ct, ok := tr.(*mcp.CommandTransport)
260		require.True(t, ok)
261		require.Equal(t, []string{"forgejo-mcp", "--token", "$MCP_MISSING"}, ct.Command.Args)
262		require.Contains(t, ct.Command.Env, "TOKEN=$(vault read -f token)")
263	})
264}
265
266// TestCreateTransport_HeadersResolution pins that a single failing
267// header aborts HTTP/SSE transport creation and that the successful
268// resolver passes every expanded header through to the round tripper.
269func TestCreateTransport_HeadersResolution(t *testing.T) {
270	t.Parallel()
271
272	t.Run("http headers success expands $(cmd)", func(t *testing.T) {
273		t.Parallel()
274		r := shellResolverWithPath(t, map[string]string{
275			"GITHUB_TOKEN": "gh-secret",
276		})
277		m := config.MCPConfig{
278			Type: config.MCPHttp,
279			URL:  "https://mcp.example.com/api",
280			Headers: map[string]string{
281				"Authorization": "$(echo Bearer $GITHUB_TOKEN)",
282				"X-Static":      "kept",
283			},
284		}
285		tr, err := createTransport(t.Context(), m, r)
286		require.NoError(t, err)
287
288		sct, ok := tr.(*mcp.StreamableClientTransport)
289		require.True(t, ok)
290		rt, ok := sct.HTTPClient.Transport.(*headerRoundTripper)
291		require.True(t, ok, "expected headerRoundTripper, got %T", sct.HTTPClient.Transport)
292		require.Equal(t, map[string]string{
293			"Authorization": "Bearer gh-secret",
294			"X-Static":      "kept",
295		}, rt.headers)
296	})
297
298	t.Run("http failing header surfaces error, no transport", func(t *testing.T) {
299		t.Parallel()
300		r := shellResolverWithPath(t, nil)
301		m := config.MCPConfig{
302			Type:    config.MCPHttp,
303			URL:     "https://mcp.example.com/api",
304			Headers: map[string]string{"Authorization": "$(false)"},
305		}
306		tr, err := createTransport(t.Context(), m, r)
307		require.Error(t, err)
308		require.Nil(t, tr)
309		require.Contains(t, err.Error(), "header Authorization")
310	})
311
312	t.Run("sse failing header surfaces error, no transport", func(t *testing.T) {
313		t.Parallel()
314		r := shellResolverWithPath(t, nil)
315		m := config.MCPConfig{
316			Type:    config.MCPSSE,
317			URL:     "https://mcp.example.com/events",
318			Headers: map[string]string{"Authorization": "Bearer $MISSING_TOKEN"},
319		}
320		tr, err := createTransport(t.Context(), m, r)
321		require.Error(t, err)
322		require.Nil(t, tr)
323		require.Contains(t, err.Error(), "header Authorization")
324		require.Contains(t, err.Error(), "$MISSING_TOKEN")
325	})
326}
327
328// TestCreateSession_ResolutionFailureUpdatesState pins the user-visible
329// half of the regression fix: when any of command/args/env/headers/url
330// fails to resolve, createSession must publish StateError to the state
331// map so crush_info and the TUI's MCP status card can render a real
332// error instead of the MCP silently sitting in "starting" or being
333// spawned with an empty credential.
334//
335// These subtests cannot run in parallel: `states` is a package-level
336// csync.Map and each assertion reads the entry written by the call
337// under test. They do use unique MCP names per subtest to keep them
338// independent regardless of ordering.
339func TestCreateSession_ResolutionFailureUpdatesState(t *testing.T) {
340	r := shellResolverWithPath(t, nil)
341
342	tests := []struct {
343		name            string
344		mcpName         string
345		cfg             config.MCPConfig
346		wantErrContains string
347	}{
348		{
349			name:    "stdio env failure",
350			mcpName: "test-stdio-env-fail",
351			cfg: config.MCPConfig{
352				Type:    config.MCPStdio,
353				Command: "echo",
354				Env:     map[string]string{"FORGEJO_ACCESS_TOKEN": "$(false)"},
355			},
356			wantErrContains: "env FORGEJO_ACCESS_TOKEN",
357		},
358		{
359			name:    "stdio args failure",
360			mcpName: "test-stdio-args-fail",
361			cfg: config.MCPConfig{
362				Type:    config.MCPStdio,
363				Command: "echo",
364				Args:    []string{"--token", "$MCP_UNSET_TOKEN"},
365			},
366			wantErrContains: "arg 1",
367		},
368		{
369			name:    "http url failure",
370			mcpName: "test-http-url-fail",
371			cfg: config.MCPConfig{
372				Type: config.MCPHttp,
373				URL:  "https://$MCP_MISSING_HOST/api",
374			},
375			wantErrContains: "url:",
376		},
377		{
378			name:    "http header failure",
379			mcpName: "test-http-header-fail",
380			cfg: config.MCPConfig{
381				Type:    config.MCPHttp,
382				URL:     "https://mcp.example.com/api",
383				Headers: map[string]string{"Authorization": "$(false)"},
384			},
385			wantErrContains: "header Authorization",
386		},
387		{
388			name:    "sse url failure",
389			mcpName: "test-sse-url-fail",
390			cfg: config.MCPConfig{
391				Type: config.MCPSSE,
392				URL:  "https://$(false)/events",
393			},
394			wantErrContains: "url:",
395		},
396		{
397			name:    "sse header failure",
398			mcpName: "test-sse-header-fail",
399			cfg: config.MCPConfig{
400				Type:    config.MCPSSE,
401				URL:     "https://mcp.example.com/events",
402				Headers: map[string]string{"Authorization": "Bearer $MISSING_SSE_TOKEN"},
403			},
404			wantErrContains: "header Authorization",
405		},
406	}
407
408	for _, tc := range tests {
409		t.Run(tc.name, func(t *testing.T) {
410			// Guarantee a clean slate on the shared state map so a
411			// stale entry from another test can't satisfy the
412			// assertion.
413			states.Del(tc.mcpName)
414			t.Cleanup(func() { states.Del(tc.mcpName) })
415
416			sess, err := createSession(t.Context(), tc.mcpName, tc.cfg, r)
417			require.Error(t, err)
418			require.Nil(t, sess)
419			require.Contains(t, err.Error(), tc.wantErrContains)
420
421			info, ok := GetState(tc.mcpName)
422			require.True(t, ok, "state entry must be written for %q", tc.mcpName)
423			require.Equal(t, StateError, info.State, "expected StateError, got %s", info.State)
424			require.Error(t, info.Error, "state must carry the failure error")
425			require.Contains(t, info.Error.Error(), tc.wantErrContains)
426			require.Nil(t, info.Client, "no client session on failure")
427		})
428	}
429}