init.go

  1// Package mcp provides functionality for managing Model Context Protocol (MCP)
  2// clients within the Crush application.
  3package mcp
  4
  5import (
  6	"cmp"
  7	"context"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"log/slog"
 12	"maps"
 13	"net/http"
 14	"os"
 15	"os/exec"
 16	"strings"
 17	"sync"
 18	"time"
 19
 20	"github.com/charmbracelet/crush/internal/config"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/home"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/version"
 26	"github.com/modelcontextprotocol/go-sdk/mcp"
 27)
 28
 29var (
 30	sessions = csync.NewMap[string, *mcp.ClientSession]()
 31	states   = csync.NewMap[string, ClientInfo]()
 32	broker   = pubsub.NewBroker[Event]()
 33)
 34
 35// State represents the current state of an MCP client
 36type State int
 37
 38const (
 39	StateDisabled State = iota
 40	StateStarting
 41	StateConnected
 42	StateError
 43)
 44
 45func (s State) String() string {
 46	switch s {
 47	case StateDisabled:
 48		return "disabled"
 49	case StateStarting:
 50		return "starting"
 51	case StateConnected:
 52		return "connected"
 53	case StateError:
 54		return "error"
 55	default:
 56		return "unknown"
 57	}
 58}
 59
 60// EventType represents the type of MCP event
 61type EventType uint
 62
 63const (
 64	EventStateChanged EventType = iota
 65	EventToolsListChanged
 66	EventPromptsListChanged
 67	EventResourcesListChanged
 68)
 69
 70// Event represents an event in the MCP system
 71type Event struct {
 72	Type   EventType
 73	Name   string
 74	State  State
 75	Error  error
 76	Counts Counts
 77}
 78
 79// Counts number of available tools, prompts, etc.
 80type Counts struct {
 81	Tools     int
 82	Prompts   int
 83	Resources int
 84}
 85
 86// ClientInfo holds information about an MCP client's state
 87type ClientInfo struct {
 88	Name        string
 89	State       State
 90	Error       error
 91	Client      *mcp.ClientSession
 92	Counts      Counts
 93	ConnectedAt time.Time
 94}
 95
 96// SubscribeEvents returns a channel for MCP events
 97func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
 98	return broker.Subscribe(ctx)
 99}
