1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28var (
29 sessions = csync.NewMap[string, *mcp.ClientSession]()
30 states = csync.NewMap[string, ClientInfo]()
31 broker = pubsub.NewBroker[Event]()
32)
33
34// State represents the current state of an MCP client
35type State int
36
37const (
38 StateDisabled State = iota
39 StateStarting
40 StateConnected
41 StateError
42)
43
44func (s State) String() string {
45 switch s {
46 case StateDisabled:
47 return "disabled"
48 case StateStarting:
49 return "starting"
50 case StateConnected:
51 return "connected"
52 case StateError:
53 return "error"
54 default:
55 return "unknown"
56 }
57}
58
59// EventType represents the type of MCP event
60type EventType uint
61
62const (
63 EventStateChanged EventType = iota
64 EventToolsListChanged
65 EventPromptsListChanged
66)
67
68// Event represents an event in the MCP system
69type Event struct {
70 Type EventType
71 Name string
72 State State
73 Error error
74 Counts Counts
75}
76
77// Counts number of available tools, prompts, etc.
78type Counts struct {
79 Tools int
80 Prompts int
81}
82
83// ClientInfo holds information about an MCP client's state
84type ClientInfo struct {
85 Name string
86 State State
87 Error error
88 Client *mcp.ClientSession
89 Counts Counts
90 ConnectedAt time.Time
91}
92
93// SubscribeEvents returns a channel for MCP events
94func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
95 return broker.Subscribe(ctx)
96}
97
98// GetStates returns the current state of all MCP clients
99func GetStates() map[string]ClientInfo {
100 return states.Copy()
101}
102
103// GetState returns the state of a specific MCP client
104func GetState(name string) (ClientInfo, bool) {
105 return states.Get(name)
106}
107
108// Close closes all MCP clients. This should be called during application shutdown.
109func Close() error {
110 var errs []error
111 var wg sync.WaitGroup
112 for name, session := range sessions.Seq2() {
113 wg.Go(func() {
114 done := make(chan bool, 1)
115 go func() {
116 if err := session.Close(); err != nil &&
117 !errors.Is(err, io.EOF) &&
118 !errors.Is(err, context.Canceled) &&
119 err.Error() != "signal: killed" {
120 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
121 }
122 done <- true
123 }()
124 select {
125 case <-done:
126 case <-time.After(time.Millisecond * 250):
127 }
128 })
129 }
130 wg.Wait()
131 broker.Shutdown()
132 return errors.Join(errs...)
133}
134
135// Initialize initializes MCP clients based on the provided configuration.
136func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
137 var wg sync.WaitGroup
138 // Initialize states for all configured MCPs
139 for name, m := range cfg.MCP {
140 if m.Disabled {
141 updateState(name, StateDisabled, nil, nil, Counts{})
142 slog.Debug("skipping disabled mcp", "name", name)
143 continue
144 }
145
146 // Set initial starting state
147 updateState(name, StateStarting, nil, nil, Counts{})
148
149 wg.Add(1)
150 go func(name string, m config.MCPConfig) {
151 defer func() {
152 wg.Done()
153 if r := recover(); r != nil {
154 var err error
155 switch v := r.(type) {
156 case error:
157 err = v
158 case string:
159 err = fmt.Errorf("panic: %s", v)
160 default:
161 err = fmt.Errorf("panic: %v", v)
162 }
163 updateState(name, StateError, err, nil, Counts{})
164 slog.Error("panic in mcp client initialization", "error", err, "name", name)
165 }
166 }()
167
168 // createSession handles its own timeout internally.
169 session, err := createSession(ctx, name, m, cfg.Resolver())
170 if err != nil {
171 return
172 }
173
174 tools, err := getTools(ctx, session)
175 if err != nil {
176 slog.Error("error listing tools", "error", err)
177 updateState(name, StateError, err, nil, Counts{})
178 session.Close()
179 return
180 }
181
182 prompts, err := getPrompts(ctx, session)
183 if err != nil {
184 slog.Error("error listing prompts", "error", err)
185 updateState(name, StateError, err, nil, Counts{})
186 session.Close()
187 return
188 }
189
190 toolCount := updateTools(name, tools)
191 updatePrompts(name, prompts)
192 sessions.Set(name, session)
193
194 updateState(name, StateConnected, nil, session, Counts{
195 Tools: toolCount,
196 Prompts: len(prompts),
197 })
198 }(name, m)
199 }
200 wg.Wait()
201}
202
203func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
204 sess, ok := sessions.Get(name)
205 if !ok {
206 return nil, fmt.Errorf("mcp '%s' not available", name)
207 }
208
209 cfg := config.Get()
210 m := cfg.MCP[name]
211 state, _ := states.Get(name)
212
213 timeout := mcpTimeout(m)
214 pingCtx, cancel := context.WithTimeout(ctx, timeout)
215 defer cancel()
216 err := sess.Ping(pingCtx, nil)
217 if err == nil {
218 return sess, nil
219 }
220 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
221
222 sess, err = createSession(ctx, name, m, cfg.Resolver())
223 if err != nil {
224 return nil, err
225 }
226
227 updateState(name, StateConnected, nil, sess, state.Counts)
228 sessions.Set(name, sess)
229 return sess, nil
230}
231
232// updateState updates the state of an MCP client and publishes an event
233func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
234 info := ClientInfo{
235 Name: name,
236 State: state,
237 Error: err,
238 Client: client,
239 Counts: counts,
240 }
241 switch state {
242 case StateConnected:
243 info.ConnectedAt = time.Now()
244 case StateError:
245 sessions.Del(name)
246 }
247 states.Set(name, info)
248
249 // Publish state change event
250 broker.Publish(pubsub.UpdatedEvent, Event{
251 Type: EventStateChanged,
252 Name: name,
253 State: state,
254 Error: err,
255 Counts: counts,
256 })
257}
258
259func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
260 timeout := mcpTimeout(m)
261 mcpCtx, cancel := context.WithCancel(ctx)
262 cancelTimer := time.AfterFunc(timeout, cancel)
263
264 transport, err := createTransport(mcpCtx, m, resolver)
265 if err != nil {
266 updateState(name, StateError, err, nil, Counts{})
267 slog.Error("error creating mcp client", "error", err, "name", name)
268 cancel()
269 cancelTimer.Stop()
270 return nil, err
271 }
272
273 client := mcp.NewClient(
274 &mcp.Implementation{
275 Name: "crush",
276 Version: version.Version,
277 Title: "Crush",
278 },
279 &mcp.ClientOptions{
280 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
281 broker.Publish(pubsub.UpdatedEvent, Event{
282 Type: EventToolsListChanged,
283 Name: name,
284 })
285 },
286 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
287 broker.Publish(pubsub.UpdatedEvent, Event{
288 Type: EventPromptsListChanged,
289 Name: name,
290 })
291 },
292 LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
293 slog.Info("MCP log", "name", name, "data", req.Params.Data)
294 },
295 },
296 )
297
298 session, err := client.Connect(mcpCtx, transport, nil)
299 if err != nil {
300 err = maybeStdioErr(err, transport)
301 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
302 slog.Error("MCP client failed to initialize", "error", err, "name", name)
303 cancel()
304 cancelTimer.Stop()
305 return nil, err
306 }
307
308 cancelTimer.Stop()
309 slog.Info("MCP client initialized", "name", name)
310 return session, nil
311}
312
313// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
314// to parse, and the cli will then close it, causing the EOF error.
315// so, if we got an EOF err, and the transport is STDIO, we try to exec it
316// again with a timeout and collect the output so we can add details to the
317// error.
318// this happens particularly when starting things with npx, e.g. if node can't
319// be found or some other error like that.
320func maybeStdioErr(err error, transport mcp.Transport) error {
321 if !errors.Is(err, io.EOF) {
322 return err
323 }
324 ct, ok := transport.(*mcp.CommandTransport)
325 if !ok {
326 return err
327 }
328 if err2 := stdioCheck(ct.Command); err2 != nil {
329 err = errors.Join(err, err2)
330 }
331 return err
332}
333
334func maybeTimeoutErr(err error, timeout time.Duration) error {
335 if errors.Is(err, context.Canceled) {
336 return fmt.Errorf("timed out after %s", timeout)
337 }
338 return err
339}
340
341func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
342 switch m.Type {
343 case config.MCPStdio:
344 command, err := resolver.ResolveValue(m.Command)
345 if err != nil {
346 return nil, fmt.Errorf("invalid mcp command: %w", err)
347 }
348 if strings.TrimSpace(command) == "" {
349 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
350 }
351 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
352 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
353 return &mcp.CommandTransport{
354 Command: cmd,
355 }, nil
356 case config.MCPHttp:
357 if strings.TrimSpace(m.URL) == "" {
358 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
359 }
360 client := &http.Client{
361 Transport: &headerRoundTripper{
362 headers: m.ResolvedHeaders(),
363 },
364 }
365 return &mcp.StreamableClientTransport{
366 Endpoint: m.URL,
367 HTTPClient: client,
368 }, nil
369 case config.MCPSSE:
370 if strings.TrimSpace(m.URL) == "" {
371 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
372 }
373 client := &http.Client{
374 Transport: &headerRoundTripper{
375 headers: m.ResolvedHeaders(),
376 },
377 }
378 return &mcp.SSEClientTransport{
379 Endpoint: m.URL,
380 HTTPClient: client,
381 }, nil
382 default:
383 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
384 }
385}
386
387type headerRoundTripper struct {
388 headers map[string]string
389}
390
391func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
392 for k, v := range rt.headers {
393 req.Header.Set(k, v)
394 }
395 return http.DefaultTransport.RoundTrip(req)
396}
397
398func mcpTimeout(m config.MCPConfig) time.Duration {
399 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
400}
401
402func stdioCheck(old *exec.Cmd) error {
403 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
404 defer cancel()
405 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
406 cmd.Env = old.Env
407 out, err := cmd.CombinedOutput()
408 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
409 return nil
410 }
411 return fmt.Errorf("%w: %s", err, string(out))
412}