1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28func parseLevel(level mcp.LoggingLevel) slog.Level {
29 switch level {
30 case "info":
31 return slog.LevelInfo
32 case "notice":
33 return slog.LevelInfo
34 case "warning":
35 return slog.LevelWarn
36 default:
37 return slog.LevelDebug
38 }
39}
40
41var (
42 sessions = csync.NewMap[string, *mcp.ClientSession]()
43 states = csync.NewMap[string, ClientInfo]()
44 broker = pubsub.NewBroker[Event]()
45 initOnce sync.Once
46 initDone = make(chan struct{})
47)
48
49// State represents the current state of an MCP client
50type State int
51
52const (
53 StateDisabled State = iota
54 StateStarting
55 StateConnected
56 StateError
57)
58
59func (s State) String() string {
60 switch s {
61 case StateDisabled:
62 return "disabled"
63 case StateStarting:
64 return "starting"
65 case StateConnected:
66 return "connected"
67 case StateError:
68 return "error"
69 default:
70 return "unknown"
71 }
72}
73
74// EventType represents the type of MCP event
75type EventType uint
76
77const (
78 EventStateChanged EventType = iota
79 EventToolsListChanged
80 EventPromptsListChanged
81 EventResourcesListChanged
82)
83
84// Event represents an event in the MCP system
85type Event struct {
86 Type EventType
87 Name string
88 State State
89 Error error
90 Counts Counts
91}
92
93// Counts number of available tools, prompts, etc.
94type Counts struct {
95 Tools int
96 Prompts int
97 Resources int
98}
99
100// ClientInfo holds information about an MCP client's state
101type ClientInfo struct {
102 Name string
103 State State
104 Error error
105 Client *mcp.ClientSession
106 Counts Counts
107 ConnectedAt time.Time
108}
109
110// SubscribeEvents returns a channel for MCP events
111func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
112 return broker.Subscribe(ctx)
113}
114
115// GetStates returns the current state of all MCP clients
116func GetStates() map[string]ClientInfo {
117 return states.Copy()
118}
119
120// GetState returns the state of a specific MCP client
121func GetState(name string) (ClientInfo, bool) {
122 return states.Get(name)
123}
124
125// Close closes all MCP clients. This should be called during application shutdown.
126func Close(ctx context.Context) error {
127 var wg sync.WaitGroup
128 for name, session := range sessions.Seq2() {
129 wg.Go(func() {
130 done := make(chan error, 1)
131 go func() {
132 done <- session.Close()
133 }()
134 select {
135 case err := <-done:
136 if err != nil &&
137 !errors.Is(err, io.EOF) &&
138 !errors.Is(err, context.Canceled) &&
139 err.Error() != "signal: killed" {
140 slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
141 }
142 case <-ctx.Done():
143 }
144 })
145 }
146 wg.Wait()
147 broker.Shutdown()
148 return nil
149}
150
151// Initialize initializes MCP clients based on the provided configuration.
152func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
153 slog.Info("Initializing MCP clients")
154 var wg sync.WaitGroup
155 // Initialize states for all configured MCPs
156 for name, m := range cfg.MCP {
157 if m.Disabled {
158 updateState(name, StateDisabled, nil, nil, Counts{})
159 slog.Debug("Skipping disabled MCP", "name", name)
160 continue
161 }
162
163 // Set initial starting state
164 updateState(name, StateStarting, nil, nil, Counts{})
165
166 wg.Add(1)
167 go func(name string, m config.MCPConfig) {
168 defer func() {
169 wg.Done()
170 if r := recover(); r != nil {
171 var err error
172 switch v := r.(type) {
173 case error:
174 err = v
175 case string:
176 err = fmt.Errorf("panic: %s", v)
177 default:
178 err = fmt.Errorf("panic: %v", v)
179 }
180 updateState(name, StateError, err, nil, Counts{})
181 slog.Error("Panic in MCP client initialization", "error", err, "name", name)
182 }
183 }()
184
185 // createSession handles its own timeout internally.
186 session, err := createSession(ctx, name, m, cfg.Resolver())
187 if err != nil {
188 return
189 }
190
191 tools, err := getTools(ctx, session)
192 if err != nil {
193 slog.Error("Error listing tools", "error", err)
194 updateState(name, StateError, err, nil, Counts{})
195 session.Close()
196 return
197 }
198
199 prompts, err := getPrompts(ctx, session)
200 if err != nil {
201 slog.Error("Error listing prompts", "error", err)
202 updateState(name, StateError, err, nil, Counts{})
203 session.Close()
204 return
205 }
206
207 resources, err := getResources(ctx, session)
208 if err != nil {
209 slog.Error("Error listing resources", "error", err)
210 updateState(name, StateError, err, nil, Counts{})
211 session.Close()
212 return
213 }
214
215 toolCount := updateTools(cfg, name, tools)
216 updatePrompts(name, prompts)
217 resourceCount := updateResources(name, resources)
218 sessions.Set(name, session)
219
220 updateState(name, StateConnected, nil, session, Counts{
221 Tools: toolCount,
222 Prompts: len(prompts),
223 Resources: resourceCount,
224 })
225 }(name, m)
226 }
227 wg.Wait()
228 initOnce.Do(func() { close(initDone) })
229}
230
231// WaitForInit blocks until MCP initialization is complete.
232// If Initialize was never called, this returns immediately.
233func WaitForInit(ctx context.Context) error {
234 select {
235 case <-initDone:
236 return nil
237 case <-ctx.Done():
238 return ctx.Err()
239 }
240}
241
242func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*mcp.ClientSession, error) {
243 sess, ok := sessions.Get(name)
244 if !ok {
245 return nil, fmt.Errorf("mcp '%s' not available", name)
246 }
247
248 m := cfg.MCP[name]
249 state, _ := states.Get(name)
250
251 timeout := mcpTimeout(m)
252 pingCtx, cancel := context.WithTimeout(ctx, timeout)
253 defer cancel()
254 err := sess.Ping(pingCtx, nil)
255 if err == nil {
256 return sess, nil
257 }
258 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
259
260 sess, err = createSession(ctx, name, m, cfg.Resolver())
261 if err != nil {
262 return nil, err
263 }
264
265 updateState(name, StateConnected, nil, sess, state.Counts)
266 sessions.Set(name, sess)
267 return sess, nil
268}
269
270// updateState updates the state of an MCP client and publishes an event
271func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
272 info := ClientInfo{
273 Name: name,
274 State: state,
275 Error: err,
276 Client: client,
277 Counts: counts,
278 }
279 switch state {
280 case StateConnected:
281 info.ConnectedAt = time.Now()
282 case StateError:
283 sessions.Del(name)
284 }
285 states.Set(name, info)
286
287 // Publish state change event
288 broker.Publish(pubsub.UpdatedEvent, Event{
289 Type: EventStateChanged,
290 Name: name,
291 State: state,
292 Error: err,
293 Counts: counts,
294 })
295}
296
297func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
298 timeout := mcpTimeout(m)
299 mcpCtx, cancel := context.WithCancel(ctx)
300 cancelTimer := time.AfterFunc(timeout, cancel)
301
302 transport, err := createTransport(mcpCtx, m, resolver)
303 if err != nil {
304 updateState(name, StateError, err, nil, Counts{})
305 slog.Error("Error creating MCP client", "error", err, "name", name)
306 cancel()
307 cancelTimer.Stop()
308 return nil, err
309 }
310
311 client := mcp.NewClient(
312 &mcp.Implementation{
313 Name: "crush",
314 Version: version.Version,
315 Title: "Crush",
316 },
317 &mcp.ClientOptions{
318 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
319 broker.Publish(pubsub.UpdatedEvent, Event{
320 Type: EventToolsListChanged,
321 Name: name,
322 })
323 },
324 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
325 broker.Publish(pubsub.UpdatedEvent, Event{
326 Type: EventPromptsListChanged,
327 Name: name,
328 })
329 },
330 ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
331 broker.Publish(pubsub.UpdatedEvent, Event{
332 Type: EventResourcesListChanged,
333 Name: name,
334 })
335 },
336 LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
337 level := parseLevel(req.Params.Level)
338 slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
339 },
340 },
341 )
342
343 session, err := client.Connect(mcpCtx, transport, nil)
344 if err != nil {
345 err = maybeStdioErr(err, transport)
346 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
347 slog.Error("MCP client failed to initialize", "error", err, "name", name)
348 cancel()
349 cancelTimer.Stop()
350 return nil, err
351 }
352
353 cancelTimer.Stop()
354 slog.Debug("MCP client initialized", "name", name)
355 return session, nil
356}
357
358// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
359// to parse, and the cli will then close it, causing the EOF error.
360// so, if we got an EOF err, and the transport is STDIO, we try to exec it
361// again with a timeout and collect the output so we can add details to the
362// error.
363// this happens particularly when starting things with npx, e.g. if node can't
364// be found or some other error like that.
365func maybeStdioErr(err error, transport mcp.Transport) error {
366 if !errors.Is(err, io.EOF) {
367 return err
368 }
369 ct, ok := transport.(*mcp.CommandTransport)
370 if !ok {
371 return err
372 }
373 if err2 := stdioCheck(ct.Command); err2 != nil {
374 err = errors.Join(err, err2)
375 }
376 return err
377}
378
379func maybeTimeoutErr(err error, timeout time.Duration) error {
380 if errors.Is(err, context.Canceled) {
381 return fmt.Errorf("timed out after %s", timeout)
382 }
383 return err
384}
385
386func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
387 switch m.Type {
388 case config.MCPStdio:
389 command, err := resolver.ResolveValue(m.Command)
390 if err != nil {
391 return nil, fmt.Errorf("invalid mcp command: %w", err)
392 }
393 if strings.TrimSpace(command) == "" {
394 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
395 }
396 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
397 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
398 return &mcp.CommandTransport{
399 Command: cmd,
400 }, nil
401 case config.MCPHttp:
402 if strings.TrimSpace(m.URL) == "" {
403 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
404 }
405 client := &http.Client{
406 Transport: &headerRoundTripper{
407 headers: m.ResolvedHeaders(),
408 },
409 }
410 return &mcp.StreamableClientTransport{
411 Endpoint: m.URL,
412 HTTPClient: client,
413 }, nil
414 case config.MCPSSE:
415 if strings.TrimSpace(m.URL) == "" {
416 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
417 }
418 client := &http.Client{
419 Transport: &headerRoundTripper{
420 headers: m.ResolvedHeaders(),
421 },
422 }
423 return &mcp.SSEClientTransport{
424 Endpoint: m.URL,
425 HTTPClient: client,
426 }, nil
427 default:
428 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
429 }
430}
431
432type headerRoundTripper struct {
433 headers map[string]string
434}
435
436func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
437 for k, v := range rt.headers {
438 req.Header.Set(k, v)
439 }
440 return http.DefaultTransport.RoundTrip(req)
441}
442
443func mcpTimeout(m config.MCPConfig) time.Duration {
444 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
445}
446
447func stdioCheck(old *exec.Cmd) error {
448 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
449 defer cancel()
450 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
451 cmd.Env = old.Env
452 out, err := cmd.CombinedOutput()
453 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
454 return nil
455 }
456 return fmt.Errorf("%w: %s", err, string(out))
457}