100
101// GetStates returns the current state of all MCP clients
102func GetStates() map[string]ClientInfo {
103	return maps.Collect(states.Seq2())
104}
105
106// GetState returns the state of a specific MCP client
107func GetState(name string) (ClientInfo, bool) {
108	return states.Get(name)
109}
110
111// Close closes all MCP clients. This should be called during application shutdown.
112func Close() error {
113	var errs []error
114	for name, c := range sessions.Seq2() {
115		if err := c.Close(); err != nil &&
116			!errors.Is(err, io.EOF) &&
117			!errors.Is(err, context.Canceled) &&
118			err.Error() != "signal: killed" {
119			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
120		}
121	}
122	broker.Shutdown()
123	return errors.Join(errs...)
124}
125
126// Initialize initializes MCP clients based on the provided configuration.
127func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
128	var wg sync.WaitGroup
129	// Initialize states for all configured MCPs
130	for name, m := range cfg.MCP {
131		if m.Disabled {
132			updateState(name, StateDisabled, nil, nil, Counts{})
133			slog.Debug("skipping disabled mcp", "name", name)
134			continue
135		}
136
137		// Set initial starting state
138		updateState(name, StateStarting, nil, nil, Counts{})
139
140		wg.Add(1)
141		go func(name string, m config.MCPConfig) {
142			defer func() {
143				wg.Done()
144				if r := recover(); r != nil {
145					var err error
146					switch v := r.(type) {
147					case error:
148						err = v
149					case string:
150						err = fmt.Errorf("panic: %s", v)
151					default:
152						err = fmt.Errorf("panic: %v", v)
153					}
154					updateState(name, StateError, err, nil, Counts{})
155					slog.Error("panic in mcp client initialization", "error", err, "name", name)
156				}
157			}()
158
159			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
160			defer cancel()
161
162			session, err := createSession(ctx, name, m, cfg.Resolver())
163			if err != nil {
164				return
165			}
166
167			tools, err := getTools(ctx, session)
168			if err != nil {
169				slog.Error("error listing tools", "error", err)
170				updateState(name, StateError, err, nil, Counts{})
171				session.Close()
172				return
173			}
174
175			prompts, err := getPrompts(ctx, session)
176			if err != nil {
177				slog.Error("error listing prompts", "error", err)
178				updateState(name, StateError, err, nil, Counts{})
179				session.Close()
180				return
181			}
182
183			resources, err := getResources(ctx, session)
184			if err != nil {
185				slog.Error("error listing resources", "error", err)
186				updateState(name, StateError, err, nil, Counts{})
187				session.Close()
188				return
189			}
190
191			updateTools(name, tools)
192			updatePrompts(name, prompts)
193			updateResources(name, resources)
194			sessions.Set(name, session)
195
196			updateState(name, StateConnected, nil, session, Counts{
197				Tools:     len(tools),
198				Prompts:   len(prompts),
199				Resources: len(resources),
200			})
201		}(name, m)
202	}
203	wg.Wait()
204}
205
206func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
207	sess, ok := sessions.Get(name)
208	if !ok {
209		return nil, fmt.Errorf("mcp '%s' not available", name)
210	}
211
212	cfg := config.Get()
213	m := cfg.MCP[name]
214	state, _ := states.Get(name)
215
216	timeout := mcpTimeout(m)
217	pingCtx, cancel := context.WithTimeout(ctx, timeout)
218	defer cancel()
219	err := sess.Ping(pingCtx, nil)
220	if err == nil {
221		return sess, nil
222	}
223	updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
224
225	sess, err = createSession(ctx, name, m, cfg.Resolver())
226	if err != nil {
227		return nil, err
228	}
229
230	updateState(name, StateConnected, nil, sess, state.Counts)
231	sessions.Set(name, sess)
232	return sess, nil
233}
234
235// updateState updates the state of an MCP client and publishes an event
236func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
237	info := ClientInfo{
238		Name:   name,
239		State:  state,
240		Error:  err,
241		Client: client,
242		Counts: counts,
243	}
244	switch state {
245	case StateConnected:
246		info.ConnectedAt = time.Now()
247	case StateError:
248		updateTools(name, nil)
249		sessions.Del(name)
250	}
251	states.Set(name, info)
252
253	// Publish state change event
254	broker.Publish(pubsub.UpdatedEvent, Event{
255		Type:   EventStateChanged,
256		Name:   name,
257		State:  state,
258		Error:  err,
259		Counts: counts,
260	})
261}
262
263func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
264	timeout := mcpTimeout(m)
265	mcpCtx, cancel := context.WithCancel(ctx)
266	cancelTimer := time.AfterFunc(timeout, cancel)
267
268	transport, err := createTransport(mcpCtx, m, resolver)
269	if err != nil {
270		updateState(name, StateError, err, nil, Counts{})
271		slog.Error("error creating mcp client", "error", err, "name", name)
272		cancel()
273		cancelTimer.Stop()
274		return nil, err
275	}
276
277	client := mcp.NewClient(
278		&mcp.Implementation{
279			Name:    "crush",
280			Version: version.Version,
281			Title:   "Crush",
282		},
283		&mcp.ClientOptions{
284			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
285				broker.Publish(pubsub.UpdatedEvent, Event{
286					Type: EventToolsListChanged,
287					Name: name,
288				})
289			},
290			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
291				broker.Publish(pubsub.UpdatedEvent, Event{
292					Type: EventPromptsListChanged,
293					Name: name,
294				})
295			},
296			ResourceListChangedHandler: func(context.Context, *mcp.ResourceListChangedRequest) {
297				broker.Publish(pubsub.UpdatedEvent, Event{
298					Type: EventResourcesListChanged,
299					Name: name,
300				})
301			},
302			LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
303				slog.Log(ctx, parseLevel(req.Params.Level), "mcp server log", "data", req.Params.Data)
304			},
305			KeepAlive: time.Minute * 10,
306		},
307	)
308
309	session, err := client.Connect(mcpCtx, transport, nil)
310	if err != nil {
311		err = maybeStdioErr(err, transport)
312		updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
313		slog.Error("error starting mcp client", "error", err, "name", name)
314		cancel()
315		cancelTimer.Stop()
316		return nil, err
317	}
318
319	cancelTimer.Stop()
320	slog.Info("Initialized mcp client", "name", name)
321	return session, nil
322}
323
324func parseLevel(level mcp.LoggingLevel) slog.Level {
325	switch level {
326	case "debug":
327		return slog.LevelDebug
328	case "info", "notice":
329		return slog.LevelInfo
330	case "warning", "warn":
331		return slog.LevelWarn
332	case "error", "emergency", "alert", "fatal":
333		fallthrough
334	default:
335		return slog.LevelError
336	}
337}
338
339// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
340// to parse, and the cli will then close it, causing the EOF error.
341// so, if we got an EOF err, and the transport is STDIO, we try to exec it
342// again with a timeout and collect the output so we can add details to the
343// error.
344// this happens particularly when starting things with npx, e.g. if node can't
345// be found or some other error like that.
346func maybeStdioErr(err error, transport mcp.Transport) error {
347	if !errors.Is(err, io.EOF) {
348		return err
349	}
350	ct, ok := transport.(*mcp.CommandTransport)
351	if !ok {
352		return err
353	}
354	if err2 := stdioCheck(ct.Command); err2 != nil {
355		err = errors.Join(err, err2)
356	}
357	return err
358}
359
360func maybeTimeoutErr(err error, timeout time.Duration) error {
361	if errors.Is(err, context.Canceled) {
362		return fmt.Errorf("timed out after %s", timeout)
363	}
364	return err
365}
366
367func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
368	switch m.Type {
369	case config.MCPStdio:
370		command, err := resolver.ResolveValue(m.Command)
371		if err != nil {
372			return nil, fmt.Errorf("invalid mcp command: %w", err)
373		}
374		if strings.TrimSpace(command) == "" {
375			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
376		}
377		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
378		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
379		return &mcp.CommandTransport{
380			Command: cmd,
381		}, nil
382	case config.MCPHttp:
383		if strings.TrimSpace(m.URL) == "" {
384			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
385		}
386		client := &http.Client{
387			Transport: &headerRoundTripper{
388				headers: m.ResolvedHeaders(),
389			},
390		}
391		return &mcp.StreamableClientTransport{
392			Endpoint:   m.URL,
393			HTTPClient: client,
394		}, nil
395	case config.MCPSSE:
396		if strings.TrimSpace(m.URL) == "" {
397			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
398		}
399		client := &http.Client{
400			Transport: &headerRoundTripper{
401				headers: m.ResolvedHeaders(),
402			},
403		}
404		return &mcp.SSEClientTransport{
405			Endpoint:   m.URL,
406			HTTPClient: client,
407		}, nil
408	default:
409		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
410	}
411}
412
413type headerRoundTripper struct {
414	headers map[string]string
415}
416
417func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
418	for k, v := range rt.headers {
419		req.Header.Set(k, v)
420	}
421	return http.DefaultTransport.RoundTrip(req)
422}
423
424func mcpTimeout(m config.MCPConfig) time.Duration {
425	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
426}
427
428func stdioCheck(old *exec.Cmd) error {
429	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
430	defer cancel()
431	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
432	cmd.Env = old.Env
433	out, err := cmd.CombinedOutput()
434	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
435		return nil
436	}
437	return fmt.Errorf("%w: %s", err, string(out))
438}