1// Package mcp provides functionality for managing Model Context Protocol (MCP)
2// clients within the Crush application.
3package mcp
4
5import (
6 "cmp"
7 "context"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "os"
14 "os/exec"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/home"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/version"
25 "github.com/modelcontextprotocol/go-sdk/mcp"
26)
27
28var (
29 sessions = csync.NewMap[string, *mcp.ClientSession]()
30 states = csync.NewMap[string, ClientInfo]()
31 broker = pubsub.NewBroker[Event]()
32)
33
34// State represents the current state of an MCP client
35type State int
36
37const (
38 StateDisabled State = iota
39 StateStarting
40 StateConnected
41 StateError
42)
43
44func (s State) String() string {
45 switch s {
46 case StateDisabled:
47 return "disabled"
48 case StateStarting:
49 return "starting"
50 case StateConnected:
51 return "connected"
52 case StateError:
53 return "error"
54 default:
55 return "unknown"
56 }
57}
58
59// EventType represents the type of MCP event
60type EventType uint
61
62const (
63 EventStateChanged EventType = iota
64 EventToolsListChanged
65 EventPromptsListChanged
66)
67
68// Event represents an event in the MCP system
69type Event struct {
70 Type EventType
71 Name string
72 State State
73 Error error
74 Counts Counts
75}
76
77// Counts number of available tools, prompts, etc.
78type Counts struct {
79 Tools int
80 Prompts int
81}
82
83// ClientInfo holds information about an MCP client's state
84type ClientInfo struct {
85 Name string
86 State State
87 Error error
88 Client *mcp.ClientSession
89 Counts Counts
90 ConnectedAt time.Time
91}
92
93// SubscribeEvents returns a channel for MCP events
94func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
95 return broker.Subscribe(ctx)
96}
97
98// GetStates returns the current state of all MCP clients
99func GetStates() map[string]ClientInfo {
100 return states.Copy()
101}
102
103// GetState returns the state of a specific MCP client
104func GetState(name string) (ClientInfo, bool) {
105 return states.Get(name)
106}
107
108// Close closes all MCP clients. This should be called during application shutdown.
109func Close() error {
110 var wg sync.WaitGroup
111 done := make(chan struct{}, 1)
112 go func() {
113 for name, session := range sessions.Seq2() {
114 wg.Go(func() {
115 if err := session.Close(); err != nil &&
116 !errors.Is(err, io.EOF) &&
117 !errors.Is(err, context.Canceled) &&
118 err.Error() != "signal: killed" {
119 slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
120 }
121 })
122 }
123 wg.Wait()
124 done <- struct{}{}
125 }()
126 select {
127 case <-done:
128 case <-time.After(5 * time.Second):
129 }
130 broker.Shutdown()
131 return nil
132}
133
134// Initialize initializes MCP clients based on the provided configuration.
135func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
136 var wg sync.WaitGroup
137 // Initialize states for all configured MCPs
138 for name, m := range cfg.MCP {
139 if m.Disabled {
140 updateState(name, StateDisabled, nil, nil, Counts{})
141 slog.Debug("skipping disabled mcp", "name", name)
142 continue
143 }
144
145 // Set initial starting state
146 updateState(name, StateStarting, nil, nil, Counts{})
147
148 wg.Add(1)
149 go func(name string, m config.MCPConfig) {
150 defer func() {
151 wg.Done()
152 if r := recover(); r != nil {
153 var err error
154 switch v := r.(type) {
155 case error:
156 err = v
157 case string:
158 err = fmt.Errorf("panic: %s", v)
159 default:
160 err = fmt.Errorf("panic: %v", v)
161 }
162 updateState(name, StateError, err, nil, Counts{})
163 slog.Error("panic in mcp client initialization", "error", err, "name", name)
164 }
165 }()
166
167 // createSession handles its own timeout internally.
168 session, err := createSession(ctx, name, m, cfg.Resolver())
169 if err != nil {
170 return
171 }
172
173 tools, err := getTools(ctx, session)
174 if err != nil {
175 slog.Error("error listing tools", "error", err)
176 updateState(name, StateError, err, nil, Counts{})
177 session.Close()
178 return
179 }
180
181 prompts, err := getPrompts(ctx, session)
182 if err != nil {
183 slog.Error("error listing prompts", "error", err)
184 updateState(name, StateError, err, nil, Counts{})
185 session.Close()
186 return
187 }
188
189 toolCount := updateTools(name, tools)
190 updatePrompts(name, prompts)
191 sessions.Set(name, session)
192
193 updateState(name, StateConnected, nil, session, Counts{
194 Tools: toolCount,
195 Prompts: len(prompts),
196 })
197 }(name, m)
198 }
199 wg.Wait()
200}
201
202func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
203 sess, ok := sessions.Get(name)
204 if !ok {
205 return nil, fmt.Errorf("mcp '%s' not available", name)
206 }
207
208 cfg := config.Get()
209 m := cfg.MCP[name]
210 state, _ := states.Get(name)
211
212 timeout := mcpTimeout(m)
213 pingCtx, cancel := context.WithTimeout(ctx, timeout)
214 defer cancel()
215 err := sess.Ping(pingCtx, nil)
216 if err == nil {
217 return sess, nil
218 }
219 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
220
221 sess, err = createSession(ctx, name, m, cfg.Resolver())
222 if err != nil {
223 return nil, err
224 }
225
226 updateState(name, StateConnected, nil, sess, state.Counts)
227 sessions.Set(name, sess)
228 return sess, nil
229}
230
231// updateState updates the state of an MCP client and publishes an event
232func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
233 info := ClientInfo{
234 Name: name,
235 State: state,
236 Error: err,
237 Client: client,
238 Counts: counts,
239 }
240 switch state {
241 case StateConnected:
242 info.ConnectedAt = time.Now()
243 case StateError:
244 sessions.Del(name)
245 }
246 states.Set(name, info)
247
248 // Publish state change event
249 broker.Publish(pubsub.UpdatedEvent, Event{
250 Type: EventStateChanged,
251 Name: name,
252 State: state,
253 Error: err,
254 Counts: counts,
255 })
256}
257
258func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
259 timeout := mcpTimeout(m)
260 mcpCtx, cancel := context.WithCancel(ctx)
261 cancelTimer := time.AfterFunc(timeout, cancel)
262
263 transport, err := createTransport(mcpCtx, m, resolver)
264 if err != nil {
265 updateState(name, StateError, err, nil, Counts{})
266 slog.Error("error creating mcp client", "error", err, "name", name)
267 cancel()
268 cancelTimer.Stop()
269 return nil, err
270 }
271
272 client := mcp.NewClient(
273 &mcp.Implementation{
274 Name: "crush",
275 Version: version.Version,
276 Title: "Crush",
277 },
278 &mcp.ClientOptions{
279 ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
280 broker.Publish(pubsub.UpdatedEvent, Event{
281 Type: EventToolsListChanged,
282 Name: name,
283 })
284 },
285 PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
286 broker.Publish(pubsub.UpdatedEvent, Event{
287 Type: EventPromptsListChanged,
288 Name: name,
289 })
290 },
291 LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
292 slog.Info("MCP log", "name", name, "data", req.Params.Data)
293 },
294 },
295 )
296
297 session, err := client.Connect(mcpCtx, transport, nil)
298 if err != nil {
299 err = maybeStdioErr(err, transport)
300 updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
301 slog.Error("MCP client failed to initialize", "error", err, "name", name)
302 cancel()
303 cancelTimer.Stop()
304 return nil, err
305 }
306
307 cancelTimer.Stop()
308 slog.Info("MCP client initialized", "name", name)
309 return session, nil
310}
311
312// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
313// to parse, and the cli will then close it, causing the EOF error.
314// so, if we got an EOF err, and the transport is STDIO, we try to exec it
315// again with a timeout and collect the output so we can add details to the
316// error.
317// this happens particularly when starting things with npx, e.g. if node can't
318// be found or some other error like that.
319func maybeStdioErr(err error, transport mcp.Transport) error {
320 if !errors.Is(err, io.EOF) {
321 return err
322 }
323 ct, ok := transport.(*mcp.CommandTransport)
324 if !ok {
325 return err
326 }
327 if err2 := stdioCheck(ct.Command); err2 != nil {
328 err = errors.Join(err, err2)
329 }
330 return err
331}
332
333func maybeTimeoutErr(err error, timeout time.Duration) error {
334 if errors.Is(err, context.Canceled) {
335 return fmt.Errorf("timed out after %s", timeout)
336 }
337 return err
338}
339
340func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
341 switch m.Type {
342 case config.MCPStdio:
343 command, err := resolver.ResolveValue(m.Command)
344 if err != nil {
345 return nil, fmt.Errorf("invalid mcp command: %w", err)
346 }
347 if strings.TrimSpace(command) == "" {
348 return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
349 }
350 cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
351 cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
352 return &mcp.CommandTransport{
353 Command: cmd,
354 }, nil
355 case config.MCPHttp:
356 if strings.TrimSpace(m.URL) == "" {
357 return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
358 }
359 client := &http.Client{
360 Transport: &headerRoundTripper{
361 headers: m.ResolvedHeaders(),
362 },
363 }
364 return &mcp.StreamableClientTransport{
365 Endpoint: m.URL,
366 HTTPClient: client,
367 }, nil
368 case config.MCPSSE:
369 if strings.TrimSpace(m.URL) == "" {
370 return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
371 }
372 client := &http.Client{
373 Transport: &headerRoundTripper{
374 headers: m.ResolvedHeaders(),
375 },
376 }
377 return &mcp.SSEClientTransport{
378 Endpoint: m.URL,
379 HTTPClient: client,
380 }, nil
381 default:
382 return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
383 }
384}
385
386type headerRoundTripper struct {
387 headers map[string]string
388}
389
390func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
391 for k, v := range rt.headers {
392 req.Header.Set(k, v)
393 }
394 return http.DefaultTransport.RoundTrip(req)
395}
396
397func mcpTimeout(m config.MCPConfig) time.Duration {
398 return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
399}
400
401func stdioCheck(old *exec.Cmd) error {
402 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
403 defer cancel()
404 cmd := exec.CommandContext(ctx, old.Path, old.Args...)
405 cmd.Env = old.Env
406 out, err := cmd.CombinedOutput()
407 if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
408 return nil
409 }
410 return fmt.Errorf("%w: %s", err, string(out))
411}