1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "os/exec"
16 "strings"
17 "sync"
18 "time"
19
20 "github.com/charmbracelet/crush/internal/config"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/home"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/version"
26 "github.com/modelcontextprotocol/go-sdk/mcp"
27)
28
29var (
30 sessions = csync.NewMap[string, *mcp.ClientSession]()
31 states = csync.NewMap[string, ClientInfo]()
32 broker = pubsub.NewBroker[Event]()
33)
34
35// State represents the current state of an MCP client
36type State int
37
38const (
39 StateDisabled State = iota
40 StateStarting
41 StateConnected
42 StateError
43)
44
45func (s State) String() string {
46 switch s {
47 case StateDisabled:
48 return "disabled"
49 case StateStarting:
50 return "starting"
51 case StateConnected:
52 return "connected"
53 case StateError:
54 return "error"
55 default:
56 return "unknown"
57 }
58}
59
60// EventType represents the type of MCP event
61type EventType uint
62
63const (
64 EventStateChanged EventType = iota
65 EventToolsListChanged
66 EventPromptsListChanged
67 EventResourcesListChanged
68)
69
70// Event represents an event in the MCP system
71type Event struct {
72 Type EventType
73 Name string
74 State State
75 Error error
76 Counts Counts
77}
78
79// Counts number of available tools, prompts, etc.
80type Counts struct {
81 Tools int
82 Prompts int
83 Resources int
84}
85
86// ClientInfo holds information about an MCP client's state
87type ClientInfo struct {
88 Name string
89 State State
90 Error error
91 Client *mcp.ClientSession
92 Counts Counts
93 ConnectedAt time.Time
94}
95
96// SubscribeEvents returns a channel for MCP events
97func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
98 return broker.Subscribe(ctx)
99}
100
101// GetStates returns the current state of all MCP clients
102func GetStates() map[string]ClientInfo {
103 return maps.Collect(states.Seq2())
104}
105
106// GetState returns the state of a specific MCP client
107func GetState(name string) (ClientInfo, bool) {
108 return states.Get(name)
109}
110
111// Close closes all MCP clients. This should be called during application shutdown.
112func Close() error {
113 var errs []error
114 for name, c := range sessions.Seq2() {
115 if err := c.Close(); err != nil &&
116 !errors.Is(err, io.EOF) &&
117 !errors.Is(err, context.Canceled) &&
118 err.Error() != "signal: killed" {
119 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
120 }
121 }
122 broker.Shutdown()
123 return errors.Join(errs...)
124}
125
126// Initialize initializes MCP clients based on the provided configuration.
127func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
128 var wg sync.WaitGroup
129 // Initialize states for all configured MCPs
130 for name, m := range cfg.MCP {
131 if m.Disabled {
132 updateState(name, StateDisabled, nil, nil, Counts{})
133 slog.Debug("skipping disabled mcp", "name", name)
134 continue
135 }
136
137 // Set initial starting state
138 updateState(name, StateStarting, nil, nil, Counts{})
139
140 wg.Add(1)
141 go func(name string, m config.MCPConfig) {
142 defer func() {
143 wg.Done()
144 if r := recover(); r != nil {
145 var err error
146 switch v := r.(type) {
147 case error:
148 err = v
149 case string:
150 err = fmt.Errorf("panic: %s", v)
151 default:
152 err = fmt.Errorf("panic: %v", v)
153 }
154 updateState(name, StateError, err, nil, Counts{})
155 slog.Error("panic in mcp client initialization", "error", err, "name", name)
156 }
157 }()
158
159 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
160 defer cancel()
161
162 session, err := createSession(ctx, name, m, cfg.Resolver())
163 if err != nil {
164 return
165 }
166
167 tools, err := getTools(ctx, session)
168 if err != nil {
169 slog.Error("error listing tools", "error", err)
170 updateState(name, StateError, err, nil, Counts{})
171 session.Close()
172 return
173 }
174
175 prompts, err := getPrompts(ctx, session)
176 if err != nil {
177 slog.Error("error listing prompts", "error", err)
178 updateState(name, StateError, err, nil, Counts{})
179 session.Close()
180 return
181 }
182
183 resources, err := getResources(ctx, session)
184 if err != nil {
185 slog.Error("error listing resources", "error", err)
186 updateState(name, StateError, err, nil, Counts{})
187 session.Close()
188 return
189 }
190
191 updateTools(name, tools)
192 updatePrompts(name, prompts)
193 updateResources(name, resources)
194 sessions.Set(name, session)
195
196 updateState(name, StateConnected, nil, session, Counts{
197 Tools: len(tools),
198 Prompts: len(prompts),
199 Resources: len(resources),
200 })
201 }(name, m)
202 }
203 wg.Wait()
204}
205
206func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
207 sess, ok := sessions.Get(name)
208 if !ok {
209 return nil, fmt.Errorf("mcp '%s' not available", name)
210 }
211
212 cfg := config.Get()
213 m := cfg.MCP[name]
214 state, _ := states.Get(name)
215
216 timeout := mcpTimeout(m)
217 pingCtx, cancel := context.WithTimeout(ctx, timeout)
218 defer cancel()
219 err := sess.Ping(pingCtx, nil)
220 if err == nil {
221 return sess, nil
222 }
223 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
224
225 sess, err = createSession(ctx, name, m, cfg.Resolver())
226 if err != nil {
227 return nil, err
228 }
229
230 updateState(name, StateConnected, nil, sess, state.Counts)
231 sessions.Set(name, sess)
232 return sess, nil
233}
234
235// updateState updates the state of an MCP client and publishes an event
236func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
237 info := ClientInfo{
238 Name: name,
239 State: state,
240 Error: err,
241 Client: client,
242 Counts: counts,
243 }
244 switch state {
245 case StateConnected:
246 info.ConnectedAt = time.Now()
247 case StateError:
248 updateTools(name, nil)
249 sessions.Del(name)
250 }
251 states.Set(name, info)
252
253 // Publish state change event
254 broker.Publish(pubsub.UpdatedEvent, Event{
255 Type: EventStateChanged,
256 Name: name,
257 State: state,
258 Error: err,
259 Counts: counts,
260 })
261}
262
263func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
264 timeout := mcpTimeout(m)
265 mcpCtx, cancel := context.WithCancel(ctx)
266 cancelTimer := time.AfterFunc(timeout, cancel)
267
268 transport, err := createTransport(mcpCtx, m, resolver)
269 if err != nil {
270 updateState(name, StateError, err, nil, Counts{})
271 slog.Error("error creating mcp client", "error", err, "name", name)
272 cancel()
273 cancelTimer.Stop()
274 return nil, err
275 }
276
277 client := mcp.NewClient(
278 &mcp.Implementation{
279 Name: "crush",
280 Version: version.Version,
281 Title: "Crush",
282 },
283 &mcp.ClientOptions{
284 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
285 broker.Publish(pubsub.UpdatedEvent, Event{
286 Type: EventToolsListChanged,
287 Name: name,
288 })
289 },
290 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
291 broker.Publish(pubsub.UpdatedEvent, Event{
292 Type: EventPromptsListChanged,
293 Name: name,
294 })
295 },
296 ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
297 broker.Publish(pubsub.UpdatedEvent, Event{
298 Type: EventResourcesListChanged,
299 Name: name,
300 })
301 },
302 LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
303 slog.Log(ctx, parseLevel(req.Params.Level), "mcp server log", "data", req.Params.Data)
304 },
305 KeepAlive: time.Minute * 10,
306 },
307 )
308
309 session, err := client.Connect(mcpCtx, transport, nil)
310 if err != nil {
311 err = maybeStdioErr(err, transport)
312 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
313 slog.Error("error starting mcp client", "error", err, "name", name)
314 cancel()
315 cancelTimer.Stop()
316 return nil, err
317 }
318
319 cancelTimer.Stop()
320 slog.Info("Initialized mcp client", "name", name)
321 return session, nil
322}
323
324func parseLevel(level mcp.LoggingLevel) slog.Level {
325 switch level {
326 case "debug":
327 return slog.LevelDebug
328 case "info", "notice":
329 return slog.LevelInfo
330 case "warning", "warn":
331 return slog.LevelWarn
332 case "error", "emergency", "alert", "fatal":
333 fallthrough
334 default:
335 return slog.LevelError
336 }
337}
338
339// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
340// to parse, and the cli will then close it, causing the EOF error.
341// so, if we got an EOF err, and the transport is STDIO, we try to exec it
342// again with a timeout and collect the output so we can add details to the
343// error.
344// this happens particularly when starting things with npx, e.g. if node can't
345// be found or some other error like that.
346func maybeStdioErr(err error, transport mcp.Transport) error {
347 if !errors.Is(err, io.EOF) {
348 return err
349 }
350 ct, ok := transport.(*mcp.CommandTransport)
351 if !ok {
352 return err
353 }
354 if err2 := stdioCheck(ct.Command); err2 != nil {
355 err = errors.Join(err, err2)
356 }
357 return err
358}
359
360func maybeTimeoutErr(err error, timeout time.Duration) error {
361 if errors.Is(err, context.Canceled) {
362 return fmt.Errorf("timed out after %s", timeout)
363 }
364 return err
365}
366
367func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
368 switch m.Type {
369 case config.MCPStdio:
370 command, err := resolver.ResolveValue(m.Command)
371 if err != nil {
372 return nil, fmt.Errorf("invalid mcp command: %w", err)
373 }
374 if strings.TrimSpace(command) == "" {
375 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
376 }
377 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
378 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
379 return &mcp.CommandTransport{
380 Command: cmd,
381 }, nil
382 case config.MCPHttp:
383 if strings.TrimSpace(m.URL) == "" {
384 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
385 }
386 client := &http.Client{
387 Transport: &headerRoundTripper{
388 headers: m.ResolvedHeaders(),
389 },
390 }
391 return &mcp.StreamableClientTransport{
392 Endpoint: m.URL,
393 HTTPClient: client,
394 }, nil
395 case config.MCPSSE:
396 if strings.TrimSpace(m.URL) == "" {
397 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
398 }
399 client := &http.Client{
400 Transport: &headerRoundTripper{
401 headers: m.ResolvedHeaders(),
402 },
403 }
404 return &mcp.SSEClientTransport{
405 Endpoint: m.URL,
406 HTTPClient: client,
407 }, nil
408 default:
409 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
410 }
411}
412
413type headerRoundTripper struct {
414 headers map[string]string
415}
416
417func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
418 for k, v := range rt.headers {
419 req.Header.Set(k, v)
420 }
421 return http.DefaultTransport.RoundTrip(req)
422}
423
424func mcpTimeout(m config.MCPConfig) time.Duration {
425 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
426}
427
428func stdioCheck(old *exec.Cmd) error {
429 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
430 defer cancel()
431 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
432 cmd.Env = old.Env
433 out, err := cmd.CombinedOutput()
434 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
435 return nil
436 }
437 return fmt.Errorf("%w: %s", err, string(out))
438}