init.go

  1// Package mcp provides functionality for managing Model Context Protocol (MCP)
  2// clients within the Crush application.
  3package mcp
  4
  5import (
  6	"cmp"
  7	"context"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"log/slog"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"strings"
 16	"sync"
 17	"syscall"
 18	"time"
 19
 20	"github.com/charmbracelet/crush/internal/config"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/home"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/version"
 26	"github.com/modelcontextprotocol/go-sdk/mcp"
 27)
 28
 29func parseLevel(level mcp.LoggingLevel) slog.Level {
 30	switch level {
 31	case "info":
 32		return slog.LevelInfo
 33	case "notice":
 34		return slog.LevelInfo
 35	case "warning":
 36		return slog.LevelWarn
 37	default:
 38		return slog.LevelDebug
 39	}
 40}
 41
 42// ClientSession wraps an mcp.ClientSession with a context cancel function so
 43// that the context created during session establishment is properly cleaned up
 44// on close.
 45type ClientSession struct {
 46	*mcp.ClientSession
 47	cancel context.CancelFunc
 48}
 49
 50// Close cancels the session context and then closes the underlying session.
 51func (s *ClientSession) Close() error {
 52	s.cancel()
 53	return s.ClientSession.Close()
 54}
 55
 56var (
 57	sessions = csync.NewMap[string, *ClientSession]()
 58	states   = csync.NewMap[string, ClientInfo]()
 59	broker   = pubsub.NewBroker[Event]()
 60	initOnce sync.Once
 61	initDone = make(chan struct{})
 62)
 63
 64// State represents the current state of an MCP client
 65type State int
 66
 67const (
 68	StateDisabled State = iota
 69	StateStarting
 70	StateConnected
 71	StateError
 72)
 73
 74func (s State) String() string {
 75	switch s {
 76	case StateDisabled:
 77		return "disabled"
 78	case StateStarting:
 79		return "starting"
 80	case StateConnected:
 81		return "connected"
 82	case StateError:
 83		return "error"
 84	default:
 85		return "unknown"
 86	}
 87}
 88
 89// EventType represents the type of MCP event
 90type EventType uint
 91
 92const (
 93	EventStateChanged EventType = iota
 94	EventToolsListChanged
 95	EventPromptsListChanged
 96	EventResourcesListChanged
 97)
 98
 99// Event represents an event in the MCP system
