1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28func parseLevel(level mcp.LoggingLevel) slog.Level {
29 switch level {
30 case "info":
31 return slog.LevelInfo
32 case "notice":
33 return slog.LevelInfo
34 case "warning":
35 return slog.LevelWarn
36 default:
37 return slog.LevelDebug
38 }
39}
40
41// ClientSession wraps an mcp.ClientSession with a context cancel function so
42// that the context created during session establishment is properly cleaned up
43// on close.
44type ClientSession struct {
45 *mcp.ClientSession
46 cancel context.CancelFunc
47}
48
49// Close cancels the session context and then closes the underlying session.
50func (s *ClientSession) Close() error {
51 s.cancel()
52 return s.ClientSession.Close()
53}
54
55var (
56 sessions = csync.NewMap[string, *ClientSession]()
57 states = csync.NewMap[string, ClientInfo]()
58 broker = pubsub.NewBroker[Event]()
59 initOnce sync.Once
60 initDone = make(chan struct{})
61)
62
63// State represents the current state of an MCP client
64type State int
65
66const (
67 StateDisabled State = iota
68 StateStarting
69 StateConnected
70 StateError
71)
72
73func (s State) String() string {
74 switch s {
75 case StateDisabled:
76 return "disabled"
77 case StateStarting:
78 return "starting"
79 case StateConnected:
80 return "connected"
81 case StateError:
82 return "error"
83 default:
84 return "unknown"
85 }
86}
87
88// EventType represents the type of MCP event
89type EventType uint
90
91const (
92 EventStateChanged EventType = iota
93 EventToolsListChanged
94 EventPromptsListChanged
95 EventResourcesListChanged
96)
97
98// Event represents an event in the MCP system
99type Event struct {
100 Type EventType
101 Name string
102 State State
103 Error error
104 Counts Counts
105}
106
107// Counts number of available tools, prompts, etc.
108type Counts struct {
109 Tools int
110 Prompts int
111 Resources int
112}
113
114// ClientInfo holds information about an MCP client's state
115type ClientInfo struct {
116 Name string
117 State State
118 Error error
119 Client *ClientSession
120 Counts Counts
121 ConnectedAt time.Time
122}
123
124// SubscribeEvents returns a channel for MCP events
125func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
126 return broker.Subscribe(ctx)
127}
128
129// GetStates returns the current state of all MCP clients
130func GetStates() map[string]ClientInfo {
131 return states.Copy()
132}
133
134// GetState returns the state of a specific MCP client
135func GetState(name string) (ClientInfo, bool) {
136 return states.Get(name)
137}
138
139// Close closes all MCP clients. This should be called during application shutdown.
140func Close(ctx context.Context) error {
141 var wg sync.WaitGroup
142 for name, session := range sessions.Seq2() {
143 wg.Go(func() {
144 done := make(chan error, 1)
145 go func() {
146 done <- session.Close()
147 }()
148 select {
149 case err := <-done:
150 if err != nil &&
151 !errors.Is(err, io.EOF) &&
152 !errors.Is(err, context.Canceled) &&
153 err.Error() != "signal: killed" {
154 slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
155 }
156 case <-ctx.Done():
157 }
158 })
159 }
160 wg.Wait()
161 broker.Shutdown()
162 return nil
163}
164
165// Initialize initializes MCP clients based on the provided configuration.
166func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
167 slog.Info("Initializing MCP clients")
168 var wg sync.WaitGroup
169 // Initialize states for all configured MCPs
170 for name, m := range cfg.MCP {
171 if m.Disabled {
172 updateState(name, StateDisabled, nil, nil, Counts{})
173 slog.Debug("Skipping disabled MCP", "name", name)
174 continue
175 }
176
177 // Set initial starting state
178 updateState(name, StateStarting, nil, nil, Counts{})
179
180 wg.Add(1)
181 go func(name string, m config.MCPConfig) {
182 defer func() {
183 wg.Done()
184 if r := recover(); r != nil {
185 var err error
186 switch v := r.(type) {
187 case error:
188 err = v
189 case string:
190 err = fmt.Errorf("panic: %s", v)
191 default:
192 err = fmt.Errorf("panic: %v", v)
193 }
194 updateState(name, StateError, err, nil, Counts{})
195 slog.Error("Panic in MCP client initialization", "error", err, "name", name)
196 }
197 }()
198
199 // createSession handles its own timeout internally.
200 session, err := createSession(ctx, name, m, cfg.Resolver())
201 if err != nil {
202 return
203 }
204
205 tools, err := getTools(ctx, session)
206 if err != nil {
207 slog.Error("Error listing tools", "error", err)
208 updateState(name, StateError, err, nil, Counts{})
209 session.Close()
210 return
211 }
212
213 prompts, err := getPrompts(ctx, session)
214 if err != nil {
215 slog.Error("Error listing prompts", "error", err)
216 updateState(name, StateError, err, nil, Counts{})
217 session.Close()
218 return
219 }
220
221 resources, err := getResources(ctx, session)
222 if err != nil {
223 slog.Error("Error listing resources", "error", err)
224 updateState(name, StateError, err, nil, Counts{})
225 session.Close()
226 return
227 }
228
229 toolCount := updateTools(cfg, name, tools)
230 updatePrompts(name, prompts)
231 resourceCount := updateResources(name, resources)
232 sessions.Set(name, session)
233
234 updateState(name, StateConnected, nil, session, Counts{
235 Tools: toolCount,
236 Prompts: len(prompts),
237 Resources: resourceCount,
238 })
239 }(name, m)
240 }
241 wg.Wait()
242 initOnce.Do(func() { close(initDone) })
243}
244
245// WaitForInit blocks until MCP initialization is complete.
246// If Initialize was never called, this returns immediately.
247func WaitForInit(ctx context.Context) error {
248 select {
249 case <-initDone:
250 return nil
251 case <-ctx.Done():
252 return ctx.Err()
253 }
254}
255
256func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*ClientSession, error) {
257 sess, ok := sessions.Get(name)
258 if !ok {
259 return nil, fmt.Errorf("mcp '%s' not available", name)
260 }
261
262 m := cfg.MCP[name]
263 state, _ := states.Get(name)
264
265 timeout := mcpTimeout(m)
266 pingCtx, cancel := context.WithTimeout(ctx, timeout)
267 defer cancel()
268 err := sess.Ping(pingCtx, nil)
269 if err == nil {
270 return sess, nil
271 }
272 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
273
274 sess, err = createSession(ctx, name, m, cfg.Resolver())
275 if err != nil {
276 return nil, err
277 }
278
279 updateState(name, StateConnected, nil, sess, state.Counts)
280 sessions.Set(name, sess)
281 return sess, nil
282}
283
284// updateState updates the state of an MCP client and publishes an event
285func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
286 info := ClientInfo{
287 Name: name,
288 State: state,
289 Error: err,
290 Client: client,
291 Counts: counts,
292 }
293 switch state {
294 case StateConnected:
295 info.ConnectedAt = time.Now()
296 case StateError:
297 sessions.Del(name)
298 }
299 states.Set(name, info)
300
301 // Publish state change event
302 broker.Publish(pubsub.UpdatedEvent, Event{
303 Type: EventStateChanged,
304 Name: name,
305 State: state,
306 Error: err,
307 Counts: counts,
308 })
309}
310
311func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
312 timeout := mcpTimeout(m)
313 mcpCtx, cancel := context.WithCancel(ctx)
314 cancelTimer := time.AfterFunc(timeout, cancel)
315
316 transport, err := createTransport(mcpCtx, m, resolver)
317 if err != nil {
318 updateState(name, StateError, err, nil, Counts{})
319 slog.Error("Error creating MCP client", "error", err, "name", name)
320 cancel()
321 cancelTimer.Stop()
322 return nil, err
323 }
324
325 client := mcp.NewClient(
326 &mcp.Implementation{
327 Name: "crush",
328 Version: version.Version,
329 Title: "Crush",
330 },
331 &mcp.ClientOptions{
332 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
333 broker.Publish(pubsub.UpdatedEvent, Event{
334 Type: EventToolsListChanged,
335 Name: name,
336 })
337 },
338 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
339 broker.Publish(pubsub.UpdatedEvent, Event{
340 Type: EventPromptsListChanged,
341 Name: name,
342 })
343 },
344 ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
345 broker.Publish(pubsub.UpdatedEvent, Event{
346 Type: EventResourcesListChanged,
347 Name: name,
348 })
349 },
350 LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
351 level := parseLevel(req.Params.Level)
352 slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
353 },
354 },
355 )
356
357 session, err := client.Connect(mcpCtx, transport, nil)
358 if err != nil {
359 err = maybeStdioErr(err, transport)
360 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
361 slog.Error("MCP client failed to initialize", "error", err, "name", name)
362 cancel()
363 cancelTimer.Stop()
364 return nil, err
365 }
366
367 cancelTimer.Stop()
368 slog.Debug("MCP client initialized", "name", name)
369 return &ClientSession{session, cancel}, nil
370}
371
372// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
373// to parse, and the cli will then close it, causing the EOF error.
374// so, if we got an EOF err, and the transport is STDIO, we try to exec it
375// again with a timeout and collect the output so we can add details to the
376// error.
377// this happens particularly when starting things with npx, e.g. if node can't
378// be found or some other error like that.
379func maybeStdioErr(err error, transport mcp.Transport) error {
380 if !errors.Is(err, io.EOF) {
381 return err
382 }
383 ct, ok := transport.(*mcp.CommandTransport)
384 if !ok {
385 return err
386 }
387 if err2 := stdioCheck(ct.Command); err2 != nil {
388 err = errors.Join(err, err2)
389 }
390 return err
391}
392
393func maybeTimeoutErr(err error, timeout time.Duration) error {
394 if errors.Is(err, context.Canceled) {
395 return fmt.Errorf("timed out after %s", timeout)
396 }
397 return err
398}
399
400func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
401 switch m.Type {
402 case config.MCPStdio:
403 command, err := resolver.ResolveValue(m.Command)
404 if err != nil {
405 return nil, fmt.Errorf("invalid mcp command: %w", err)
406 }
407 if strings.TrimSpace(command) == "" {
408 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
409 }
410 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
411 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
412 return &mcp.CommandTransport{
413 Command: cmd,
414 }, nil
415 case config.MCPHttp:
416 if strings.TrimSpace(m.URL) == "" {
417 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
418 }
419 client := &http.Client{
420 Transport: &headerRoundTripper{
421 headers: m.ResolvedHeaders(),
422 },
423 }
424 return &mcp.StreamableClientTransport{
425 Endpoint: m.URL,
426 HTTPClient: client,
427 }, nil
428 case config.MCPSSE:
429 if strings.TrimSpace(m.URL) == "" {
430 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
431 }
432 client := &http.Client{
433 Transport: &headerRoundTripper{
434 headers: m.ResolvedHeaders(),
435 },
436 }
437 return &mcp.SSEClientTransport{
438 Endpoint: m.URL,
439 HTTPClient: client,
440 }, nil
441 default:
442 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
443 }
444}
445
446type headerRoundTripper struct {
447 headers map[string]string
448}
449
450func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
451 for k, v := range rt.headers {
452 req.Header.Set(k, v)
453 }
454 return http.DefaultTransport.RoundTrip(req)
455}
456
457func mcpTimeout(m config.MCPConfig) time.Duration {
458 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
459}
460
461func stdioCheck(old *exec.Cmd) error {
462 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
463 defer cancel()
464 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
465 cmd.Env = old.Env
466 out, err := cmd.CombinedOutput()
467 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
468 return nil
469 }
470 return fmt.Errorf("%w: %s", err, string(out))
471}