init.go

  1// Package mcp provides functionality for managing Model Context Protocol (MCP)
  2// clients within the Crush application.
  3package mcp
  4
  5import (
  6	"cmp"
  7	"context"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"log/slog"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/home"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/version"
 25	"github.com/modelcontextprotocol/go-sdk/mcp"
 26)
 27
 28func parseLevel(level mcp.LoggingLevel) slog.Level {
 29	switch level {
 30	case "info":
 31		return slog.LevelInfo
 32	case "notice":
 33		return slog.LevelInfo
 34	case "warning":
 35		return slog.LevelWarn
 36	default:
 37		return slog.LevelDebug
 38	}
 39}
 40
 41// ClientSession wraps an mcp.ClientSession with a context cancel function so
 42// that the context created during session establishment is properly cleaned up
 43// on close.
 44type ClientSession struct {
 45	*mcp.ClientSession
 46	cancel context.CancelFunc
 47}
 48
 49// Close cancels the session context and then closes the underlying session.
 50func (s *ClientSession) Close() error {
 51	s.cancel()
 52	return s.ClientSession.Close()
 53}
 54
 55var (
 56	sessions = csync.NewMap[string, *ClientSession]()
 57	states   = csync.NewMap[string, ClientInfo]()
 58	broker   = pubsub.NewBroker[Event]()
 59	initOnce sync.Once
 60	initDone = make(chan struct{})
 61)
 62
 63// State represents the current state of an MCP client
 64type State int
 65
 66const (
 67	StateDisabled State = iota
 68	StateStarting
 69	StateConnected
 70	StateError
 71)
 72
 73func (s State) String() string {
 74	switch s {
 75	case StateDisabled:
 76		return "disabled"
 77	case StateStarting:
 78		return "starting"
 79	case StateConnected:
 80		return "connected"
 81	case StateError:
 82		return "error"
 83	default:
 84		return "unknown"
 85	}
 86}
 87
 88// EventType represents the type of MCP event
 89type EventType uint
 90
 91const (
 92	EventStateChanged EventType = iota
 93	EventToolsListChanged
 94	EventPromptsListChanged
 95	EventResourcesListChanged
 96)
 97
 98// Event represents an event in the MCP system
 99type Event struct {
100	Type   EventType
101	Name   string
102	State  State
103	Error  error
104	Counts Counts
105}
106
107// Counts number of available tools, prompts, etc.
108type Counts struct {
109	Tools     int
110	Prompts   int
111	Resources int
112}
113
114// ClientInfo holds information about an MCP client's state
115type ClientInfo struct {
116	Name        string
117	State       State
118	Error       error
119	Client      *ClientSession
120	Counts      Counts
121	ConnectedAt time.Time
122}
123
124// SubscribeEvents returns a channel for MCP events
125func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
126	return broker.Subscribe(ctx)
127}
128
129// GetStates returns the current state of all MCP clients
130func GetStates() map[string]ClientInfo {
131	return states.Copy()
132}
133
134// GetState returns the state of a specific MCP client
135func GetState(name string) (ClientInfo, bool) {
136	return states.Get(name)
137}
138
139// Close closes all MCP clients. This should be called during application shutdown.
140func Close(ctx context.Context) error {
141	var wg sync.WaitGroup
142	for name, session := range sessions.Seq2() {
143		wg.Go(func() {
144			done := make(chan error, 1)
145			go func() {
146				done <- session.Close()
147			}()
148			select {
149			case err := <-done:
150				if err != nil &&
151					!errors.Is(err, io.EOF) &&
152					!errors.Is(err, context.Canceled) &&
153					err.Error() != "signal: killed" {
154					slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
155				}
156			case <-ctx.Done():
157			}
158		})
159	}
160	wg.Wait()
161	broker.Shutdown()
162	return nil
163}
164
165// Initialize initializes MCP clients based on the provided configuration.
166func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
167	slog.Info("Initializing MCP clients")
168	var wg sync.WaitGroup
169	// Initialize states for all configured MCPs
170	for name, m := range cfg.MCP {
171		if m.Disabled {
172			updateState(name, StateDisabled, nil, nil, Counts{})
173			slog.Debug("Skipping disabled MCP", "name", name)
174			continue
175		}
176
177		// Set initial starting state
178		updateState(name, StateStarting, nil, nil, Counts{})
179
180		wg.Add(1)
181		go func(name string, m config.MCPConfig) {
182			defer func() {
183				wg.Done()
184				if r := recover(); r != nil {
185					var err error
186					switch v := r.(type) {
187					case error:
188						err = v
189					case string:
190						err = fmt.Errorf("panic: %s", v)
191					default:
192						err = fmt.Errorf("panic: %v", v)
193					}
194					updateState(name, StateError, err, nil, Counts{})
195					slog.Error("Panic in MCP client initialization", "error", err, "name", name)
196				}
197			}()
198
199			// createSession handles its own timeout internally.
200			session, err := createSession(ctx, name, m, cfg.Resolver())
201			if err != nil {
202				return
203			}
204
205			tools, err := getTools(ctx, session)
206			if err != nil {
207				slog.Error("Error listing tools", "error", err)
208				updateState(name, StateError, err, nil, Counts{})
209				session.Close()
210				return
211			}
212
213			prompts, err := getPrompts(ctx, session)
214			if err != nil {
215				slog.Error("Error listing prompts", "error", err)
216				updateState(name, StateError, err, nil, Counts{})
217				session.Close()
218				return
219			}
220
221			resources, err := getResources(ctx, session)
222			if err != nil {
223				slog.Error("Error listing resources", "error", err)
224				updateState(name, StateError, err, nil, Counts{})
225				session.Close()
226				return
227			}
228
229			toolCount := updateTools(cfg, name, tools)
230			updatePrompts(name, prompts)
231			resourceCount := updateResources(name, resources)
232			sessions.Set(name, session)
233
234			updateState(name, StateConnected, nil, session, Counts{
235				Tools:     toolCount,
236				Prompts:   len(prompts),
237				Resources: resourceCount,
238			})
239		}(name, m)
240	}
241	wg.Wait()
242	initOnce.Do(func() { close(initDone) })
243}
244
245// WaitForInit blocks until MCP initialization is complete.
246// If Initialize was never called, this returns immediately.
247func WaitForInit(ctx context.Context) error {
248	select {
249	case <-initDone:
250		return nil
251	case <-ctx.Done():
252		return ctx.Err()
253	}
254}
255
256func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*ClientSession, error) {
257	sess, ok := sessions.Get(name)
258	if !ok {
259		return nil, fmt.Errorf("mcp '%s' not available", name)
260	}
261
262	m := cfg.MCP[name]
263	state, _ := states.Get(name)
264
265	timeout := mcpTimeout(m)
266	pingCtx, cancel := context.WithTimeout(ctx, timeout)
267	defer cancel()
268	err := sess.Ping(pingCtx, nil)
269	if err == nil {
270		return sess, nil
271	}
272	updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
273
274	sess, err = createSession(ctx, name, m, cfg.Resolver())
275	if err != nil {
276		return nil, err
277	}
278
279	updateState(name, StateConnected, nil, sess, state.Counts)
280	sessions.Set(name, sess)
281	return sess, nil
282}
283
284// updateState updates the state of an MCP client and publishes an event
285func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
286	info := ClientInfo{
287		Name:   name,
288		State:  state,
289		Error:  err,
290		Client: client,
291		Counts: counts,
292	}
293	switch state {
294	case StateConnected:
295		info.ConnectedAt = time.Now()
296	case StateError:
297		sessions.Del(name)
298	}
299	states.Set(name, info)
300
301	// Publish state change event
302	broker.Publish(pubsub.UpdatedEvent, Event{
303		Type:   EventStateChanged,
304		Name:   name,
305		State:  state,
306		Error:  err,
307		Counts: counts,
308	})
309}
310
311func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
312	timeout := mcpTimeout(m)
313	mcpCtx, cancel := context.WithCancel(ctx)
314	cancelTimer := time.AfterFunc(timeout, cancel)
315
316	transport, err := createTransport(mcpCtx, m, resolver)
317	if err != nil {
318		updateState(name, StateError, err, nil, Counts{})
319		slog.Error("Error creating MCP client", "error", err, "name", name)
320		cancel()
321		cancelTimer.Stop()
322		return nil, err
323	}
324
325	client := mcp.NewClient(
326		&mcp.Implementation{
327			Name:    "crush",
328			Version: version.Version,
329			Title:   "Crush",
330		},
331		&mcp.ClientOptions{
332			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
333				broker.Publish(pubsub.UpdatedEvent, Event{
334					Type: EventToolsListChanged,
335					Name: name,
336				})
337			},
338			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
339				broker.Publish(pubsub.UpdatedEvent, Event{
340					Type: EventPromptsListChanged,
341					Name: name,
342				})
343			},
344			ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
345				broker.Publish(pubsub.UpdatedEvent, Event{
346					Type: EventResourcesListChanged,
347					Name: name,
348				})
349			},
350			LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
351				level := parseLevel(req.Params.Level)
352				slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
353			},
354		},
355	)
356
357	session, err := client.Connect(mcpCtx, transport, nil)
358	if err != nil {
359		err = maybeStdioErr(err, transport)
360		updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
361		slog.Error("MCP client failed to initialize", "error", err, "name", name)
362		cancel()
363		cancelTimer.Stop()
364		return nil, err
365	}
366
367	cancelTimer.Stop()
368	slog.Debug("MCP client initialized", "name", name)
369	return &ClientSession{session, cancel}, nil
370}
371
372// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
373// to parse, and the cli will then close it, causing the EOF error.
374// so, if we got an EOF err, and the transport is STDIO, we try to exec it
375// again with a timeout and collect the output so we can add details to the
376// error.
377// this happens particularly when starting things with npx, e.g. if node can't
378// be found or some other error like that.
379func maybeStdioErr(err error, transport mcp.Transport) error {
380	if !errors.Is(err, io.EOF) {
381		return err
382	}
383	ct, ok := transport.(*mcp.CommandTransport)
384	if !ok {
385		return err
386	}
387	if err2 := stdioCheck(ct.Command); err2 != nil {
388		err = errors.Join(err, err2)
389	}
390	return err
391}
392
393func maybeTimeoutErr(err error, timeout time.Duration) error {
394	if errors.Is(err, context.Canceled) {
395		return fmt.Errorf("timed out after %s", timeout)
396	}
397	return err
398}
399
400func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
401	switch m.Type {
402	case config.MCPStdio:
403		command, err := resolver.ResolveValue(m.Command)
404		if err != nil {
405			return nil, fmt.Errorf("invalid mcp command: %w", err)
406		}
407		if strings.TrimSpace(command) == "" {
408			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
409		}
410		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
411		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
412		return &mcp.CommandTransport{
413			Command: cmd,
414		}, nil
415	case config.MCPHttp:
416		if strings.TrimSpace(m.URL) == "" {
417			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
418		}
419		client := &http.Client{
420			Transport: &headerRoundTripper{
421				headers: m.ResolvedHeaders(),
422			},
423		}
424		return &mcp.StreamableClientTransport{
425			Endpoint:   m.URL,
426			HTTPClient: client,
427		}, nil
428	case config.MCPSSE:
429		if strings.TrimSpace(m.URL) == "" {
430			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
431		}
432		client := &http.Client{
433			Transport: &headerRoundTripper{
434				headers: m.ResolvedHeaders(),
435			},
436		}
437		return &mcp.SSEClientTransport{
438			Endpoint:   m.URL,
439			HTTPClient: client,
440		}, nil
441	default:
442		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
443	}
444}
445
446type headerRoundTripper struct {
447	headers map[string]string
448}
449
450func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
451	for k, v := range rt.headers {
452		req.Header.Set(k, v)
453	}
454	return http.DefaultTransport.RoundTrip(req)
455}
456
457func mcpTimeout(m config.MCPConfig) time.Duration {
458	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
459}
460
461func stdioCheck(old *exec.Cmd) error {
462	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
463	defer cancel()
464	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
465	cmd.Env = old.Env
466	out, err := cmd.CombinedOutput()
467	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
468		return nil
469	}
470	return fmt.Errorf("%w: %s", err, string(out))
471}