1package mcp
2
3import (
4 "cmp"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "log/slog"
10 "maps"
11 "net/http"
12 "os"
13 "os/exec"
14 "strings"
15 "sync"
16 "time"
17
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/home"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/pubsub"
23 "github.com/charmbracelet/crush/internal/version"
24 "github.com/modelcontextprotocol/go-sdk/mcp"
25)
26
27var (
28 sessions = csync.NewMap[string, *mcp.ClientSession]()
29 states = csync.NewMap[string, ClientInfo]()
30 broker = pubsub.NewBroker[Event]()
31)
32
33// State represents the current state of an MCP client
34type State int
35
36const (
37 StateDisabled State = iota
38 StateStarting
39 StateConnected
40 StateError
41)
42
43func (s State) String() string {
44 switch s {
45 case StateDisabled:
46 return "disabled"
47 case StateStarting:
48 return "starting"
49 case StateConnected:
50 return "connected"
51 case StateError:
52 return "error"
53 default:
54 return "unknown"
55 }
56}
57
58// EventType represents the type of MCP event
59type EventType string
60
61const (
62 EventStateChanged EventType = "state_changed"
63 EventToolsListChanged EventType = "tools_list_changed"
64 EventPromptsListChanged EventType = "prompts_list_changed"
65)
66
67// Event represents an event in the MCP system
68type Event struct {
69 Type EventType
70 Name string
71 State State
72 Error error
73 Counts Counts
74}
75
76// Counts number of available tools, prompts, etc.
77type Counts struct {
78 Tools int
79 Prompts int
80}
81
82// ClientInfo holds information about an MCP client's state
83type ClientInfo struct {
84 Name string
85 State State
86 Error error
87 Client *mcp.ClientSession
88 Counts Counts
89 ConnectedAt time.Time
90}
91
92// SubscribeEvents returns a channel for MCP events
93func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
94 return broker.Subscribe(ctx)
95}
96
97// GetStates returns the current state of all MCP clients
98func GetStates() map[string]ClientInfo {
99 return maps.Collect(states.Seq2())
100}
101
102// GetState returns the state of a specific MCP client
103func GetState(name string) (ClientInfo, bool) {
104 return states.Get(name)
105}
106
107// Close closes all MCP clients. This should be called during application shutdown.
108func Close() error {
109 var errs []error
110 for name, c := range sessions.Seq2() {
111 if err := c.Close(); err != nil &&
112 !errors.Is(err, io.EOF) &&
113 !errors.Is(err, context.Canceled) &&
114 err.Error() != "signal: killed" {
115 errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
116 }
117 }
118 broker.Shutdown()
119 return errors.Join(errs...)
120}
121
122// Initialize initializes MCP clients based on the provided configuration.
123func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
124 var wg sync.WaitGroup
125 // Initialize states for all configured MCPs
126 for name, m := range cfg.MCP {
127 if m.Disabled {
128 updateState(name, StateDisabled, nil, nil, Counts{})
129 slog.Debug("skipping disabled mcp", "name", name)
130 continue
131 }
132
133 // Set initial starting state
134 updateState(name, StateStarting, nil, nil, Counts{})
135
136 wg.Add(1)
137 go func(name string, m config.MCPConfig) {
138 defer func() {
139 wg.Done()
140 if r := recover(); r != nil {
141 var err error
142 switch v := r.(type) {
143 case error:
144 err = v
145 case string:
146 err = fmt.Errorf("panic: %s", v)
147 default:
148 err = fmt.Errorf("panic: %v", v)
149 }
150 updateState(name, StateError, err, nil, Counts{})
151 slog.Error("panic in mcp client initialization", "error", err, "name", name)
152 }
153 }()
154
155 ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
156 defer cancel()
157
158 session, err := createSession(ctx, name, m, cfg.Resolver())
159 if err != nil {
160 return
161 }
162
163 tools, err := getTools(ctx, name, permissions, session, cfg.WorkingDir())
164 if err != nil {
165 slog.Error("error listing tools", "error", err)
166 updateState(name, StateError, err, nil, Counts{})
167 session.Close()
168 return
169 }
170
171 prompts, err := getPrompts(ctx, session)
172 if err != nil {
173 slog.Error("error listing prompts", "error", err)
174 updateState(name, StateError, err, nil, Counts{})
175 session.Close()
176 return
177 }
178
179 updateTools(name, tools)
180 updatePrompts(name, prompts)
181 sessions.Set(name, session)
182
183 updateState(name, StateConnected, nil, session, Counts{
184 Tools: len(tools),
185 Prompts: len(prompts),
186 })
187 }(name, m)
188 }
189 wg.Wait()
190}
191
192func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
193 sess, ok := sessions.Get(name)
194 if !ok {
195 return nil, fmt.Errorf("mcp '%s' not available", name)
196 }
197
198 cfg := config.Get()
199 m := cfg.MCP[name]
200 state, _ := states.Get(name)
201
202 timeout := mcpTimeout(m)
203 pingCtx, cancel := context.WithTimeout(ctx, timeout)
204 defer cancel()
205 err := sess.Ping(pingCtx, nil)
206 if err == nil {
207 return sess, nil
208 }
209 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
210
211 sess, err = createSession(ctx, name, m, cfg.Resolver())
212 if err != nil {
213 return nil, err
214 }
215
216 updateState(name, StateConnected, nil, sess, state.Counts)
217 sessions.Set(name, sess)
218 return sess, nil
219}
220
221// updateState updates the state of an MCP client and publishes an event
222func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
223 info := ClientInfo{
224 Name: name,
225 State: state,
226 Error: err,
227 Client: client,
228 Counts: counts,
229 }
230 switch state {
231 case StateConnected:
232 info.ConnectedAt = time.Now()
233 case StateError:
234 updateTools(name, nil)
235 sessions.Del(name)
236 }
237 states.Set(name, info)
238
239 // Publish state change event
240 broker.Publish(pubsub.UpdatedEvent, Event{
241 Type: EventStateChanged,
242 Name: name,
243 State: state,
244 Error: err,
245 Counts: counts,
246 })
247}
248
249func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
250 timeout := mcpTimeout(m)
251 mcpCtx, cancel := context.WithCancel(ctx)
252 cancelTimer := time.AfterFunc(timeout, cancel)
253
254 transport, err := createTransport(mcpCtx, m, resolver)
255 if err != nil {
256 updateState(name, StateError, err, nil, Counts{})
257 slog.Error("error creating mcp client", "error", err, "name", name)
258 cancel()
259 cancelTimer.Stop()
260 return nil, err
261 }
262
263 client := mcp.NewClient(
264 &mcp.Implementation{
265 Name: "crush",
266 Version: version.Version,
267 Title: "Crush",
268 },
269 &mcp.ClientOptions{
270 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
271 broker.Publish(pubsub.UpdatedEvent, Event{
272 Type: EventToolsListChanged,
273 Name: name,
274 })
275 },
276 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
277 broker.Publish(pubsub.UpdatedEvent, Event{
278 Type: EventPromptsListChanged,
279 Name: name,
280 })
281 },
282 KeepAlive: time.Minute * 10,
283 },
284 )
285
286 session, err := client.Connect(mcpCtx, transport, nil)
287 if err != nil {
288 err = maybeStdioErr(err, transport)
289 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
290 slog.Error("error starting mcp client", "error", err, "name", name)
291 cancel()
292 cancelTimer.Stop()
293 return nil, err
294 }
295
296 cancelTimer.Stop()
297 slog.Info("Initialized mcp client", "name", name)
298 return session, nil
299}
300
301// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
302// to parse, and the cli will then close it, causing the EOF error.
303// so, if we got an EOF err, and the transport is STDIO, we try to exec it
304// again with a timeout and collect the output so we can add details to the
305// error.
306// this happens particularly when starting things with npx, e.g. if node can't
307// be found or some other error like that.
308func maybeStdioErr(err error, transport mcp.Transport) error {
309 if !errors.Is(err, io.EOF) {
310 return err
311 }
312 ct, ok := transport.(*mcp.CommandTransport)
313 if !ok {
314 return err
315 }
316 if err2 := stdioCheck(ct.Command); err2 != nil {
317 err = errors.Join(err, err2)
318 }
319 return err
320}
321
322func maybeTimeoutErr(err error, timeout time.Duration) error {
323 if errors.Is(err, context.Canceled) {
324 return fmt.Errorf("timed out after %s", timeout)
325 }
326 return err
327}
328
329func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
330 switch m.Type {
331 case config.MCPStdio:
332 command, err := resolver.ResolveValue(m.Command)
333 if err != nil {
334 return nil, fmt.Errorf("invalid mcp command: %w", err)
335 }
336 if strings.TrimSpace(command) == "" {
337 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
338 }
339 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
340 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
341 return &mcp.CommandTransport{
342 Command: cmd,
343 }, nil
344 case config.MCPHttp:
345 if strings.TrimSpace(m.URL) == "" {
346 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
347 }
348 client := &http.Client{
349 Transport: &headerRoundTripper{
350 headers: m.ResolvedHeaders(),
351 },
352 }
353 return &mcp.StreamableClientTransport{
354 Endpoint: m.URL,
355 HTTPClient: client,
356 }, nil
357 case config.MCPSSE:
358 if strings.TrimSpace(m.URL) == "" {
359 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
360 }
361 client := &http.Client{
362 Transport: &headerRoundTripper{
363 headers: m.ResolvedHeaders(),
364 },
365 }
366 return &mcp.SSEClientTransport{
367 Endpoint: m.URL,
368 HTTPClient: client,
369 }, nil
370 default:
371 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
372 }
373}
374
375type headerRoundTripper struct {
376 headers map[string]string
377}
378
379func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
380 for k, v := range rt.headers {
381 req.Header.Set(k, v)
382 }
383 return http.DefaultTransport.RoundTrip(req)
384}
385
386func mcpTimeout(m config.MCPConfig) time.Duration {
387 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
388}
389
390func stdioCheck(old *exec.Cmd) error {
391 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
392 defer cancel()
393 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
394 cmd.Env = old.Env
395 out, err := cmd.CombinedOutput()
396 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
397 return nil
398 }
399 return fmt.Errorf("%w: %s", err, string(out))
400}