1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28var (
29 sessions = csync.NewMap[string, *mcp.ClientSession]()
30 states = csync.NewMap[string, ClientInfo]()
31 broker = pubsub.NewBroker[Event]()
32 initOnce sync.Once
33 initDone = make(chan struct{})
34)
35
36// State represents the current state of an MCP client
37type State int
38
39const (
40 StateDisabled State = iota
41 StateStarting
42 StateConnected
43 StateError
44)
45
46func (s State) String() string {
47 switch s {
48 case StateDisabled:
49 return "disabled"
50 case StateStarting:
51 return "starting"
52 case StateConnected:
53 return "connected"
54 case StateError:
55 return "error"
56 default:
57 return "unknown"
58 }
59}
60
61// EventType represents the type of MCP event
62type EventType uint
63
64const (
65 EventStateChanged EventType = iota
66 EventToolsListChanged
67 EventPromptsListChanged
68)
69
70// Event represents an event in the MCP system
71type Event struct {
72 Type EventType
73 Name string
74 State State
75 Error error
76 Counts Counts
77}
78
79// Counts number of available tools, prompts, etc.
80type Counts struct {
81 Tools int
82 Prompts int
83}
84
85// ClientInfo holds information about an MCP client's state
86type ClientInfo struct {
87 Name string
88 State State
89 Error error
90 Client *mcp.ClientSession
91 Counts Counts
92 ConnectedAt time.Time
93}
94
95// SubscribeEvents returns a channel for MCP events
96func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
97 return broker.Subscribe(ctx)
98}
99
100// GetStates returns the current state of all MCP clients
101func GetStates() map[string]ClientInfo {
102 return states.Copy()
103}
104
105// GetState returns the state of a specific MCP client
106func GetState(name string) (ClientInfo, bool) {
107 return states.Get(name)
108}
109
110// Close closes all MCP clients. This should be called during application shutdown.
111func Close() error {
112 var wg sync.WaitGroup
113 done := make(chan struct{}, 1)
114 go func() {
115 for name, session := range sessions.Seq2() {
116 wg.Go(func() {
117 if err := session.Close(); err != nil &&
118 !errors.Is(err, io.EOF) &&
119 !errors.Is(err, context.Canceled) &&
120 err.Error() != "signal: killed" {
121 slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
122 }
123 })
124 }
125 wg.Wait()
126 done <- struct{}{}
127 }()
128 select {
129 case <-done:
130 case <-time.After(5 * time.Second):
131 }
132 broker.Shutdown()
133 return nil
134}
135
136// Initialize initializes MCP clients based on the provided configuration.
137func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Service) {
138 slog.Info("Initializing MCP clients")
139 var wg sync.WaitGroup
140 // Initialize states for all configured MCPs
141 for name, m := range cfg.MCP() {
142 if m.Disabled {
143 updateState(name, StateDisabled, nil, nil, Counts{})
144 slog.Debug("Skipping disabled MCP", "name", name)
145 continue
146 }
147
148 // Set initial starting state
149 updateState(name, StateStarting, nil, nil, Counts{})
150
151 wg.Add(1)
152 go func(name string, m config.MCPConfig) {
153 defer func() {
154 wg.Done()
155 if r := recover(); r != nil {
156 var err error
157 switch v := r.(type) {
158 case error:
159 err = v
160 case string:
161 err = fmt.Errorf("panic: %s", v)
162 default:
163 err = fmt.Errorf("panic: %v", v)
164 }
165 updateState(name, StateError, err, nil, Counts{})
166 slog.Error("Panic in MCP client initialization", "error", err, "name", name)
167 }
168 }()
169
170 // createSession handles its own timeout internally.
171 session, err := createSession(ctx, name, m, cfg.Resolver())
172 if err != nil {
173 return
174 }
175
176 tools, err := getTools(ctx, session)
177 if err != nil {
178 slog.Error("Error listing tools", "error", err)
179 updateState(name, StateError, err, nil, Counts{})
180 session.Close()
181 return
182 }
183
184 prompts, err := getPrompts(ctx, session)
185 if err != nil {
186 slog.Error("Error listing prompts", "error", err)
187 updateState(name, StateError, err, nil, Counts{})
188 session.Close()
189 return
190 }
191
192 toolCount := updateTools(cfg, name, tools)
193 updatePrompts(name, prompts)
194 sessions.Set(name, session)
195
196 updateState(name, StateConnected, nil, session, Counts{
197 Tools: toolCount,
198 Prompts: len(prompts),
199 })
200 }(name, m)
201 }
202 wg.Wait()
203 initOnce.Do(func() { close(initDone) })
204}
205
206// WaitForInit blocks until MCP initialization is complete.
207// If Initialize was never called, this returns immediately.
208func WaitForInit(ctx context.Context) error {
209 select {
210 case <-initDone:
211 return nil
212 case <-ctx.Done():
213 return ctx.Err()
214 }
215}
216
217func getOrRenewClient(ctx context.Context, cfg *config.Service, name string) (*mcp.ClientSession, error) {
218 sess, ok := sessions.Get(name)
219 if !ok {
220 return nil, fmt.Errorf("mcp '%s' not available", name)
221 }
222
223 m := cfg.MCP()[name]
224 state, _ := states.Get(name)
225
226 timeout := mcpTimeout(m)
227 pingCtx, cancel := context.WithTimeout(ctx, timeout)
228 defer cancel()
229 err := sess.Ping(pingCtx, nil)
230 if err == nil {
231 return sess, nil
232 }
233 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
234
235 sess, err = createSession(ctx, name, m, cfg.Resolver())
236 if err != nil {
237 return nil, err
238 }
239
240 updateState(name, StateConnected, nil, sess, state.Counts)
241 sessions.Set(name, sess)
242 return sess, nil
243}
244
245// updateState updates the state of an MCP client and publishes an event
246func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
247 info := ClientInfo{
248 Name: name,
249 State: state,
250 Error: err,
251 Client: client,
252 Counts: counts,
253 }
254 switch state {
255 case StateConnected:
256 info.ConnectedAt = time.Now()
257 case StateError:
258 sessions.Del(name)
259 }
260 states.Set(name, info)
261
262 // Publish state change event
263 broker.Publish(pubsub.UpdatedEvent, Event{
264 Type: EventStateChanged,
265 Name: name,
266 State: state,
267 Error: err,
268 Counts: counts,
269 })
270}
271
272func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
273 timeout := mcpTimeout(m)
274 mcpCtx, cancel := context.WithCancel(ctx)
275 cancelTimer := time.AfterFunc(timeout, cancel)
276
277 transport, err := createTransport(mcpCtx, m, resolver)
278 if err != nil {
279 updateState(name, StateError, err, nil, Counts{})
280 slog.Error("Error creating MCP client", "error", err, "name", name)
281 cancel()
282 cancelTimer.Stop()
283 return nil, err
284 }
285
286 client := mcp.NewClient(
287 &mcp.Implementation{
288 Name: "crush",
289 Version: version.Version,
290 Title: "Crush",
291 },
292 &mcp.ClientOptions{
293 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
294 broker.Publish(pubsub.UpdatedEvent, Event{
295 Type: EventToolsListChanged,
296 Name: name,
297 })
298 },
299 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
300 broker.Publish(pubsub.UpdatedEvent, Event{
301 Type: EventPromptsListChanged,
302 Name: name,
303 })
304 },
305 LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
306 slog.Info("MCP log", "name", name, "data", req.Params.Data)
307 },
308 },
309 )
310
311 session, err := client.Connect(mcpCtx, transport, nil)
312 if err != nil {
313 err = maybeStdioErr(err, transport)
314 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
315 slog.Error("MCP client failed to initialize", "error", err, "name", name)
316 cancel()
317 cancelTimer.Stop()
318 return nil, err
319 }
320
321 cancelTimer.Stop()
322 slog.Debug("MCP client initialized", "name", name)
323 return session, nil
324}
325
326// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
327// to parse, and the cli will then close it, causing the EOF error.
328// so, if we got an EOF err, and the transport is STDIO, we try to exec it
329// again with a timeout and collect the output so we can add details to the
330// error.
331// this happens particularly when starting things with npx, e.g. if node can't
332// be found or some other error like that.
333func maybeStdioErr(err error, transport mcp.Transport) error {
334 if !errors.Is(err, io.EOF) {
335 return err
336 }
337 ct, ok := transport.(*mcp.CommandTransport)
338 if !ok {
339 return err
340 }
341 if err2 := stdioCheck(ct.Command); err2 != nil {
342 err = errors.Join(err, err2)
343 }
344 return err
345}
346
347func maybeTimeoutErr(err error, timeout time.Duration) error {
348 if errors.Is(err, context.Canceled) {
349 return fmt.Errorf("timed out after %s", timeout)
350 }
351 return err
352}
353
354func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
355 switch m.Type {
356 case config.MCPStdio:
357 command, err := resolver.ResolveValue(m.Command)
358 if err != nil {
359 return nil, fmt.Errorf("invalid mcp command: %w", err)
360 }
361 if strings.TrimSpace(command) == "" {
362 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
363 }
364 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
365 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
366 return &mcp.CommandTransport{
367 Command: cmd,
368 }, nil
369 case config.MCPHttp:
370 if strings.TrimSpace(m.URL) == "" {
371 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
372 }
373 client := &http.Client{
374 Transport: &headerRoundTripper{
375 headers: m.ResolvedHeaders(),
376 },
377 }
378 return &mcp.StreamableClientTransport{
379 Endpoint: m.URL,
380 HTTPClient: client,
381 }, nil
382 case config.MCPSSE:
383 if strings.TrimSpace(m.URL) == "" {
384 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
385 }
386 client := &http.Client{
387 Transport: &headerRoundTripper{
388 headers: m.ResolvedHeaders(),
389 },
390 }
391 return &mcp.SSEClientTransport{
392 Endpoint: m.URL,
393 HTTPClient: client,
394 }, nil
395 default:
396 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
397 }
398}
399
400type headerRoundTripper struct {
401 headers map[string]string
402}
403
404func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
405 for k, v := range rt.headers {
406 req.Header.Set(k, v)
407 }
408 return http.DefaultTransport.RoundTrip(req)
409}
410
411func mcpTimeout(m config.MCPConfig) time.Duration {
412 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
413}
414
415func stdioCheck(old *exec.Cmd) error {
416 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
417 defer cancel()
418 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
419 cmd.Env = old.Env
420 out, err := cmd.CombinedOutput()
421 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
422 return nil
423 }
424 return fmt.Errorf("%w: %s", err, string(out))
425}