1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28var (
29 sessions = csync.NewMap[string, *mcp.ClientSession]()
30 states = csync.NewMap[string, ClientInfo]()
31 broker = pubsub.NewBroker[Event]()
32 initOnce sync.Once
33 initDone = make(chan struct{})
34)
35
36// State represents the current state of an MCP client
37type State int
38
39const (
40 StateDisabled State = iota
41 StateStarting
42 StateConnected
43 StateError
44)
45
46func (s State) String() string {
47 switch s {
48 case StateDisabled:
49 return "disabled"
50 case StateStarting:
51 return "starting"
52 case StateConnected:
53 return "connected"
54 case StateError:
55 return "error"
56 default:
57 return "unknown"
58 }
59}
60
61// EventType represents the type of MCP event
62type EventType uint
63
64const (
65 EventStateChanged EventType = iota
66 EventToolsListChanged
67 EventPromptsListChanged
68)
69
70// Event represents an event in the MCP system
71type Event struct {
72 Type EventType
73 Name string
74 State State
75 Error error
76 Counts Counts
77}
78
79// Counts number of available tools, prompts, etc.
80type Counts struct {
81 Tools int
82 Prompts int
83}
84
85// ClientInfo holds information about an MCP client's state
86type ClientInfo struct {
87 Name string
88 State State
89 Error error
90 Client *mcp.ClientSession
91 Counts Counts
92 ConnectedAt time.Time
93}
94
95// AddEventListener registers a callback for MCP events.
96func AddEventListener(fn func(pubsub.Event[Event])) {
97 broker.AddListener(fn)
98}
99
100// GetStates returns the current state of all MCP clients
101func GetStates() map[string]ClientInfo {
102 return states.Copy()
103}
104
105// GetState returns the state of a specific MCP client
106func GetState(name string) (ClientInfo, bool) {
107 return states.Get(name)
108}
109
110// Close closes all MCP clients. This should be called during application shutdown.
111func Close() error {
112 var wg sync.WaitGroup
113 done := make(chan struct{}, 1)
114 go func() {
115 for name, session := range sessions.Seq2() {
116 wg.Go(func() {
117 if err := session.Close(); err != nil &&
118 !errors.Is(err, io.EOF) &&
119 !errors.Is(err, context.Canceled) &&
120 err.Error() != "signal: killed" {
121 slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
122 }
123 })
124 }
125 wg.Wait()
126 done <- struct{}{}
127 }()
128 select {
129 case <-done:
130 case <-time.After(5 * time.Second):
131 }
132 return nil
133}
134
135// Initialize initializes MCP clients based on the provided configuration.
136func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
137 var wg sync.WaitGroup
138 // Initialize states for all configured MCPs
139 for name, m := range cfg.MCP {
140 if m.Disabled {
141 updateState(ctx, name, StateDisabled, nil, nil, Counts{})
142 slog.Debug("skipping disabled mcp", "name", name)
143 continue
144 }
145
146 // Set initial starting state
147 updateState(ctx, name, StateStarting, nil, nil, Counts{})
148
149 wg.Add(1)
150 go func(name string, m config.MCPConfig) {
151 defer func() {
152 wg.Done()
153 if r := recover(); r != nil {
154 var err error
155 switch v := r.(type) {
156 case error:
157 err = v
158 case string:
159 err = fmt.Errorf("panic: %s", v)
160 default:
161 err = fmt.Errorf("panic: %v", v)
162 }
163 updateState(ctx, name, StateError, err, nil, Counts{})
164 slog.Error("panic in mcp client initialization", "error", err, "name", name)
165 }
166 }()
167
168 // createSession handles its own timeout internally.
169 session, err := createSession(ctx, name, m, cfg.Resolver())
170 if err != nil {
171 return
172 }
173
174 tools, err := getTools(ctx, session)
175 if err != nil {
176 slog.Error("error listing tools", "error", err)
177 updateState(ctx, name, StateError, err, nil, Counts{})
178 session.Close()
179 return
180 }
181
182 prompts, err := getPrompts(ctx, session)
183 if err != nil {
184 slog.Error("error listing prompts", "error", err)
185 updateState(ctx, name, StateError, err, nil, Counts{})
186 session.Close()
187 return
188 }
189
190 toolCount := updateTools(name, tools)
191 updatePrompts(name, prompts)
192 sessions.Set(name, session)
193
194 updateState(ctx, name, StateConnected, nil, session, Counts{
195 Tools: toolCount,
196 Prompts: len(prompts),
197 })
198 }(name, m)
199 }
200 wg.Wait()
201 initOnce.Do(func() { close(initDone) })
202}
203
204// WaitForInit blocks until MCP initialization is complete.
205// If Initialize was never called, this returns immediately.
206func WaitForInit(ctx context.Context) error {
207 select {
208 case <-initDone:
209 return nil
210 case <-ctx.Done():
211 return ctx.Err()
212 }
213}
214
215func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
216 sess, ok := sessions.Get(name)
217 if !ok {
218 return nil, fmt.Errorf("mcp '%s' not available", name)
219 }
220
221 cfg := config.Get()
222 m := cfg.MCP[name]
223 state, _ := states.Get(name)
224
225 timeout := mcpTimeout(m)
226 pingCtx, cancel := context.WithTimeout(ctx, timeout)
227 defer cancel()
228 err := sess.Ping(pingCtx, nil)
229 if err == nil {
230 return sess, nil
231 }
232 updateState(ctx, name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
233
234 sess, err = createSession(ctx, name, m, cfg.Resolver())
235 if err != nil {
236 return nil, err
237 }
238
239 updateState(ctx, name, StateConnected, nil, sess, state.Counts)
240 sessions.Set(name, sess)
241 return sess, nil
242}
243
244// updateState updates the state of an MCP client and publishes an event
245func updateState(ctx context.Context, name string, state State, err error, client *mcp.ClientSession, counts Counts) {
246 info := ClientInfo{
247 Name: name,
248 State: state,
249 Error: err,
250 Client: client,
251 Counts: counts,
252 }
253 switch state {
254 case StateConnected:
255 info.ConnectedAt = time.Now()
256 case StateError:
257 sessions.Del(name)
258 }
259 states.Set(name, info)
260
261 // Publish state change event
262 broker.Publish(ctx, pubsub.UpdatedEvent, Event{
263 Type: EventStateChanged,
264 Name: name,
265 State: state,
266 Error: err,
267 Counts: counts,
268 })
269}
270
271func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
272 timeout := mcpTimeout(m)
273 mcpCtx, cancel := context.WithCancel(ctx)
274 cancelTimer := time.AfterFunc(timeout, cancel)
275
276 transport, err := createTransport(mcpCtx, m, resolver)
277 if err != nil {
278 updateState(ctx, name, StateError, err, nil, Counts{})
279 slog.Error("error creating mcp client", "error", err, "name", name)
280 cancel()
281 cancelTimer.Stop()
282 return nil, err
283 }
284
285 client := mcp.NewClient(
286 &mcp.Implementation{
287 Name: "crush",
288 Version: version.Version,
289 Title: "Crush",
290 },
291 &mcp.ClientOptions{
292 ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) {
293 broker.Publish(ctx, pubsub.UpdatedEvent, Event{
294 Type: EventToolsListChanged,
295 Name: name,
296 })
297 },
298 PromptListChangedHandler: func(ctx context.Context, _ *mcp.PromptListChangedRequest) {
299 broker.Publish(ctx, pubsub.UpdatedEvent, Event{
300 Type: EventPromptsListChanged,
301 Name: name,
302 })
303 },
304 LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
305 slog.Info("MCP log", "name", name, "data", req.Params.Data)
306 },
307 },
308 )
309
310 session, err := client.Connect(mcpCtx, transport, nil)
311 if err != nil {
312 err = maybeStdioErr(err, transport)
313 updateState(ctx, name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
314 slog.Error("MCP client failed to initialize", "error", err, "name", name)
315 cancel()
316 cancelTimer.Stop()
317 return nil, err
318 }
319
320 cancelTimer.Stop()
321 slog.Info("MCP client initialized", "name", name)
322 return session, nil
323}
324
325// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
326// to parse, and the cli will then close it, causing the EOF error.
327// so, if we got an EOF err, and the transport is STDIO, we try to exec it
328// again with a timeout and collect the output so we can add details to the
329// error.
330// this happens particularly when starting things with npx, e.g. if node can't
331// be found or some other error like that.
332func maybeStdioErr(err error, transport mcp.Transport) error {
333 if !errors.Is(err, io.EOF) {
334 return err
335 }
336 ct, ok := transport.(*mcp.CommandTransport)
337 if !ok {
338 return err
339 }
340 if err2 := stdioCheck(ct.Command); err2 != nil {
341 err = errors.Join(err, err2)
342 }
343 return err
344}
345
346func maybeTimeoutErr(err error, timeout time.Duration) error {
347 if errors.Is(err, context.Canceled) {
348 return fmt.Errorf("timed out after %s", timeout)
349 }
350 return err
351}
352
353func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
354 switch m.Type {
355 case config.MCPStdio:
356 command, err := resolver.ResolveValue(m.Command)
357 if err != nil {
358 return nil, fmt.Errorf("invalid mcp command: %w", err)
359 }
360 if strings.TrimSpace(command) == "" {
361 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
362 }
363 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
364 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
365 return &mcp.CommandTransport{
366 Command: cmd,
367 }, nil
368 case config.MCPHttp:
369 if strings.TrimSpace(m.URL) == "" {
370 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
371 }
372 client := &http.Client{
373 Transport: &headerRoundTripper{
374 headers: m.ResolvedHeaders(),
375 },
376 }
377 return &mcp.StreamableClientTransport{
378 Endpoint: m.URL,
379 HTTPClient: client,
380 }, nil
381 case config.MCPSSE:
382 if strings.TrimSpace(m.URL) == "" {
383 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
384 }
385 client := &http.Client{
386 Transport: &headerRoundTripper{
387 headers: m.ResolvedHeaders(),
388 },
389 }
390 return &mcp.SSEClientTransport{
391 Endpoint: m.URL,
392 HTTPClient: client,
393 }, nil
394 default:
395 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
396 }
397}
398
399type headerRoundTripper struct {
400 headers map[string]string
401}
402
403func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
404 for k, v := range rt.headers {
405 req.Header.Set(k, v)
406 }
407 return http.DefaultTransport.RoundTrip(req)
408}
409
410func mcpTimeout(m config.MCPConfig) time.Duration {
411 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
412}
413
414func stdioCheck(old *exec.Cmd) error {
415 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
416 defer cancel()
417 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
418 cmd.Env = old.Env
419 out, err := cmd.CombinedOutput()
420 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
421 return nil
422 }
423 return fmt.Errorf("%w: %s", err, string(out))
424}