1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "syscall"
18 "time"
19
20 "github.com/charmbracelet/crush/internal/config"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/home"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/version"
26 "github.com/modelcontextprotocol/go-sdk/mcp"
27)
28
29func parseLevel(level mcp.LoggingLevel) slog.Level {
30 switch level {
31 case "info":
32 return slog.LevelInfo
33 case "notice":
34 return slog.LevelInfo
35 case "warning":
36 return slog.LevelWarn
37 default:
38 return slog.LevelDebug
39 }
40}
41
42// ClientSession wraps an mcp.ClientSession with a context cancel function so
43// that the context created during session establishment is properly cleaned up
44// on close.
45type ClientSession struct {
46 *mcp.ClientSession
47 cancel context.CancelFunc
48}
49
50// Close cancels the session context and then closes the underlying session.
51func (s *ClientSession) Close() error {
52 s.cancel()
53 return s.ClientSession.Close()
54}
55
56var (
57 sessions = csync.NewMap[string, *ClientSession]()
58 states = csync.NewMap[string, ClientInfo]()
59 broker = pubsub.NewBroker[Event]()
60 initOnce sync.Once
61 initDone = make(chan struct{})
62)
63
64// State represents the current state of an MCP client
65type State int
66
67const (
68 StateDisabled State = iota
69 StateStarting
70 StateConnected
71 StateError
72)
73
74func (s State) String() string {
75 switch s {
76 case StateDisabled:
77 return "disabled"
78 case StateStarting:
79 return "starting"
80 case StateConnected:
81 return "connected"
82 case StateError:
83 return "error"
84 default:
85 return "unknown"
86 }
87}
88
89// EventType represents the type of MCP event
90type EventType uint
91
92const (
93 EventStateChanged EventType = iota
94 EventToolsListChanged
95 EventPromptsListChanged
96 EventResourcesListChanged
97)
98
99// Event represents an event in the MCP system
100type Event struct {
101 Type EventType
102 Name string
103 State State
104 Error error
105 Counts Counts
106}
107
108// Counts number of available tools, prompts, etc.
109type Counts struct {
110 Tools int
111 Prompts int
112 Resources int
113}
114
115// ClientInfo holds information about an MCP client's state
116type ClientInfo struct {
117 Name string
118 State State
119 Error error
120 Client *ClientSession
121 Counts Counts
122 ConnectedAt time.Time
123}
124
125// SubscribeEvents returns a channel for MCP events
126func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
127 return broker.Subscribe(ctx)
128}
129
130// GetStates returns the current state of all MCP clients
131func GetStates() map[string]ClientInfo {
132 return states.Copy()
133}
134
135// GetState returns the state of a specific MCP client
136func GetState(name string) (ClientInfo, bool) {
137 return states.Get(name)
138}
139
140// isIgnorableCloseErr returns true for errors that are expected during MCP
141// session shutdown and can be safely suppressed.
142func isIgnorableCloseErr(err error) bool {
143 return err == nil ||
144 errors.Is(err, io.EOF) ||
145 errors.Is(err, context.Canceled) ||
146 isKilledErr(err)
147}
148
149// isKilledErr returns true if the error is an exec.ExitError caused by
150// SIGKILL.
151func isKilledErr(err error) bool {
152 var exitErr *exec.ExitError
153 if !errors.As(err, &exitErr) {
154 return false
155 }
156 ws, ok := exitErr.Sys().(syscall.WaitStatus)
157 return ok && ws.Signaled() && ws.Signal() == syscall.SIGKILL
158}
159
160// Close closes all MCP clients. This should be called during application shutdown.
161func Close(ctx context.Context) error {
162 var wg sync.WaitGroup
163 for name, session := range sessions.Seq2() {
164 wg.Go(func() {
165 done := make(chan error, 1)
166 go func() {
167 done <- session.Close()
168 }()
169 select {
170 case err := <-done:
171 if !isIgnorableCloseErr(err) {
172 slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
173 }
174 case <-ctx.Done():
175 }
176 })
177 }
178 wg.Wait()
179 broker.Shutdown()
180 return nil
181}
182
183// Initialize initializes MCP clients based on the provided configuration.
184func Initialize(ctx context.Context, permissions permission.Service, cfg *config.ConfigStore) {
185 slog.Info("Initializing MCP clients")
186 var wg sync.WaitGroup
187 // Initialize states for all configured MCPs
188 for name, m := range cfg.Config().MCP {
189 if m.Disabled {
190 updateState(name, StateDisabled, nil, nil, Counts{})
191 slog.Debug("Skipping disabled MCP", "name", name)
192 continue
193 }
194
195 // Set initial starting state
196 wg.Add(1)
197 go func(name string, m config.MCPConfig) {
198 defer func() {
199 wg.Done()
200 if r := recover(); r != nil {
201 var err error
202 switch v := r.(type) {
203 case error:
204 err = v
205 case string:
206 err = fmt.Errorf("panic: %s", v)
207 default:
208 err = fmt.Errorf("panic: %v", v)
209 }
210 updateState(name, StateError, err, nil, Counts{})
211 slog.Error("Panic in MCP client initialization", "error", err, "name", name)
212 }
213 }()
214
215 if err := initClient(ctx, cfg, name, m, cfg.Resolver()); err != nil {
216 slog.Debug("Failed to initialize MCP client", "name", name, "error", err)
217 }
218 }(name, m)
219 }
220 wg.Wait()
221 initOnce.Do(func() { close(initDone) })
222}
223
224// WaitForInit blocks until MCP initialization is complete.
225// If Initialize was never called, this returns immediately.
226func WaitForInit(ctx context.Context) error {
227 select {
228 case <-initDone:
229 return nil
230 case <-ctx.Done():
231 return ctx.Err()
232 }
233}
234
235// InitializeSingle initializes a single MCP client by name.
236func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore) error {
237 m, exists := cfg.Config().MCP[name]
238 if !exists {
239 return fmt.Errorf("mcp '%s' not found in configuration", name)
240 }
241
242 if m.Disabled {
243 updateState(name, StateDisabled, nil, nil, Counts{})
244 slog.Debug("Skipping disabled MCP", "name", name)
245 return nil
246 }
247
248 return initClient(ctx, cfg, name, m, cfg.Resolver())
249}
250
251// initClient initializes a single MCP client with the given configuration.
252func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m config.MCPConfig, resolver config.VariableResolver) error {
253 updateState(name, StateStarting, nil, nil, Counts{})
254
255 session, err := createSession(ctx, name, m, resolver)
256 if err != nil {
257 return err
258 }
259
260 tools, err := getTools(ctx, session)
261 if err != nil {
262 slog.Error("Error listing tools", "name", name, "error", err)
263 updateState(name, StateError, err, nil, Counts{})
264 closeSessionOnInitError(name, session)
265 return err
266 }
267
268 prompts, err := getPrompts(ctx, session)
269 if err != nil {
270 slog.Error("Error listing prompts", "name", name, "error", err)
271 updateState(name, StateError, err, nil, Counts{})
272 closeSessionOnInitError(name, session)
273 return err
274 }
275
276 resources, err := getResources(ctx, session)
277 if err != nil {
278 slog.Error("Error listing resources", "name", name, "error", err)
279 updateState(name, StateError, err, nil, Counts{})
280 closeSessionOnInitError(name, session)
281 return err
282 }
283
284 toolCount := updateTools(cfg, name, tools)
285 updatePrompts(name, prompts)
286 resourceCount := updateResources(name, resources)
287 sessions.Set(name, session)
288
289 updateState(name, StateConnected, nil, session, Counts{
290 Tools: toolCount,
291 Prompts: len(prompts),
292 Resources: resourceCount,
293 })
294
295 return nil
296}
297
298// closeSessionOnInitError closes a session that failed during initialization,
299// suppressing expected shutdown errors. Uses a fixed timeout to avoid blocking
300// indefinitely if the parent context has no deadline.
301//
302// On timeout the Close goroutine may outlive this function, but since
303// session.Close cancels the session context internally, it will unblock
304// shortly after.
305func closeSessionOnInitError(name string, session *ClientSession) {
306 const closeTimeout = 5 * time.Second
307 ctx, cancel := context.WithTimeout(context.Background(), closeTimeout)
308 defer cancel()
309
310 done := make(chan error, 1)
311 go func() {
312 done <- session.Close()
313 }()
314
315 select {
316 case err := <-done:
317 if !isIgnorableCloseErr(err) {
318 slog.Warn("Failed to close MCP session after init error", "name", name, "error", err)
319 }
320 case <-ctx.Done():
321 slog.Warn("Timed out waiting to close MCP session after init error", "name", name, "error", ctx.Err())
322 }
323}
324
325// DisableSingle disables and closes a single MCP client by name.
326func DisableSingle(cfg *config.ConfigStore, name string) {
327 session, ok := sessions.Get(name)
328 if ok {
329 if err := session.Close(); !isIgnorableCloseErr(err) {
330 slog.Warn("Error closing MCP session", "name", name, "error", err)
331 }
332 sessions.Del(name)
333 }
334
335 // Clear tools, prompts, and resources for this MCP.
336 updateTools(cfg, name, nil)
337 updatePrompts(name, nil)
338 updateResources(name, nil)
339
340 // Update state to disabled.
341 updateState(name, StateDisabled, nil, nil, Counts{})
342
343 slog.Info("Disabled MCP client", "name", name)
344}
345
346func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) {
347 sess, ok := sessions.Get(name)
348 if !ok {
349 return nil, fmt.Errorf("mcp '%s' not available", name)
350 }
351
352 m := cfg.Config().MCP[name]
353 state, _ := states.Get(name)
354
355 timeout := mcpTimeout(m)
356 pingCtx, cancel := context.WithTimeout(ctx, timeout)
357 defer cancel()
358 err := sess.Ping(pingCtx, nil)
359 if err == nil {
360 return sess, nil
361 }
362 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
363
364 sess, err = createSession(ctx, name, m, cfg.Resolver())
365 if err != nil {
366 return nil, err
367 }
368
369 updateState(name, StateConnected, nil, sess, state.Counts)
370 sessions.Set(name, sess)
371 return sess, nil
372}
373
374// updateState updates the state of an MCP client and publishes an event
375func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
376 info := ClientInfo{
377 Name: name,
378 State: state,
379 Error: err,
380 Client: client,
381 Counts: counts,
382 }
383 switch state {
384 case StateConnected:
385 info.ConnectedAt = time.Now()
386 case StateError:
387 sessions.Del(name)
388 }
389 states.Set(name, info)
390
391 // Publish state change event
392 broker.Publish(pubsub.UpdatedEvent, Event{
393 Type: EventStateChanged,
394 Name: name,
395 State: state,
396 Error: err,
397 Counts: counts,
398 })
399}
400
401func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
402 timeout := mcpTimeout(m)
403 mcpCtx, cancel := context.WithCancel(ctx)
404 cancelTimer := time.AfterFunc(timeout, cancel)
405
406 transport, err := createTransport(mcpCtx, m, resolver)
407 if err != nil {
408 updateState(name, StateError, err, nil, Counts{})
409 slog.Error("Error creating MCP client", "error", err, "name", name)
410 cancel()
411 cancelTimer.Stop()
412 return nil, err
413 }
414
415 client := mcp.NewClient(
416 &mcp.Implementation{
417 Name: "crush",
418 Version: version.Version,
419 Title: "Crush",
420 },
421 &mcp.ClientOptions{
422 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
423 broker.Publish(pubsub.UpdatedEvent, Event{
424 Type: EventToolsListChanged,
425 Name: name,
426 })
427 },
428 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
429 broker.Publish(pubsub.UpdatedEvent, Event{
430 Type: EventPromptsListChanged,
431 Name: name,
432 })
433 },
434 ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
435 broker.Publish(pubsub.UpdatedEvent, Event{
436 Type: EventResourcesListChanged,
437 Name: name,
438 })
439 },
440 LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
441 level := parseLevel(req.Params.Level)
442 slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
443 },
444 },
445 )
446
447 session, err := client.Connect(mcpCtx, transport, nil)
448 if err != nil {
449 err = maybeStdioErr(err, transport)
450 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
451 slog.Error("MCP client failed to initialize", "error", err, "name", name)
452 cancel()
453 cancelTimer.Stop()
454 return nil, err
455 }
456
457 cancelTimer.Stop()
458 slog.Debug("MCP client initialized", "name", name)
459 return &ClientSession{session, cancel}, nil
460}
461
462// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
463// to parse, and the cli will then close it, causing the EOF error.
464// so, if we got an EOF err, and the transport is STDIO, we try to exec it
465// again with a timeout and collect the output so we can add details to the
466// error.
467// this happens particularly when starting things with npx, e.g. if node can't
468// be found or some other error like that.
469func maybeStdioErr(err error, transport mcp.Transport) error {
470 if !errors.Is(err, io.EOF) {
471 return err
472 }
473 ct, ok := transport.(*mcp.CommandTransport)
474 if !ok {
475 return err
476 }
477 if err2 := stdioCheck(ct.Command); err2 != nil {
478 err = errors.Join(err, err2)
479 }
480 return err
481}
482
483func maybeTimeoutErr(err error, timeout time.Duration) error {
484 if errors.Is(err, context.Canceled) {
485 return fmt.Errorf("timed out after %s", timeout)
486 }
487 return err
488}
489
490func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
491 switch m.Type {
492 case config.MCPStdio:
493 command, err := resolver.ResolveValue(m.Command)
494 if err != nil {
495 return nil, fmt.Errorf("invalid mcp command: %w", err)
496 }
497 if strings.TrimSpace(command) == "" {
498 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
499 }
500 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
501 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
502 return &mcp.CommandTransport{
503 Command: cmd,
504 }, nil
505 case config.MCPHttp:
506 if strings.TrimSpace(m.URL) == "" {
507 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
508 }
509 client := &http.Client{
510 Transport: &headerRoundTripper{
511 headers: m.ResolvedHeaders(),
512 },
513 }
514 return &mcp.StreamableClientTransport{
515 Endpoint: m.URL,
516 HTTPClient: client,
517 }, nil
518 case config.MCPSSE:
519 if strings.TrimSpace(m.URL) == "" {
520 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
521 }
522 client := &http.Client{
523 Transport: &headerRoundTripper{
524 headers: m.ResolvedHeaders(),
525 },
526 }
527 return &mcp.SSEClientTransport{
528 Endpoint: m.URL,
529 HTTPClient: client,
530 }, nil
531 default:
532 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
533 }
534}
535
536type headerRoundTripper struct {
537 headers map[string]string
538}
539
540func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
541 for k, v := range rt.headers {
542 req.Header.Set(k, v)
543 }
544 return http.DefaultTransport.RoundTrip(req)
545}
546
547func mcpTimeout(m config.MCPConfig) time.Duration {
548 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
549}
550
551func stdioCheck(old *exec.Cmd) error {
552 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
553 defer cancel()
554 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
555 cmd.Env = old.Env
556 out, err := cmd.CombinedOutput()
557 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
558 return nil
559 }
560 return fmt.Errorf("%w: %s", err, string(out))
561}