init.go

  1// Package mcp provides functionality for managing Model Context Protocol (MCP)
  2// clients within the Crush application.
  3package mcp
  4
  5import (
  6	"cmp"
  7	"context"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"log/slog"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/home"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/version"
 25	"github.com/modelcontextprotocol/go-sdk/mcp"
 26)
 27
 28func parseLevel(level mcp.LoggingLevel) slog.Level {
 29	switch level {
 30	case "info":
 31		return slog.LevelInfo
 32	case "notice":
 33		return slog.LevelInfo
 34	case "warning":
 35		return slog.LevelWarn
 36	default:
 37		return slog.LevelDebug
 38	}
 39}
 40
 41var (
 42	sessions = csync.NewMap[string, *mcp.ClientSession]()
 43	states   = csync.NewMap[string, ClientInfo]()
 44	broker   = pubsub.NewBroker[Event]()
 45	initOnce sync.Once
 46	initDone = make(chan struct{})
 47)
 48
 49// State represents the current state of an MCP client
 50type State int
 51
 52const (
 53	StateDisabled State = iota
 54	StateStarting
 55	StateConnected
 56	StateError
 57)
 58
 59func (s State) String() string {
 60	switch s {
 61	case StateDisabled:
 62		return "disabled"
 63	case StateStarting:
 64		return "starting"
 65	case StateConnected:
 66		return "connected"
 67	case StateError:
 68		return "error"
 69	default:
 70		return "unknown"
 71	}
 72}
 73
 74// EventType represents the type of MCP event
 75type EventType uint
 76
 77const (
 78	EventStateChanged EventType = iota
 79	EventToolsListChanged
 80	EventPromptsListChanged
 81	EventResourcesListChanged
 82)
 83
 84// Event represents an event in the MCP system
 85type Event struct {
 86	Type   EventType
 87	Name   string
 88	State  State
 89	Error  error
 90	Counts Counts
 91}
 92
 93// Counts number of available tools, prompts, etc.
 94type Counts struct {
 95	Tools     int
 96	Prompts   int
 97	Resources int
 98}
 99
