init.go

  1// Package mcp provides functionality for managing Model Context Protocol (MCP)
  2// clients within the Crush application.
  3package mcp
  4
  5import (
  6	"cmp"
  7	"context"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"log/slog"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/home"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/version"
 25	"github.com/modelcontextprotocol/go-sdk/mcp"
 26)
 27
 28func parseLevel(level mcp.LoggingLevel) slog.Level {
 29	switch level {
 30	case "info":
 31		return slog.LevelInfo
 32	case "notice":
 33		return slog.LevelInfo
 34	case "warning":
 35		return slog.LevelWarn
 36	default:
 37		return slog.LevelDebug
 38	}
 39}
 40
 41// ClientSession wraps an mcp.ClientSession with a context cancel function so
 42// that the context created during session establishment is properly cleaned up
 43// on close.
 44type ClientSession struct {
 45	*mcp.ClientSession
 46	cancel context.CancelFunc
 47}
 48
 49// Close cancels the session context and then closes the underlying session.
 50func (s *ClientSession) Close() error {
 51	s.cancel()
 52	return s.ClientSession.Close()
 53}
 54
 55var (
 56	sessions = csync.NewMap[string, *ClientSession]()
 57	states   = csync.NewMap[string, ClientInfo]()
 58	broker   = pubsub.NewBroker[Event]()
 59	initOnce sync.Once
 60	initDone = make(chan struct{})
 61)
 62
 63// State represents the current state of an MCP client
 64type State int
 65
 66const (
 67	StateDisabled State = iota
 68	StateStarting
 69	StateConnected
 70	StateError
 71)
 72
 73func (s State) String() string {
 74	switch s {
 75	case StateDisabled:
 76		return "disabled"
 77	case StateStarting:
 78		return "starting"
 79	case StateConnected:
 80		return "connected"
 81	case StateError:
 82		return "error"
 83	default:
 84		return "unknown"
 85	}
 86}
 87
 88// EventType represents the type of MCP event
 89type EventType uint
 90
 91const (
 92	EventStateChanged EventType = iota
 93	EventToolsListChanged
 94	EventPromptsListChanged
 95	EventResourcesListChanged
 96)
 97
 98// Event represents an event in the MCP system
 99type Event struct {
100	Type   EventType
101	Name   string
102	State  State
103	Error  error
104	Counts Counts
105}
106
107// Counts number of available tools, prompts, etc.
108type Counts struct {
109	Tools     int
110	Prompts   int
111	Resources int
112}
113
114// ClientInfo holds information about an MCP client's state
115type ClientInfo struct {
116	Name        string
117	State       State
118	Error       error
119	Client      *ClientSession
120	Counts      Counts
121	ConnectedAt time.Time
122}
123
124// SubscribeEvents returns a channel for MCP events
125func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
126	return broker.Subscribe(ctx)
127}
128
129// GetStates returns the current state of all MCP clients
130func GetStates() map[string]ClientInfo {
131	return states.Copy()
132}
133
134// GetState returns the state of a specific MCP client
135func GetState(name string) (ClientInfo, bool) {
136	return states.Get(name)
137}
138
139// Close closes all MCP clients. This should be called during application shutdown.
140func Close(ctx context.Context) error {
141	var wg sync.WaitGroup
142	for name, session := range sessions.Seq2() {
143		wg.Go(func() {
144			done := make(chan error, 1)
145			go func() {
146				done <- session.Close()
147			}()
148			select {
149			case err := <-done:
150				if err != nil &&
151					!errors.Is(err, io.EOF) &&
152					!errors.Is(err, context.Canceled) &&
153					err.Error() != "signal: killed" {
154					slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
155				}
156			case <-ctx.Done():
157			}
158		})
159	}
160	wg.Wait()
161	broker.Shutdown()
162	return nil
163}
164
165// Initialize initializes MCP clients based on the provided configuration.
166func Initialize(ctx context.Context, permissions permission.Service, cfg *config.ConfigStore) {
167	slog.Info("Initializing MCP clients")
168	var wg sync.WaitGroup
169	// Initialize states for all configured MCPs
170	for name, m := range cfg.Config().MCP {
171		if m.Disabled {
172			updateState(name, StateDisabled, nil, nil, Counts{})
173			slog.Debug("Skipping disabled MCP", "name", name)
174			continue
175		}
176
177		// Set initial starting state
178		wg.Add(1)
179		go func(name string, m config.MCPConfig) {
180			defer func() {
181				wg.Done()
182				if r := recover(); r != nil {
183					var err error
184					switch v := r.(type) {
185					case error:
186						err = v
187					case string:
188						err = fmt.Errorf("panic: %s", v)
189					default:
190						err = fmt.Errorf("panic: %v", v)
191					}
192					updateState(name, StateError, err, nil, Counts{})
193					slog.Error("Panic in MCP client initialization", "error", err, "name", name)
194				}
195			}()
196
197			if err := initClient(ctx, cfg, name, m, cfg.Resolver()); err != nil {
198				slog.Debug("failed to initialize mcp client", "name", name, "error", err)
199			}
200		}(name, m)
201	}
202	wg.Wait()
203	initOnce.Do(func() { close(initDone) })
204}
205
206// WaitForInit blocks until MCP initialization is complete.
207// If Initialize was never called, this returns immediately.
208func WaitForInit(ctx context.Context) error {
209	select {
210	case <-initDone:
211		return nil
212	case <-ctx.Done():
213		return ctx.Err()
214	}
215}
216
217// InitializeSingle initializes a single MCP client by name.
218func InitializeSingle(ctx context.Context, name string, cfg *config.ConfigStore) error {
219	m, exists := cfg.Config().MCP[name]
220	if !exists {
221		return fmt.Errorf("mcp '%s' not found in configuration", name)
222	}
223
224	if m.Disabled {
225		updateState(name, StateDisabled, nil, nil, Counts{})
226		slog.Debug("skipping disabled mcp", "name", name)
227		return nil
228	}
229
230	return initClient(ctx, cfg, name, m, cfg.Resolver())
231}
232
233// initClient initializes a single MCP client with the given configuration.
234func initClient(ctx context.Context, cfg *config.ConfigStore, name string, m config.MCPConfig, resolver config.VariableResolver) error {
235	// Set initial starting state.
236	updateState(name, StateStarting, nil, nil, Counts{})
237
238	// createSession handles its own timeout internally.
239	session, err := createSession(ctx, name, m, resolver)
240	if err != nil {
241		return err
242	}
243
244	tools, err := getTools(ctx, session)
245	if err != nil {
246		slog.Error("Error listing tools", "error", err)
247		updateState(name, StateError, err, nil, Counts{})
248		session.Close()
249		return err
250	}
251
252	prompts, err := getPrompts(ctx, session)
253	if err != nil {
254		slog.Error("Error listing prompts", "error", err)
255		updateState(name, StateError, err, nil, Counts{})
256		session.Close()
257		return err
258	}
259
260	toolCount := updateTools(cfg, name, tools)
261	updatePrompts(name, prompts)
262	sessions.Set(name, session)
263
264	updateState(name, StateConnected, nil, session, Counts{
265		Tools:   toolCount,
266		Prompts: len(prompts),
267	})
268
269	return nil
270}
271
272// DisableSingle disables and closes a single MCP client by name.
273func DisableSingle(cfg *config.ConfigStore, name string) error {
274	session, ok := sessions.Get(name)
275	if ok {
276		if err := session.Close(); err != nil &&
277			!errors.Is(err, io.EOF) &&
278			!errors.Is(err, context.Canceled) &&
279			err.Error() != "signal: killed" {
280			slog.Warn("error closing mcp session", "name", name, "error", err)
281		}
282		sessions.Del(name)
283	}
284
285	// Clear tools and prompts for this MCP.
286	updateTools(cfg, name, nil)
287	updatePrompts(name, nil)
288
289	// Update state to disabled.
290	updateState(name, StateDisabled, nil, nil, Counts{})
291
292	slog.Info("Disabled mcp client", "name", name)
293	return nil
294}
295
296func getOrRenewClient(ctx context.Context, cfg *config.ConfigStore, name string) (*ClientSession, error) {
297	sess, ok := sessions.Get(name)
298	if !ok {
299		return nil, fmt.Errorf("mcp '%s' not available", name)
300	}
301
302	m := cfg.Config().MCP[name]
303	state, _ := states.Get(name)
304
305	timeout := mcpTimeout(m)
306	pingCtx, cancel := context.WithTimeout(ctx, timeout)
307	defer cancel()
308	err := sess.Ping(pingCtx, nil)
309	if err == nil {
310		return sess, nil
311	}
312	updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
313
314	sess, err = createSession(ctx, name, m, cfg.Resolver())
315	if err != nil {
316		return nil, err
317	}
318
319	updateState(name, StateConnected, nil, sess, state.Counts)
320	sessions.Set(name, sess)
321	return sess, nil
322}
323
324// updateState updates the state of an MCP client and publishes an event
325func updateState(name string, state State, err error, client *ClientSession, counts Counts) {
326	info := ClientInfo{
327		Name:   name,
328		State:  state,
329		Error:  err,
330		Client: client,
331		Counts: counts,
332	}
333	switch state {
334	case StateConnected:
335		info.ConnectedAt = time.Now()
336	case StateError:
337		sessions.Del(name)
338	}
339	states.Set(name, info)
340
341	// Publish state change event
342	broker.Publish(pubsub.UpdatedEvent, Event{
343		Type:   EventStateChanged,
344		Name:   name,
345		State:  state,
346		Error:  err,
347		Counts: counts,
348	})
349}
350
351func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*ClientSession, error) {
352	timeout := mcpTimeout(m)
353	mcpCtx, cancel := context.WithCancel(ctx)
354	cancelTimer := time.AfterFunc(timeout, cancel)
355
356	transport, err := createTransport(mcpCtx, m, resolver)
357	if err != nil {
358		updateState(name, StateError, err, nil, Counts{})
359		slog.Error("Error creating MCP client", "error", err, "name", name)
360		cancel()
361		cancelTimer.Stop()
362		return nil, err
363	}
364
365	client := mcp.NewClient(
366		&mcp.Implementation{
367			Name:    "crush",
368			Version: version.Version,
369			Title:   "Crush",
370		},
371		&mcp.ClientOptions{
372			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
373				broker.Publish(pubsub.UpdatedEvent, Event{
374					Type: EventToolsListChanged,
375					Name: name,
376				})
377			},
378			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
379				broker.Publish(pubsub.UpdatedEvent, Event{
380					Type: EventPromptsListChanged,
381					Name: name,
382				})
383			},
384			ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
385				broker.Publish(pubsub.UpdatedEvent, Event{
386					Type: EventResourcesListChanged,
387					Name: name,
388				})
389			},
390			LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
391				level := parseLevel(req.Params.Level)
392				slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
393			},
394		},
395	)
396
397	session, err := client.Connect(mcpCtx, transport, nil)
398	if err != nil {
399		err = maybeStdioErr(err, transport)
400		updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
401		slog.Error("MCP client failed to initialize", "error", err, "name", name)
402		cancel()
403		cancelTimer.Stop()
404		return nil, err
405	}
406
407	cancelTimer.Stop()
408	slog.Debug("MCP client initialized", "name", name)
409	return &ClientSession{session, cancel}, nil
410}
411
412// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
413// to parse, and the cli will then close it, causing the EOF error.
414// so, if we got an EOF err, and the transport is STDIO, we try to exec it
415// again with a timeout and collect the output so we can add details to the
416// error.
417// this happens particularly when starting things with npx, e.g. if node can't
418// be found or some other error like that.
419func maybeStdioErr(err error, transport mcp.Transport) error {
420	if !errors.Is(err, io.EOF) {
421		return err
422	}
423	ct, ok := transport.(*mcp.CommandTransport)
424	if !ok {
425		return err
426	}
427	if err2 := stdioCheck(ct.Command); err2 != nil {
428		err = errors.Join(err, err2)
429	}
430	return err
431}
432
433func maybeTimeoutErr(err error, timeout time.Duration) error {
434	if errors.Is(err, context.Canceled) {
435		return fmt.Errorf("timed out after %s", timeout)
436	}
437	return err
438}
439
440func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
441	switch m.Type {
442	case config.MCPStdio:
443		command, err := resolver.ResolveValue(m.Command)
444		if err != nil {
445			return nil, fmt.Errorf("invalid mcp command: %w", err)
446		}
447		if strings.TrimSpace(command) == "" {
448			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
449		}
450		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
451		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
452		return &mcp.CommandTransport{
453			Command: cmd,
454		}, nil
455	case config.MCPHttp:
456		if strings.TrimSpace(m.URL) == "" {
457			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
458		}
459		client := &http.Client{
460			Transport: &headerRoundTripper{
461				headers: m.ResolvedHeaders(),
462			},
463		}
464		return &mcp.StreamableClientTransport{
465			Endpoint:   m.URL,
466			HTTPClient: client,
467		}, nil
468	case config.MCPSSE:
469		if strings.TrimSpace(m.URL) == "" {
470			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
471		}
472		client := &http.Client{
473			Transport: &headerRoundTripper{
474				headers: m.ResolvedHeaders(),
475			},
476		}
477		return &mcp.SSEClientTransport{
478			Endpoint:   m.URL,
479			HTTPClient: client,
480		}, nil
481	default:
482		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
483	}
484}
485
486type headerRoundTripper struct {
487	headers map[string]string
488}
489
490func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
491	for k, v := range rt.headers {
492		req.Header.Set(k, v)
493	}
494	return http.DefaultTransport.RoundTrip(req)
495}
496
497func mcpTimeout(m config.MCPConfig) time.Duration {
498	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
499}
500
501func stdioCheck(old *exec.Cmd) error {
502	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
503	defer cancel()
504	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
505	cmd.Env = old.Env
506	out, err := cmd.CombinedOutput()
507	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
508		return nil
509	}
510	return fmt.Errorf("%w: %s", err, string(out))
511}