1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "os/exec"
16 "strings"
17 "sync"
18 "time"
19
20 "github.com/charmbracelet/crush/internal/config"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/home"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/version"
26 "github.com/modelcontextprotocol/go-sdk/mcp"
27)
28
29var (
30 sessions = csync.NewMap[string, *mcp.ClientSession]()
31 states = csync.NewMap[string, ClientInfo]()
32 broker = pubsub.NewBroker[Event]()
33)
34
35// State represents the current state of an MCP client
36type State int
37
38const (
39 StateDisabled State = iota
40 StateStarting
41 StateConnected
42 StateError
43)
44
45func (s State) String() string {
46 switch s {
47 case StateDisabled:
48 return "disabled"
49 case StateStarting:
50 return "starting"
51 case StateConnected:
52 return "connected"
53 case StateError:
54 return "error"
55 default:
56 return "unknown"
57 }
58}
59
60// EventType represents the type of MCP event
61type EventType uint
62
63const (
64 EventStateChanged EventType = iota
65 EventToolsListChanged
66 EventPromptsListChanged
67)
68
69// Event represents an event in the MCP system
70type Event struct {
71 Type EventType
72 Name string
73 State State
74 Error error
75 Counts Counts
76}
77
78// Counts number of available tools, prompts, etc.
79type Counts struct {
80 Tools int
81 Prompts int
82}
83
84// ClientInfo holds information about an MCP client's state
85type ClientInfo struct {
86 Name string
87 State State
88 Error error
89 Client *mcp.ClientSession
90 Counts Counts
91 ConnectedAt time.Time
92}
93
94// SubscribeEvents returns a channel for MCP events
95func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
96 return broker.Subscribe(ctx)
97}
98
99// GetStates returns the current state of all MCP clients
100func GetStates() map[string]ClientInfo {
101 return maps.Collect(states.Seq2())
102}
103
104// GetState returns the state of a specific MCP client
105func GetState(name string) (ClientInfo, bool) {
106 return states.Get(name)
107}
108
109// Close closes all MCP clients. This should be called during application shutdown.
110func Close() error {
111 var errs []error
112 for name, session := range sessions.Seq2() {
113 if err := session.Close(); err != nil &&
114 !errors.Is(err, io.EOF) &&
115 !errors.Is(err, context.Canceled) &&
116 err.Error() != "signal: killed" {
117 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
118 }
119 }
120 broker.Shutdown()
121 return errors.Join(errs...)
122}
123
124// Initialize initializes MCP clients based on the provided configuration.
125func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
126 var wg sync.WaitGroup
127 // Initialize states for all configured MCPs
128 for name, m := range cfg.MCP {
129 if m.Disabled {
130 updateState(name, StateDisabled, nil, nil, Counts{})
131 slog.Debug("skipping disabled mcp", "name", name)
132 continue
133 }
134
135 // Set initial starting state
136 updateState(name, StateStarting, nil, nil, Counts{})
137
138 wg.Add(1)
139 go func(name string, m config.MCPConfig) {
140 defer func() {
141 wg.Done()
142 if r := recover(); r != nil {
143 var err error
144 switch v := r.(type) {
145 case error:
146 err = v
147 case string:
148 err = fmt.Errorf("panic: %s", v)
149 default:
150 err = fmt.Errorf("panic: %v", v)
151 }
152 updateState(name, StateError, err, nil, Counts{})
153 slog.Error("panic in mcp client initialization", "error", err, "name", name)
154 }
155 }()
156
157 // createSession handles its own timeout internally.
158 session, err := createSession(ctx, name, m, cfg.Resolver())
159 if err != nil {
160 return
161 }
162
163 tools, err := getTools(ctx, session)
164 if err != nil {
165 slog.Error("error listing tools", "error", err)
166 updateState(name, StateError, err, nil, Counts{})
167 session.Close()
168 return
169 }
170
171 prompts, err := getPrompts(ctx, session)
172 if err != nil {
173 slog.Error("error listing prompts", "error", err)
174 updateState(name, StateError, err, nil, Counts{})
175 session.Close()
176 return
177 }
178
179 updateTools(name, tools)
180 updatePrompts(name, prompts)
181 sessions.Set(name, session)
182
183 updateState(name, StateConnected, nil, session, Counts{
184 Tools: len(tools),
185 Prompts: len(prompts),
186 })
187 }(name, m)
188 }
189 wg.Wait()
190}
191
192// InitializeSingle initializes a single MCP client by name.
193func InitializeSingle(ctx context.Context, name string, cfg *config.Config) error {
194 m, exists := cfg.MCP[name]
195 if !exists {
196 return fmt.Errorf("mcp '%s' not found in configuration", name)
197 }
198
199 if m.Disabled {
200 updateState(name, StateDisabled, nil, nil, Counts{})
201 slog.Debug("skipping disabled mcp", "name", name)
202 return nil
203 }
204
205 // Set initial starting state.
206 updateState(name, StateStarting, nil, nil, Counts{})
207
208 // createSession handles its own timeout internally.
209 session, err := createSession(ctx, name, m, cfg.Resolver())
210 if err != nil {
211 return err
212 }
213
214 tools, err := getTools(ctx, session)
215 if err != nil {
216 slog.Error("error listing tools", "error", err)
217 updateState(name, StateError, err, nil, Counts{})
218 session.Close()
219 return err
220 }
221
222 prompts, err := getPrompts(ctx, session)
223 if err != nil {
224 slog.Error("error listing prompts", "error", err)
225 updateState(name, StateError, err, nil, Counts{})
226 session.Close()
227 return err
228 }
229
230 updateTools(name, tools)
231 updatePrompts(name, prompts)
232 sessions.Set(name, session)
233
234 updateState(name, StateConnected, nil, session, Counts{
235 Tools: len(tools),
236 Prompts: len(prompts),
237 })
238
239 return nil
240}
241
242// DisableSingle disables and closes a single MCP client by name.
243func DisableSingle(name string) error {
244 session, ok := sessions.Get(name)
245 if ok {
246 if err := session.Close(); err != nil &&
247 !errors.Is(err, io.EOF) &&
248 !errors.Is(err, context.Canceled) &&
249 err.Error() != "signal: killed" {
250 slog.Warn("error closing mcp session", "name", name, "error", err)
251 }
252 sessions.Del(name)
253 }
254
255 // Clear tools and prompts for this MCP.
256 updateTools(name, nil)
257 updatePrompts(name, nil)
258
259 // Update state to disabled.
260 updateState(name, StateDisabled, nil, nil, Counts{})
261
262 slog.Info("Disabled mcp client", "name", name)
263 return nil
264}
265
266func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
267 sess, ok := sessions.Get(name)
268 if !ok {
269 return nil, fmt.Errorf("mcp '%s' not available", name)
270 }
271
272 cfg := config.Get()
273 m := cfg.MCP[name]
274 state, _ := states.Get(name)
275
276 timeout := mcpTimeout(m)
277 pingCtx, cancel := context.WithTimeout(ctx, timeout)
278 defer cancel()
279 err := sess.Ping(pingCtx, nil)
280 if err == nil {
281 return sess, nil
282 }
283 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
284
285 sess, err = createSession(ctx, name, m, cfg.Resolver())
286 if err != nil {
287 return nil, err
288 }
289
290 updateState(name, StateConnected, nil, sess, state.Counts)
291 sessions.Set(name, sess)
292 return sess, nil
293}
294
295// updateState updates the state of an MCP client and publishes an event
296func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
297 info := ClientInfo{
298 Name: name,
299 State: state,
300 Error: err,
301 Client: client,
302 Counts: counts,
303 }
304 switch state {
305 case StateConnected:
306 info.ConnectedAt = time.Now()
307 case StateError:
308 sessions.Del(name)
309 }
310 states.Set(name, info)
311
312 // Publish state change event
313 broker.Publish(pubsub.UpdatedEvent, Event{
314 Type: EventStateChanged,
315 Name: name,
316 State: state,
317 Error: err,
318 Counts: counts,
319 })
320}
321
322func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
323 timeout := mcpTimeout(m)
324 mcpCtx, cancel := context.WithCancel(ctx)
325 cancelTimer := time.AfterFunc(timeout, cancel)
326
327 transport, err := createTransport(mcpCtx, m, resolver)
328 if err != nil {
329 updateState(name, StateError, err, nil, Counts{})
330 slog.Error("error creating mcp client", "error", err, "name", name)
331 cancel()
332 cancelTimer.Stop()
333 return nil, err
334 }
335
336 client := mcp.NewClient(
337 &mcp.Implementation{
338 Name: "crush",
339 Version: version.Version,
340 Title: "Crush",
341 },
342 &mcp.ClientOptions{
343 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
344 broker.Publish(pubsub.UpdatedEvent, Event{
345 Type: EventToolsListChanged,
346 Name: name,
347 })
348 },
349 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
350 broker.Publish(pubsub.UpdatedEvent, Event{
351 Type: EventPromptsListChanged,
352 Name: name,
353 })
354 },
355 LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
356 slog.Info("mcp log", "name", name, "data", req.Params.Data)
357 },
358 KeepAlive: time.Minute * 10,
359 },
360 )
361
362 session, err := client.Connect(mcpCtx, transport, nil)
363 if err != nil {
364 err = maybeStdioErr(err, transport)
365 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
366 slog.Error("error starting mcp client", "error", err, "name", name)
367 cancel()
368 cancelTimer.Stop()
369 return nil, err
370 }
371
372 cancelTimer.Stop()
373 slog.Info("Initialized mcp client", "name", name)
374 return session, nil
375}
376
377// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
378// to parse, and the cli will then close it, causing the EOF error.
379// so, if we got an EOF err, and the transport is STDIO, we try to exec it
380// again with a timeout and collect the output so we can add details to the
381// error.
382// this happens particularly when starting things with npx, e.g. if node can't
383// be found or some other error like that.
384func maybeStdioErr(err error, transport mcp.Transport) error {
385 if !errors.Is(err, io.EOF) {
386 return err
387 }
388 ct, ok := transport.(*mcp.CommandTransport)
389 if !ok {
390 return err
391 }
392 if err2 := stdioCheck(ct.Command); err2 != nil {
393 err = errors.Join(err, err2)
394 }
395 return err
396}
397
398func maybeTimeoutErr(err error, timeout time.Duration) error {
399 if errors.Is(err, context.Canceled) {
400 return fmt.Errorf("timed out after %s", timeout)
401 }
402 return err
403}
404
405func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
406 switch m.Type {
407 case config.MCPStdio:
408 command, err := resolver.ResolveValue(m.Command)
409 if err != nil {
410 return nil, fmt.Errorf("invalid mcp command: %w", err)
411 }
412 if strings.TrimSpace(command) == "" {
413 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
414 }
415 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
416 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
417 return &mcp.CommandTransport{
418 Command: cmd,
419 }, nil
420 case config.MCPHttp:
421 if strings.TrimSpace(m.URL) == "" {
422 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
423 }
424 client := &http.Client{
425 Transport: &headerRoundTripper{
426 headers: m.ResolvedHeaders(),
427 },
428 }
429 return &mcp.StreamableClientTransport{
430 Endpoint: m.URL,
431 HTTPClient: client,
432 }, nil
433 case config.MCPSSE:
434 if strings.TrimSpace(m.URL) == "" {
435 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
436 }
437 client := &http.Client{
438 Transport: &headerRoundTripper{
439 headers: m.ResolvedHeaders(),
440 },
441 }
442 return &mcp.SSEClientTransport{
443 Endpoint: m.URL,
444 HTTPClient: client,
445 }, nil
446 default:
447 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
448 }
449}
450
451type headerRoundTripper struct {
452 headers map[string]string
453}
454
455func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
456 for k, v := range rt.headers {
457 req.Header.Set(k, v)
458 }
459 return http.DefaultTransport.RoundTrip(req)
460}
461
462func mcpTimeout(m config.MCPConfig) time.Duration {
463 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
464}
465
466func stdioCheck(old *exec.Cmd) error {
467 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
468 defer cancel()
469 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
470 cmd.Env = old.Env
471 out, err := cmd.CombinedOutput()
472 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
473 return nil
474 }
475 return fmt.Errorf("%w: %s", err, string(out))
476}