1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "os/exec"
16 "strings"
17 "sync"
18 "time"
19
20 "github.com/charmbracelet/crush/internal/config"
21 "github.com/charmbracelet/crush/internal/csync"
22 "github.com/charmbracelet/crush/internal/home"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/version"
26 "github.com/modelcontextprotocol/go-sdk/mcp"
27)
28
29var (
30 sessions = csync.NewMap[string, *mcp.ClientSession]()
31 states = csync.NewMap[string, ClientInfo]()
32 broker = pubsub.NewBroker[Event]()
33)
34
35// State represents the current state of an MCP client
36type State int
37
38const (
39 StateDisabled State = iota
40 StateStarting
41 StateConnected
42 StateError
43)
44
45func (s State) String() string {
46 switch s {
47 case StateDisabled:
48 return "disabled"
49 case StateStarting:
50 return "starting"
51 case StateConnected:
52 return "connected"
53 case StateError:
54 return "error"
55 default:
56 return "unknown"
57 }
58}
59
60// EventType represents the type of MCP event
61type EventType uint
62
63const (
64 EventStateChanged EventType = iota
65 EventToolsListChanged
66 EventPromptsListChanged
67)
68
69// Event represents an event in the MCP system
70type Event struct {
71 Type EventType
72 Name string
73 State State
74 Error error
75 Counts Counts
76}
77
78// Counts number of available tools, prompts, etc.
79type Counts struct {
80 Tools int
81 Prompts int
82}
83
84// ClientInfo holds information about an MCP client's state
85type ClientInfo struct {
86 Name string
87 State State
88 Error error
89 Client *mcp.ClientSession
90 Counts Counts
91 ConnectedAt time.Time
92}
93
94// SubscribeEvents returns a channel for MCP events
95func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
96 return broker.Subscribe(ctx)
97}
98
99// GetStates returns the current state of all MCP clients
100func GetStates() map[string]ClientInfo {
101 return maps.Collect(states.Seq2())
102}
103
104// GetState returns the state of a specific MCP client
105func GetState(name string) (ClientInfo, bool) {
106 return states.Get(name)
107}
108
109// Close closes all MCP clients. This should be called during application shutdown.
110func Close() error {
111 var errs []error
112 var wg sync.WaitGroup
113 for name, session := range sessions.Seq2() {
114 wg.Go(func() {
115 done := make(chan bool, 1)
116 go func() {
117 if err := session.Close(); err != nil &&
118 !errors.Is(err, io.EOF) &&
119 !errors.Is(err, context.Canceled) &&
120 err.Error() != "signal: killed" {
121 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
122 }
123 done <- true
124 }()
125 select {
126 case <-done:
127 case <-time.After(time.Millisecond * 250):
128 }
129 })
130 }
131 wg.Wait()
132 broker.Shutdown()
133 return errors.Join(errs...)
134}
135
136// Initialize initializes MCP clients based on the provided configuration.
137func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
138 var wg sync.WaitGroup
139 // Initialize states for all configured MCPs
140 for name, m := range cfg.MCP {
141 if m.Disabled {
142 updateState(name, StateDisabled, nil, nil, Counts{})
143 slog.Debug("skipping disabled mcp", "name", name)
144 continue
145 }
146
147 // Set initial starting state
148 updateState(name, StateStarting, nil, nil, Counts{})
149
150 wg.Add(1)
151 go func(name string, m config.MCPConfig) {
152 defer func() {
153 wg.Done()
154 if r := recover(); r != nil {
155 var err error
156 switch v := r.(type) {
157 case error:
158 err = v
159 case string:
160 err = fmt.Errorf("panic: %s", v)
161 default:
162 err = fmt.Errorf("panic: %v", v)
163 }
164 updateState(name, StateError, err, nil, Counts{})
165 slog.Error("panic in mcp client initialization", "error", err, "name", name)
166 }
167 }()
168
169 // createSession handles its own timeout internally.
170 session, err := createSession(ctx, name, m, cfg.Resolver())
171 if err != nil {
172 return
173 }
174
175 tools, err := getTools(ctx, session)
176 if err != nil {
177 slog.Error("error listing tools", "error", err)
178 updateState(name, StateError, err, nil, Counts{})
179 session.Close()
180 return
181 }
182
183 prompts, err := getPrompts(ctx, session)
184 if err != nil {
185 slog.Error("error listing prompts", "error", err)
186 updateState(name, StateError, err, nil, Counts{})
187 session.Close()
188 return
189 }
190
191 updateTools(name, tools)
192 updatePrompts(name, prompts)
193 sessions.Set(name, session)
194
195 updateState(name, StateConnected, nil, session, Counts{
196 Tools: len(tools),
197 Prompts: len(prompts),
198 })
199 }(name, m)
200 }
201 wg.Wait()
202}
203
204func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
205 sess, ok := sessions.Get(name)
206 if !ok {
207 return nil, fmt.Errorf("mcp '%s' not available", name)
208 }
209
210 cfg := config.Get()
211 m := cfg.MCP[name]
212 state, _ := states.Get(name)
213
214 timeout := mcpTimeout(m)
215 pingCtx, cancel := context.WithTimeout(ctx, timeout)
216 defer cancel()
217 err := sess.Ping(pingCtx, nil)
218 if err == nil {
219 return sess, nil
220 }
221 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
222
223 sess, err = createSession(ctx, name, m, cfg.Resolver())
224 if err != nil {
225 return nil, err
226 }
227
228 updateState(name, StateConnected, nil, sess, state.Counts)
229 sessions.Set(name, sess)
230 return sess, nil
231}
232
233// updateState updates the state of an MCP client and publishes an event
234func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
235 info := ClientInfo{
236 Name: name,
237 State: state,
238 Error: err,
239 Client: client,
240 Counts: counts,
241 }
242 switch state {
243 case StateConnected:
244 info.ConnectedAt = time.Now()
245 case StateError:
246 sessions.Del(name)
247 }
248 states.Set(name, info)
249
250 // Publish state change event
251 broker.Publish(pubsub.UpdatedEvent, Event{
252 Type: EventStateChanged,
253 Name: name,
254 State: state,
255 Error: err,
256 Counts: counts,
257 })
258}
259
260func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
261 timeout := mcpTimeout(m)
262 mcpCtx, cancel := context.WithCancel(ctx)
263 cancelTimer := time.AfterFunc(timeout, cancel)
264
265 transport, err := createTransport(mcpCtx, m, resolver)
266 if err != nil {
267 updateState(name, StateError, err, nil, Counts{})
268 slog.Error("error creating mcp client", "error", err, "name", name)
269 cancel()
270 cancelTimer.Stop()
271 return nil, err
272 }
273
274 client := mcp.NewClient(
275 &mcp.Implementation{
276 Name: "crush",
277 Version: version.Version,
278 Title: "Crush",
279 },
280 &mcp.ClientOptions{
281 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
282 broker.Publish(pubsub.UpdatedEvent, Event{
283 Type: EventToolsListChanged,
284 Name: name,
285 })
286 },
287 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
288 broker.Publish(pubsub.UpdatedEvent, Event{
289 Type: EventPromptsListChanged,
290 Name: name,
291 })
292 },
293 LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
294 slog.Info("MCP log", "name", name, "data", req.Params.Data)
295 },
296 },
297 )
298
299 session, err := client.Connect(mcpCtx, transport, nil)
300 if err != nil {
301 err = maybeStdioErr(err, transport)
302 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
303 slog.Error("MCP client failed to initialize", "error", err, "name", name)
304 cancel()
305 cancelTimer.Stop()
306 return nil, err
307 }
308
309 cancelTimer.Stop()
310 slog.Info("MCP client initialized", "name", name)
311 return session, nil
312}
313
314// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
315// to parse, and the cli will then close it, causing the EOF error.
316// so, if we got an EOF err, and the transport is STDIO, we try to exec it
317// again with a timeout and collect the output so we can add details to the
318// error.
319// this happens particularly when starting things with npx, e.g. if node can't
320// be found or some other error like that.
321func maybeStdioErr(err error, transport mcp.Transport) error {
322 if !errors.Is(err, io.EOF) {
323 return err
324 }
325 ct, ok := transport.(*mcp.CommandTransport)
326 if !ok {
327 return err
328 }
329 if err2 := stdioCheck(ct.Command); err2 != nil {
330 err = errors.Join(err, err2)
331 }
332 return err
333}
334
335func maybeTimeoutErr(err error, timeout time.Duration) error {
336 if errors.Is(err, context.Canceled) {
337 return fmt.Errorf("timed out after %s", timeout)
338 }
339 return err
340}
341
342func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
343 switch m.Type {
344 case config.MCPStdio:
345 command, err := resolver.ResolveValue(m.Command)
346 if err != nil {
347 return nil, fmt.Errorf("invalid mcp command: %w", err)
348 }
349 if strings.TrimSpace(command) == "" {
350 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
351 }
352 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
353 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
354 return &mcp.CommandTransport{
355 Command: cmd,
356 }, nil
357 case config.MCPHttp:
358 if strings.TrimSpace(m.URL) == "" {
359 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
360 }
361 client := &http.Client{
362 Transport: &headerRoundTripper{
363 headers: m.ResolvedHeaders(),
364 },
365 }
366 return &mcp.StreamableClientTransport{
367 Endpoint: m.URL,
368 HTTPClient: client,
369 }, nil
370 case config.MCPSSE:
371 if strings.TrimSpace(m.URL) == "" {
372 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
373 }
374 client := &http.Client{
375 Transport: &headerRoundTripper{
376 headers: m.ResolvedHeaders(),
377 },
378 }
379 return &mcp.SSEClientTransport{
380 Endpoint: m.URL,
381 HTTPClient: client,
382 }, nil
383 default:
384 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
385 }
386}
387
388type headerRoundTripper struct {
389 headers map[string]string
390}
391
392func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
393 for k, v := range rt.headers {
394 req.Header.Set(k, v)
395 }
396 return http.DefaultTransport.RoundTrip(req)
397}
398
399func mcpTimeout(m config.MCPConfig) time.Duration {
400 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
401}
402
403func stdioCheck(old *exec.Cmd) error {
404 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
405 defer cancel()
406 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
407 cmd.Env = old.Env
408 out, err := cmd.CombinedOutput()
409 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
410 return nil
411 }
412 return fmt.Errorf("%w: %s", err, string(out))
413}