init_test.go

  1package mcp
  2
  3import (
  4	"context"
  5	"testing"
  6
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/env"
  9	"github.com/modelcontextprotocol/go-sdk/mcp"
 10	"github.com/stretchr/testify/require"
 11	"go.uber.org/goleak"
 12)
 13
 14func TestMCPSession_CancelOnClose(t *testing.T) {
 15	defer goleak.VerifyNone(t)
 16
 17	serverTransport, clientTransport := mcp.NewInMemoryTransports()
 18
 19	server := mcp.NewServer(&mcp.Implementation{Name: "test-server"}, nil)
 20	serverSession, err := server.Connect(context.Background(), serverTransport, nil)
 21	require.NoError(t, err)
 22	defer serverSession.Close()
 23
 24	ctx, cancel := context.WithCancel(context.Background())
 25
 26	client := mcp.NewClient(&mcp.Implementation{Name: "crush-test"}, nil)
 27	clientSession, err := client.Connect(ctx, clientTransport, nil)
 28	require.NoError(t, err)
 29
 30	sess := &ClientSession{clientSession, cancel}
 31
 32	// Verify the context is not cancelled before close.
 33	require.NoError(t, ctx.Err())
 34
 35	err = sess.Close()
 36	require.NoError(t, err)
 37
 38	// After Close, the context must be cancelled.
 39	require.ErrorIs(t, ctx.Err(), context.Canceled)
 40}
 41
 42// TestCreateTransport_URLResolution pins that m.URL goes through the
 43// same resolver seam as command, args, env, and headers. Covers both
 44// the HTTP and SSE branches, success and failure, so a regression in
 45// ResolvedURL wiring is caught at the transport layer rather than only
 46// at the config layer.
 47func TestCreateTransport_URLResolution(t *testing.T) {
 48	t.Parallel()
 49
 50	shell := config.NewShellVariableResolver(env.NewFromMap(map[string]string{
 51		"MCP_HOST": "mcp.example.com",
 52	}))
 53
 54	t.Run("http success expands $VAR", func(t *testing.T) {
 55		t.Parallel()
 56		m := config.MCPConfig{
 57			Type: config.MCPHttp,
 58			URL:  "https://$MCP_HOST/api",
 59		}
 60		tr, err := createTransport(t.Context(), m, shell)
 61		require.NoError(t, err)
 62		require.NotNil(t, tr)
 63		sct, ok := tr.(*mcp.StreamableClientTransport)
 64		require.True(t, ok, "expected StreamableClientTransport, got %T", tr)
 65		require.Equal(t, "https://mcp.example.com/api", sct.Endpoint)
 66	})
 67
 68	t.Run("sse success expands $(cmd)", func(t *testing.T) {
 69		t.Parallel()
 70		m := config.MCPConfig{
 71			Type: config.MCPSSE,
 72			URL:  "https://$(echo mcp.example.com)/events",
 73		}
 74		tr, err := createTransport(t.Context(), m, shell)
 75		require.NoError(t, err)
 76		sse, ok := tr.(*mcp.SSEClientTransport)
 77		require.True(t, ok, "expected SSEClientTransport, got %T", tr)
 78		require.Equal(t, "https://mcp.example.com/events", sse.Endpoint)
 79	})
 80
 81	t.Run("http unset var surfaces error, no transport created", func(t *testing.T) {
 82		t.Parallel()
 83		m := config.MCPConfig{
 84			Type: config.MCPHttp,
 85			URL:  "https://$MCP_MISSING_HOST/api",
 86		}
 87		tr, err := createTransport(t.Context(), m, shell)
 88		require.Error(t, err)
 89		require.Nil(t, tr)
 90		require.Contains(t, err.Error(), "url:")
 91		require.Contains(t, err.Error(), "$MCP_MISSING_HOST")
 92	})
 93
 94	t.Run("sse failing $(cmd) surfaces error, no transport created", func(t *testing.T) {
 95		t.Parallel()
 96		m := config.MCPConfig{
 97			Type: config.MCPSSE,
 98			URL:  "https://$(false)/events",
 99		}
100		tr, err := createTransport(t.Context(), m, shell)
101		require.Error(t, err)
102		require.Nil(t, tr)
103		require.Contains(t, err.Error(), "url:")
104		require.Contains(t, err.Error(), "$(false)")
105	})
106
107	t.Run("http empty-after-resolve still fails the non-empty guard", func(t *testing.T) {
108		t.Parallel()
109		// ${MCP_EMPTY:-} resolves to the empty string (no error),
110		// then the existing TrimSpace guard in createTransport must
111		// reject it so we never spawn a transport against "".
112		m := config.MCPConfig{
113			Type: config.MCPHttp,
114			URL:  "${MCP_EMPTY:-}",
115		}
116		tr, err := createTransport(t.Context(), m, shell)
117		require.Error(t, err)
118		require.Nil(t, tr)
119		require.Contains(t, err.Error(), "non-empty 'url'")
120	})
121
122	t.Run("identity resolver round-trips template verbatim", func(t *testing.T) {
123		t.Parallel()
124		// Client mode forwards the template to the server; no local
125		// expansion, no error on unset vars.
126		tmpl := "https://$MCP_MISSING_HOST/api"
127		m := config.MCPConfig{Type: config.MCPHttp, URL: tmpl}
128		tr, err := createTransport(t.Context(), m, config.IdentityResolver())
129		require.NoError(t, err)
130		sct, ok := tr.(*mcp.StreamableClientTransport)
131		require.True(t, ok)
132		require.Equal(t, tmpl, sct.Endpoint)
133	})
134}