100type Event struct {
101	Type   EventType
102	Name   string
103	State  State
104	Error  error
105	Counts Counts
106}
107
108// Counts number of available tools, prompts, etc.
109type Counts struct {
110	Tools     int
111	Prompts   int
112	Resources int
113}
114
115// ClientInfo holds information about an MCP client's state
116type ClientInfo struct {
117	Name        string
118	State       State
119	Error       error
120	Client      *ClientSession
121	Counts      Counts
122	ConnectedAt time.Time
123}
124
125// SubscribeEvents returns a channel for MCP events
126func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
127	return broker.Subscribe(ctx)
128}
129
130// GetStates returns the current state of all MCP clients
131func GetStates() map[string]ClientInfo {
132	return states.Copy()
133}
134
135// GetState returns the state of a specific MCP client
136func GetState(name string) (ClientInfo, bool) {
137	return states.Get(name)
138}
139
140// isIgnorableCloseErr returns true for errors that are expected during MCP
141// session shutdown and can be safely suppressed.
142func isIgnorableCloseErr(err error) bool {
143	return err == nil ||
144		errors.Is(err, io.EOF) ||
145		errors.Is(err, context.Canceled) ||
146		isKilledErr(err)
147}
148
149// isKilledErr returns true if the error is an exec.ExitError caused by
150// SIGKILL.
151func isKilledErr(err error) bool {
152	var exitErr *exec.ExitError
153	if !errors.As(err, &exitErr) {
154		return false
155	}
156	ws, ok := exitErr.Sys().(syscall.WaitStatus)
157	return ok && ws.Signaled() && ws.Signal() == syscall.SIGKILL
158}
159
160// Close closes all MCP clients. This should be called during application shutdown.
161func Close(ctx context.Context) error {
162	var wg sync.WaitGroup
163	for name, session := range sessions.Seq2() {
164		wg.Go(func() {
165			done := make(chan error, 1)
166			go func() {
167				done <- session.Close()
168			}()
169			select {
170			case err := <-done:
171				if !isIgnorableCloseErr(err) {
172					slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
173				}
174			case <-ctx.Done():
175			}
176		})
177	}
178	wg.Wait()
179	broker.Shutdown()
180	return nil
181}
182
183// Initialize initializes MCP clients based on the provided configuration.
184func Initialize(ctx context.Context, permissions permission.Service, cfg *config.ConfigStore) {
185	slog.Info("Initializing MCP clients")
186	var wg sync.WaitGroup
187	// Initialize states for all configured MCPs
188	for name, m := range cfg.Config().MCP {
189		if m.Disabled {
190			updateState(name, StateDisabled, nil, nil, Counts{})
191			slog.Debug("Skipping disabled MCP", "name", name)
192			continue
193		}
194
195		// Set initial starting state
196		wg.Add(1)
197		go func(name string, m config.MCPConfig) {
198			defer func() {
199				wg.Done()
200				if r := recover(); r != nil {
201					var err error
202					switch v := r.(type) {
203					case error:
204						err = v
205					case string:
206						err = fmt.Errorf("panic: %s", v)
207					default:
208						err = fmt.Errorf("panic: %v", v)
209					}
210					updateState(name, StateError, err, nil, Counts{})
211					slog.Error("Panic in MCP client initialization", "error", err, "name", name)
212				}
213			}()
214
215			if err := initClient(ctx, cfg, name, m, cfg.Resolver()); err != nil {
216				slog.Debug("Failed to initialize MCP client", "name", name, "error", err)
217			}
218		}(name, m)
219	}
220	wg.Wait()
221	initOnce.Do(func() { close(initDone) })
222}
223
224// WaitForInit blocks until MCP initialization is complete.
225// If Initialize was never called, this returns immediately.
226func WaitForInit(ctx context.Context) error {
227	select {
228	case <-initDone:
229		return nil
230	case <-ctx.Done():
231		return ctx.Err()
232	}
233}
234
235// InitializeSingle initializes a single MCP client by name.
236func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore) error {
237	m, exists := cfg.Config().MCP[name]
238	if !exists {
239		return fmt.Errorf("mcp '%s' not found in configuration", name)
240	}
241
242	if m.Disabled {
243		updateState(name, StateDisabled, nil, nil, Counts{})
244		slog.Debug("Skipping disabled MCP", "name", name)
245		return nil
246	}
247
248	return initClient(ctx, cfg, name, m, cfg.Resolver())
249}
250
251// initClient initializes a single MCP client with the given configuration.
252func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m config.MCPConfig, resolver config.VariableResolver) error {
253	updateState(name, StateStarting, nil, nil, Counts{})
254
255	session, err := createSession(ctx, name, m, resolver)
256	if err != nil {
257		return err
258	}
259
260	tools, err := getTools(ctx, session)
261	if err != nil {
262		slog.Error("Error listing tools", "name", name, "error", err)
263		updateState(name, StateError, err, nil, Counts{})
264		closeSessionOnInitError(name, session)
265		return err
266	}
267
268	prompts, err := getPrompts(ctx, session)
269	if err != nil {
270		slog.Error("Error listing prompts", "name", name, "error", err)
271		updateState(name, StateError, err, nil, Counts{})
272		closeSessionOnInitError(name, session)
273		return err
274	}
275
276	resources, err := getResources(ctx, session)
277	if err != nil {
278		slog.Error("Error listing resources", "name", name, "error", err)
279		updateState(name, StateError, err, nil, Counts{})
280		closeSessionOnInitError(name, session)
281		return err
282	}
283
284	toolCount := updateTools(cfg, name, tools)
285	updatePrompts(name, prompts)
286	resourceCount := updateResources(name, resources)
287	sessions.Set(name, session)
288
289	updateState(name, StateConnected, nil, session, Counts{
290		Tools:     toolCount,
291		Prompts:   len(prompts),
292		Resources: resourceCount,
293	})
294
295	return nil
296}
297
298// closeSessionOnInitError closes a session that failed during initialization,
299// suppressing expected shutdown errors. Uses a fixed timeout to avoid blocking
300// indefinitely if the parent context has no deadline.
301//
302// On timeout the Close goroutine may outlive this function, but since
303// session.Close cancels the session context internally, it will unblock
304// shortly after.
305func closeSessionOnInitError(name string, session *ClientSession) {
306	const closeTimeout = 5 * time.Second
307	ctx, cancel := context.WithTimeout(context.Background(), closeTimeout)
308	defer cancel()
309
310	done := make(chan error, 1)
311	go func() {
312		done <- session.Close()
313	}()
314
315	select {
316	case err := <-done:
317		if !isIgnorableCloseErr(err) {
318			slog.Warn("Failed to close MCP session after init error", "name", name, "error", err)
319		}
320	case <-ctx.Done():
321		slog.Warn("Timed out waiting to close MCP session after init error", "name", name, "error", ctx.Err())
322	}
323}
324
325// DisableSingle disables and closes a single MCP client by name.
326func DisableSingle(cfg *config.ConfigStore, name string) {
327	session, ok := sessions.Get(name)
328	if ok {
329		if err := session.Close(); !isIgnorableCloseErr(err) {
330			slog.Warn("Error closing MCP session", "name", name, "error", err)
331		}
332		sessions.Del(name)
333	}
334
335	// Clear tools, prompts, and resources for this MCP.
336	updateTools(cfg, name, nil)
337	updatePrompts(name, nil)
338	updateResources(name, nil)
339
340	// Update state to disabled.
341	updateState(name, StateDisabled, nil, nil, Counts{})
342
343	slog.Info("Disabled MCP client", "name", name)
344}
345
346func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) {
347	sess, ok := sessions.Get(name)
348	if !ok {
349		return nil, fmt.Errorf("mcp '%s' not available", name)
350	}
351
352	m := cfg.Config().MCP[name]
353	state, _ := states.Get(name)
354
355	timeout := mcpTimeout(m)
356	pingCtx, cancel := context.WithTimeout(ctx, timeout)
357	defer cancel()
358	err := sess.Ping(pingCtx, nil)
359	if err == nil {
360		return sess, nil
361	}
362	updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
363
364	sess, err = createSession(ctx, name, m, cfg.Resolver())
365	if err != nil {
366		return nil, err
367	}
368
369	updateState(name, StateConnected, nil, sess, state.Counts)
370	sessions.Set(name, sess)
371	return sess, nil
372}
373
374// updateState updates the state of an MCP client and publishes an event
375func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
376	info := ClientInfo{
377		Name:   name,
378		State:  state,
379		Error:  err,
380		Client: client,
381		Counts: counts,
382	}
383	switch state {
384	case StateConnected:
385		info.ConnectedAt = time.Now()
386	case StateError:
387		sessions.Del(name)
388	}
389	states.Set(name, info)
390
391	// Publish state change event
392	broker.Publish(pubsub.UpdatedEvent, Event{
393		Type:   EventStateChanged,
394		Name:   name,
395		State:  state,
396		Error:  err,
397		Counts: counts,
398	})
399}
400
401func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
402	timeout := mcpTimeout(m)
403	mcpCtx, cancel := context.WithCancel(ctx)
404	cancelTimer := time.AfterFunc(timeout, cancel)
405
406	transport, err := createTransport(mcpCtx, m, resolver)
407	if err != nil {
408		updateState(name, StateError, err, nil, Counts{})
409		slog.Error("Error creating MCP client", "error", err, "name", name)
410		cancel()
411		cancelTimer.Stop()
412		return nil, err
413	}
414
415	client := mcp.NewClient(
416		&mcp.Implementation{
417			Name:    "crush",
418			Version: version.Version,
419			Title:   "Crush",
420		},
421		&mcp.ClientOptions{
422			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
423				broker.Publish(pubsub.UpdatedEvent, Event{
424					Type: EventToolsListChanged,
425					Name: name,
426				})
427			},
428			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
429				broker.Publish(pubsub.UpdatedEvent, Event{
430					Type: EventPromptsListChanged,
431					Name: name,
432				})
433			},
434			ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
435				broker.Publish(pubsub.UpdatedEvent, Event{
436					Type: EventResourcesListChanged,
437					Name: name,
438				})
439			},
440			LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
441				level := parseLevel(req.Params.Level)
442				slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
443			},
444		},
445	)
446
447	session, err := client.Connect(mcpCtx, transport, nil)
448	if err != nil {
449		err = maybeStdioErr(err, transport)
450		updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
451		slog.Error("MCP client failed to initialize", "error", err, "name", name)
452		cancel()
453		cancelTimer.Stop()
454		return nil, err
455	}
456
457	cancelTimer.Stop()
458	slog.Debug("MCP client initialized", "name", name)
459	return &ClientSession{session, cancel}, nil
460}
461
462// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
463// to parse, and the cli will then close it, causing the EOF error.
464// so, if we got an EOF err, and the transport is STDIO, we try to exec it
465// again with a timeout and collect the output so we can add details to the
466// error.
467// this happens particularly when starting things with npx, e.g. if node can't
468// be found or some other error like that.
469func maybeStdioErr(err error, transport mcp.Transport) error {
470	if !errors.Is(err, io.EOF) {
471		return err
472	}
473	ct, ok := transport.(*mcp.CommandTransport)
474	if !ok {
475		return err
476	}
477	if err2 := stdioCheck(ct.Command); err2 != nil {
478		err = errors.Join(err, err2)
479	}
480	return err
481}
482
483func maybeTimeoutErr(err error, timeout time.Duration) error {
484	if errors.Is(err, context.Canceled) {
485		return fmt.Errorf("timed out after %s", timeout)
486	}
487	return err
488}
489
490func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
491	switch m.Type {
492	case config.MCPStdio:
493		command, err := resolver.ResolveValue(m.Command)
494		if err != nil {
495			return nil, fmt.Errorf("invalid mcp command: %w", err)
496		}
497		if strings.TrimSpace(command) == "" {
498			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
499		}
500		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
501		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
502		return &mcp.CommandTransport{
503			Command: cmd,
504		}, nil
505	case config.MCPHttp:
506		if strings.TrimSpace(m.URL) == "" {
507			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
508		}
509		client := &http.Client{
510			Transport: &headerRoundTripper{
511				headers: m.ResolvedHeaders(),
512			},
513		}
514		return &mcp.StreamableClientTransport{
515			Endpoint:   m.URL,
516			HTTPClient: client,
517		}, nil
518	case config.MCPSSE:
519		if strings.TrimSpace(m.URL) == "" {
520			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
521		}
522		client := &http.Client{
523			Transport: &headerRoundTripper{
524				headers: m.ResolvedHeaders(),
525			},
526		}
527		return &mcp.SSEClientTransport{
528			Endpoint:   m.URL,
529			HTTPClient: client,
530		}, nil
531	default:
532		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
533	}
534}
535
536type headerRoundTripper struct {
537	headers map[string]string
538}
539
540func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
541	for k, v := range rt.headers {
542		req.Header.Set(k, v)
543	}
544	return http.DefaultTransport.RoundTrip(req)
545}
546
547func mcpTimeout(m config.MCPConfig) time.Duration {
548	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
549}
550
551func stdioCheck(old *exec.Cmd) error {
552	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
553	defer cancel()
554	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
555	cmd.Env = old.Env
556	out, err := cmd.CombinedOutput()
557	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
558		return nil
559	}
560	return fmt.Errorf("%w: %s", err, string(out))
561}