1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28func parseLevel(level mcp.LoggingLevel) slog.Level {
29 switch level {
30 case "info":
31 return slog.LevelInfo
32 case "notice":
33 return slog.LevelInfo
34 case "warning":
35 return slog.LevelWarn
36 default:
37 return slog.LevelDebug
38 }
39}
40
41// ClientSession wraps an mcp.ClientSession with a context cancel function so
42// that the context created during session establishment is properly cleaned up
43// on close.
44type ClientSession struct {
45 *mcp.ClientSession
46 cancel context.CancelFunc
47}
48
49// Close cancels the session context and then closes the underlying session.
50func (s *ClientSession) Close() error {
51 s.cancel()
52 return s.ClientSession.Close()
53}
54
55var (
56 sessions = csync.NewMap[string, *ClientSession]()
57 states = csync.NewMap[string, ClientInfo]()
58 broker = pubsub.NewBroker[Event]()
59 initOnce sync.Once
60 initDone = make(chan struct{})
61)
62
63// State represents the current state of an MCP client
64type State int
65
66const (
67 StateDisabled State = iota
68 StateStarting
69 StateConnected
70 StateError
71)
72
73func (s State) String() string {
74 switch s {
75 case StateDisabled:
76 return "disabled"
77 case StateStarting:
78 return "starting"
79 case StateConnected:
80 return "connected"
81 case StateError:
82 return "error"
83 default:
84 return "unknown"
85 }
86}
87
88// EventType represents the type of MCP event
89type EventType uint
90
91const (
92 EventStateChanged EventType = iota
93 EventToolsListChanged
94 EventPromptsListChanged
95 EventResourcesListChanged
96)
97
98// Event represents an event in the MCP system
99type Event struct {
100 Type EventType
101 Name string
102 State State
103 Error error
104 Counts Counts
105}
106
107// Counts number of available tools, prompts, etc.
108type Counts struct {
109 Tools int
110 Prompts int
111 Resources int
112}
113
114// ClientInfo holds information about an MCP client's state
115type ClientInfo struct {
116 Name string
117 State State
118 Error error
119 Client *ClientSession
120 Counts Counts
121 ConnectedAt time.Time
122}
123
124// SubscribeEvents returns a channel for MCP events
125func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
126 return broker.Subscribe(ctx)
127}
128
129// GetStates returns the current state of all MCP clients
130func GetStates() map[string]ClientInfo {
131 return states.Copy()
132}
133
134// GetState returns the state of a specific MCP client
135func GetState(name string) (ClientInfo, bool) {
136 return states.Get(name)
137}
138
139// Close closes all MCP clients. This should be called during application shutdown.
140func Close(ctx context.Context) error {
141 var wg sync.WaitGroup
142 for name, session := range sessions.Seq2() {
143 wg.Go(func() {
144 done := make(chan error, 1)
145 go func() {
146 done <- session.Close()
147 }()
148 select {
149 case err := <-done:
150 if err != nil &&
151 !errors.Is(err, io.EOF) &&
152 !errors.Is(err, context.Canceled) &&
153 err.Error() != "signal: killed" {
154 slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
155 }
156 case <-ctx.Done():
157 }
158 })
159 }
160 wg.Wait()
161 broker.Shutdown()
162 return nil
163}
164
165// Initialize initializes MCP clients based on the provided configuration.
166func Initialize(ctx context.Context, permissions permission.Service, cfg *config.ConfigStore) {
167 slog.Info("Initializing MCP clients")
168 var wg sync.WaitGroup
169 // Initialize states for all configured MCPs
170 for name, m := range cfg.Config().MCP {
171 if m.Disabled {
172 updateState(name, StateDisabled, nil, nil, Counts{})
173 slog.Debug("Skipping disabled MCP", "name", name)
174 continue
175 }
176
177 // Set initial starting state
178 wg.Add(1)
179 go func(name string, m config.MCPConfig) {
180 defer func() {
181 wg.Done()
182 if r := recover(); r != nil {
183 var err error
184 switch v := r.(type) {
185 case error:
186 err = v
187 case string:
188 err = fmt.Errorf("panic: %s", v)
189 default:
190 err = fmt.Errorf("panic: %v", v)
191 }
192 updateState(name, StateError, err, nil, Counts{})
193 slog.Error("Panic in MCP client initialization", "error", err, "name", name)
194 }
195 }()
196
197 if err := initClient(ctx, cfg, name, m, cfg.Resolver()); err != nil {
198 slog.Debug("failed to initialize mcp client", "name", name, "error", err)
199 }
200 }(name, m)
201 }
202 wg.Wait()
203 initOnce.Do(func() { close(initDone) })
204}
205
206// WaitForInit blocks until MCP initialization is complete.
207// If Initialize was never called, this returns immediately.
208func WaitForInit(ctx context.Context) error {
209 select {
210 case <-initDone:
211 return nil
212 case <-ctx.Done():
213 return ctx.Err()
214 }
215}
216
217// InitializeSingle initializes a single MCP client by name.
218func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore) error {
219 m, exists := cfg.Config().MCP[name]
220 if !exists {
221 return fmt.Errorf("mcp '%s' not found in configuration", name)
222 }
223
224 if m.Disabled {
225 updateState(name, StateDisabled, nil, nil, Counts{})
226 slog.Debug("skipping disabled mcp", "name", name)
227 return nil
228 }
229
230 return initClient(ctx, cfg, name, m, cfg.Resolver())
231}
232
233// initClient initializes a single MCP client with the given configuration.
234func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m config.MCPConfig, resolver config.VariableResolver) error {
235 // Set initial starting state.
236 updateState(name, StateStarting, nil, nil, Counts{})
237
238 // createSession handles its own timeout internally.
239 session, err := createSession(ctx, name, m, resolver)
240 if err != nil {
241 return err
242 }
243
244 tools, err := getTools(ctx, session)
245 if err != nil {
246 slog.Error("Error listing tools", "error", err)
247 updateState(name, StateError, err, nil, Counts{})
248 session.Close()
249 return err
250 }
251
252 prompts, err := getPrompts(ctx, session)
253 if err != nil {
254 slog.Error("Error listing prompts", "error", err)
255 updateState(name, StateError, err, nil, Counts{})
256 session.Close()
257 return err
258 }
259
260 toolCount := updateTools(cfg, name, tools)
261 updatePrompts(name, prompts)
262 sessions.Set(name, session)
263
264 updateState(name, StateConnected, nil, session, Counts{
265 Tools: toolCount,
266 Prompts: len(prompts),
267 })
268
269 return nil
270}
271
272// DisableSingle disables and closes a single MCP client by name.
273func DisableSingle(cfg *config.ConfigStore, name string) error {
274 session, ok := sessions.Get(name)
275 if ok {
276 if err := session.Close(); err != nil &&
277 !errors.Is(err, io.EOF) &&
278 !errors.Is(err, context.Canceled) &&
279 err.Error() != "signal: killed" {
280 slog.Warn("error closing mcp session", "name", name, "error", err)
281 }
282 sessions.Del(name)
283 }
284
285 // Clear tools and prompts for this MCP.
286 updateTools(cfg, name, nil)
287 updatePrompts(name, nil)
288
289 // Update state to disabled.
290 updateState(name, StateDisabled, nil, nil, Counts{})
291
292 slog.Info("Disabled mcp client", "name", name)
293 return nil
294}
295
296func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) {
297 sess, ok := sessions.Get(name)
298 if !ok {
299 return nil, fmt.Errorf("mcp '%s' not available", name)
300 }
301
302 m := cfg.Config().MCP[name]
303 state, _ := states.Get(name)
304
305 timeout := mcpTimeout(m)
306 pingCtx, cancel := context.WithTimeout(ctx, timeout)
307 defer cancel()
308 err := sess.Ping(pingCtx, nil)
309 if err == nil {
310 return sess, nil
311 }
312 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
313
314 sess, err = createSession(ctx, name, m, cfg.Resolver())
315 if err != nil {
316 return nil, err
317 }
318
319 updateState(name, StateConnected, nil, sess, state.Counts)
320 sessions.Set(name, sess)
321 return sess, nil
322}
323
324// updateState updates the state of an MCP client and publishes an event
325func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
326 info := ClientInfo{
327 Name: name,
328 State: state,
329 Error: err,
330 Client: client,
331 Counts: counts,
332 }
333 switch state {
334 case StateConnected:
335 info.ConnectedAt = time.Now()
336 case StateError:
337 sessions.Del(name)
338 }
339 states.Set(name, info)
340
341 // Publish state change event
342 broker.Publish(pubsub.UpdatedEvent, Event{
343 Type: EventStateChanged,
344 Name: name,
345 State: state,
346 Error: err,
347 Counts: counts,
348 })
349}
350
351func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
352 timeout := mcpTimeout(m)
353 mcpCtx, cancel := context.WithCancel(ctx)
354 cancelTimer := time.AfterFunc(timeout, cancel)
355
356 transport, err := createTransport(mcpCtx, m, resolver)
357 if err != nil {
358 updateState(name, StateError, err, nil, Counts{})
359 slog.Error("Error creating MCP client", "error", err, "name", name)
360 cancel()
361 cancelTimer.Stop()
362 return nil, err
363 }
364
365 client := mcp.NewClient(
366 &mcp.Implementation{
367 Name: "crush",
368 Version: version.Version,
369 Title: "Crush",
370 },
371 &mcp.ClientOptions{
372 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
373 broker.Publish(pubsub.UpdatedEvent, Event{
374 Type: EventToolsListChanged,
375 Name: name,
376 })
377 },
378 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
379 broker.Publish(pubsub.UpdatedEvent, Event{
380 Type: EventPromptsListChanged,
381 Name: name,
382 })
383 },
384 ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
385 broker.Publish(pubsub.UpdatedEvent, Event{
386 Type: EventResourcesListChanged,
387 Name: name,
388 })
389 },
390 LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
391 level := parseLevel(req.Params.Level)
392 slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
393 },
394 },
395 )
396
397 session, err := client.Connect(mcpCtx, transport, nil)
398 if err != nil {
399 err = maybeStdioErr(err, transport)
400 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
401 slog.Error("MCP client failed to initialize", "error", err, "name", name)
402 cancel()
403 cancelTimer.Stop()
404 return nil, err
405 }
406
407 cancelTimer.Stop()
408 slog.Debug("MCP client initialized", "name", name)
409 return &ClientSession{session, cancel}, nil
410}
411
412// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
413// to parse, and the cli will then close it, causing the EOF error.
414// so, if we got an EOF err, and the transport is STDIO, we try to exec it
415// again with a timeout and collect the output so we can add details to the
416// error.
417// this happens particularly when starting things with npx, e.g. if node can't
418// be found or some other error like that.
419func maybeStdioErr(err error, transport mcp.Transport) error {
420 if !errors.Is(err, io.EOF) {
421 return err
422 }
423 ct, ok := transport.(*mcp.CommandTransport)
424 if !ok {
425 return err
426 }
427 if err2 := stdioCheck(ct.Command); err2 != nil {
428 err = errors.Join(err, err2)
429 }
430 return err
431}
432
433func maybeTimeoutErr(err error, timeout time.Duration) error {
434 if errors.Is(err, context.Canceled) {
435 return fmt.Errorf("timed out after %s", timeout)
436 }
437 return err
438}
439
440func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
441 switch m.Type {
442 case config.MCPStdio:
443 command, err := resolver.ResolveValue(m.Command)
444 if err != nil {
445 return nil, fmt.Errorf("invalid mcp command: %w", err)
446 }
447 if strings.TrimSpace(command) == "" {
448 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
449 }
450 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
451 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
452 return &mcp.CommandTransport{
453 Command: cmd,
454 }, nil
455 case config.MCPHttp:
456 if strings.TrimSpace(m.URL) == "" {
457 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
458 }
459 client := &http.Client{
460 Transport: &headerRoundTripper{
461 headers: m.ResolvedHeaders(),
462 },
463 }
464 return &mcp.StreamableClientTransport{
465 Endpoint: m.URL,
466 HTTPClient: client,
467 }, nil
468 case config.MCPSSE:
469 if strings.TrimSpace(m.URL) == "" {
470 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
471 }
472 client := &http.Client{
473 Transport: &headerRoundTripper{
474 headers: m.ResolvedHeaders(),
475 },
476 }
477 return &mcp.SSEClientTransport{
478 Endpoint: m.URL,
479 HTTPClient: client,
480 }, nil
481 default:
482 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
483 }
484}
485
486type headerRoundTripper struct {
487 headers map[string]string
488}
489
490func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
491 for k, v := range rt.headers {
492 req.Header.Set(k, v)
493 }
494 return http.DefaultTransport.RoundTrip(req)
495}
496
497func mcpTimeout(m config.MCPConfig) time.Duration {
498 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
499}
500
501func stdioCheck(old *exec.Cmd) error {
502 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
503 defer cancel()
504 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
505 cmd.Env = old.Env
506 out, err := cmd.CombinedOutput()
507 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
508 return nil
509 }
510 return fmt.Errorf("%w: %s", err, string(out))
511}