100// ClientInfo holds information about an MCP client's state
101type ClientInfo struct {
102	Name        string
103	State       State
104	Error       error
105	Client      *mcp.ClientSession
106	Counts      Counts
107	ConnectedAt time.Time
108}
109
110// SubscribeEvents returns a channel for MCP events
111func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
112	return broker.Subscribe(ctx)
113}
114
115// GetStates returns the current state of all MCP clients
116func GetStates() map[string]ClientInfo {
117	return states.Copy()
118}
119
120// GetState returns the state of a specific MCP client
121func GetState(name string) (ClientInfo, bool) {
122	return states.Get(name)
123}
124
125// Close closes all MCP clients. This should be called during application shutdown.
126func Close(ctx context.Context) error {
127	var wg sync.WaitGroup
128	for name, session := range sessions.Seq2() {
129		wg.Go(func() {
130			done := make(chan error, 1)
131			go func() {
132				done <- session.Close()
133			}()
134			select {
135			case err := <-done:
136				if err != nil &&
137					!errors.Is(err, io.EOF) &&
138					!errors.Is(err, context.Canceled) &&
139					err.Error() != "signal: killed" {
140					slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
141				}
142			case <-ctx.Done():
143			}
144		})
145	}
146	wg.Wait()
147	broker.Shutdown()
148	return nil
149}
150
151// Initialize initializes MCP clients based on the provided configuration.
152func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
153	slog.Info("Initializing MCP clients")
154	var wg sync.WaitGroup
155	// Initialize states for all configured MCPs
156	for name, m := range cfg.MCP {
157		if m.Disabled {
158			updateState(name, StateDisabled, nil, nil, Counts{})
159			slog.Debug("Skipping disabled MCP", "name", name)
160			continue
161		}
162
163		// Set initial starting state
164		updateState(name, StateStarting, nil, nil, Counts{})
165
166		wg.Add(1)
167		go func(name string, m config.MCPConfig) {
168			defer func() {
169				wg.Done()
170				if r := recover(); r != nil {
171					var err error
172					switch v := r.(type) {
173					case error:
174						err = v
175					case string:
176						err = fmt.Errorf("panic: %s", v)
177					default:
178						err = fmt.Errorf("panic: %v", v)
179					}
180					updateState(name, StateError, err, nil, Counts{})
181					slog.Error("Panic in MCP client initialization", "error", err, "name", name)
182				}
183			}()
184
185			// createSession handles its own timeout internally.
186			session, err := createSession(ctx, name, m, cfg.Resolver())
187			if err != nil {
188				return
189			}
190
191			tools, err := getTools(ctx, session)
192			if err != nil {
193				slog.Error("Error listing tools", "error", err)
194				updateState(name, StateError, err, nil, Counts{})
195				session.Close()
196				return
197			}
198
199			prompts, err := getPrompts(ctx, session)
200			if err != nil {
201				slog.Error("Error listing prompts", "error", err)
202				updateState(name, StateError, err, nil, Counts{})
203				session.Close()
204				return
205			}
206
207			resources, err := getResources(ctx, session)
208			if err != nil {
209				slog.Error("Error listing resources", "error", err)
210				updateState(name, StateError, err, nil, Counts{})
211				session.Close()
212				return
213			}
214
215			toolCount := updateTools(cfg, name, tools)
216			updatePrompts(name, prompts)
217			resourceCount := updateResources(name, resources)
218			sessions.Set(name, session)
219
220			updateState(name, StateConnected, nil, session, Counts{
221				Tools:     toolCount,
222				Prompts:   len(prompts),
223				Resources: resourceCount,
224			})
225		}(name, m)
226	}
227	wg.Wait()
228	initOnce.Do(func() { close(initDone) })
229}
230
231// WaitForInit blocks until MCP initialization is complete.
232// If Initialize was never called, this returns immediately.
233func WaitForInit(ctx context.Context) error {
234	select {
235	case <-initDone:
236		return nil
237	case <-ctx.Done():
238		return ctx.Err()
239	}
240}
241
242func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*mcp.ClientSession, error) {
243	sess, ok := sessions.Get(name)
244	if !ok {
245		return nil, fmt.Errorf("mcp '%s' not available", name)
246	}
247
248	m := cfg.MCP[name]
249	state, _ := states.Get(name)
250
251	timeout := mcpTimeout(m)
252	pingCtx, cancel := context.WithTimeout(ctx, timeout)
253	defer cancel()
254	err := sess.Ping(pingCtx, nil)
255	if err == nil {
256		return sess, nil
257	}
258	updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
259
260	sess, err = createSession(ctx, name, m, cfg.Resolver())
261	if err != nil {
262		return nil, err
263	}
264
265	updateState(name, StateConnected, nil, sess, state.Counts)
266	sessions.Set(name, sess)
267	return sess, nil
268}
269
270// updateState updates the state of an MCP client and publishes an event
271func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
272	info := ClientInfo{
273		Name:   name,
274		State:  state,
275		Error:  err,
276		Client: client,
277		Counts: counts,
278	}
279	switch state {
280	case StateConnected:
281		info.ConnectedAt = time.Now()
282	case StateError:
283		sessions.Del(name)
284	}
285	states.Set(name, info)
286
287	// Publish state change event
288	broker.Publish(pubsub.UpdatedEvent, Event{
289		Type:   EventStateChanged,
290		Name:   name,
291		State:  state,
292		Error:  err,
293		Counts: counts,
294	})
295}
296
297func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
298	timeout := mcpTimeout(m)
299	mcpCtx, cancel := context.WithCancel(ctx)
300	cancelTimer := time.AfterFunc(timeout, cancel)
301
302	transport, err := createTransport(mcpCtx, m, resolver)
303	if err != nil {
304		updateState(name, StateError, err, nil, Counts{})
305		slog.Error("Error creating MCP client", "error", err, "name", name)
306		cancel()
307		cancelTimer.Stop()
308		return nil, err
309	}
310
311	client := mcp.NewClient(
312		&mcp.Implementation{
313			Name:    "crush",
314			Version: version.Version,
315			Title:   "Crush",
316		},
317		&mcp.ClientOptions{
318			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
319				broker.Publish(pubsub.UpdatedEvent, Event{
320					Type: EventToolsListChanged,
321					Name: name,
322				})
323			},
324			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
325				broker.Publish(pubsub.UpdatedEvent, Event{
326					Type: EventPromptsListChanged,
327					Name: name,
328				})
329			},
330			ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
331				broker.Publish(pubsub.UpdatedEvent, Event{
332					Type: EventResourcesListChanged,
333					Name: name,
334				})
335			},
336			LoggingMessageHandler: func(ctx context.Context, req *mcp.LoggingMessageRequest) {
337				level := parseLevel(req.Params.Level)
338				slog.Log(ctx, level, "MCP log", "name", name, "logger", req.Params.Logger, "data", req.Params.Data)
339			},
340		},
341	)
342
343	session, err := client.Connect(mcpCtx, transport, nil)
344	if err != nil {
345		err = maybeStdioErr(err, transport)
346		updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
347		slog.Error("MCP client failed to initialize", "error", err, "name", name)
348		cancel()
349		cancelTimer.Stop()
350		return nil, err
351	}
352
353	cancelTimer.Stop()
354	slog.Debug("MCP client initialized", "name", name)
355	return session, nil
356}
357
358// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
359// to parse, and the cli will then close it, causing the EOF error.
360// so, if we got an EOF err, and the transport is STDIO, we try to exec it
361// again with a timeout and collect the output so we can add details to the
362// error.
363// this happens particularly when starting things with npx, e.g. if node can't
364// be found or some other error like that.
365func maybeStdioErr(err error, transport mcp.Transport) error {
366	if !errors.Is(err, io.EOF) {
367		return err
368	}
369	ct, ok := transport.(*mcp.CommandTransport)
370	if !ok {
371		return err
372	}
373	if err2 := stdioCheck(ct.Command); err2 != nil {
374		err = errors.Join(err, err2)
375	}
376	return err
377}
378
379func maybeTimeoutErr(err error, timeout time.Duration) error {
380	if errors.Is(err, context.Canceled) {
381		return fmt.Errorf("timed out after %s", timeout)
382	}
383	return err
384}
385
386func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
387	switch m.Type {
388	case config.MCPStdio:
389		command, err := resolver.ResolveValue(m.Command)
390		if err != nil {
391			return nil, fmt.Errorf("invalid mcp command: %w", err)
392		}
393		if strings.TrimSpace(command) == "" {
394			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
395		}
396		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
397		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
398		return &mcp.CommandTransport{
399			Command: cmd,
400		}, nil
401	case config.MCPHttp:
402		if strings.TrimSpace(m.URL) == "" {
403			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
404		}
405		client := &http.Client{
406			Transport: &headerRoundTripper{
407				headers: m.ResolvedHeaders(),
408			},
409		}
410		return &mcp.StreamableClientTransport{
411			Endpoint:   m.URL,
412			HTTPClient: client,
413		}, nil
414	case config.MCPSSE:
415		if strings.TrimSpace(m.URL) == "" {
416			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
417		}
418		client := &http.Client{
419			Transport: &headerRoundTripper{
420				headers: m.ResolvedHeaders(),
421			},
422		}
423		return &mcp.SSEClientTransport{
424			Endpoint:   m.URL,
425			HTTPClient: client,
426		}, nil
427	default:
428		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
429	}
430}
431
432type headerRoundTripper struct {
433	headers map[string]string
434}
435
436func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
437	for k, v := range rt.headers {
438		req.Header.Set(k, v)
439	}
440	return http.DefaultTransport.RoundTrip(req)
441}
442
443func mcpTimeout(m config.MCPConfig) time.Duration {
444	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
445}
446
447func stdioCheck(old *exec.Cmd) error {
448	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
449	defer cancel()
450	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
451	cmd.Env = old.Env
452	out, err := cmd.CombinedOutput()
453	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
454		return nil
455	}
456	return fmt.Errorf("%w: %s", err, string(out))
457}