init.go

  1// Package mcp provides functionality for managing Model Context Protocol (MCP)
  2// clients within the Crush application.
  3package mcp
  4
  5import (
  6	"cmp"
  7	"context"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"log/slog"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/home"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/version"
 25	"github.com/modelcontextprotocol/go-sdk/mcp"
 26)
 27
 28var (
 29	sessions = csync.NewMap[string, *mcp.ClientSession]()
 30	states   = csync.NewMap[string, ClientInfo]()
 31	broker   = pubsub.NewBroker[Event]()
 32)
 33
 34// State represents the current state of an MCP client
 35type State int
 36
 37const (
 38	StateDisabled State = iota
 39	StateStarting
 40	StateConnected
 41	StateError
 42)
 43
 44func (s State) String() string {
 45	switch s {
 46	case StateDisabled:
 47		return "disabled"
 48	case StateStarting:
 49		return "starting"
 50	case StateConnected:
 51		return "connected"
 52	case StateError:
 53		return "error"
 54	default:
 55		return "unknown"
 56	}
 57}
 58
 59// EventType represents the type of MCP event
 60type EventType uint
 61
 62const (
 63	EventStateChanged EventType = iota
 64	EventToolsListChanged
 65	EventPromptsListChanged
 66)
 67
 68// Event represents an event in the MCP system
 69type Event struct {
 70	Type   EventType
 71	Name   string
 72	State  State
 73	Error  error
 74	Counts Counts
 75}
 76
 77// Counts number of available tools, prompts, etc.
 78type Counts struct {
 79	Tools   int
 80	Prompts int
 81}
 82
 83// ClientInfo holds information about an MCP client's state
 84type ClientInfo struct {
 85	Name        string
 86	State       State
 87	Error       error
 88	Client      *mcp.ClientSession
 89	Counts      Counts
 90	ConnectedAt time.Time
 91}
 92
 93// SubscribeEvents returns a channel for MCP events
 94func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
 95	return broker.Subscribe(ctx)
 96}
 97
 98// GetStates returns the current state of all MCP clients
 99func GetStates() map[string]ClientInfo {
100	return states.Copy()
101}
102
103// GetState returns the state of a specific MCP client
104func GetState(name string) (ClientInfo, bool) {
105	return states.Get(name)
106}
107
108// Close closes all MCP clients. This should be called during application shutdown.
109func Close() error {
110	var errs []error
111	var wg sync.WaitGroup
112	for name, session := range sessions.Seq2() {
113		wg.Go(func() {
114			done := make(chan bool, 1)
115			go func() {
116				if err := session.Close(); err != nil &&
117					!errors.Is(err, io.EOF) &&
118					!errors.Is(err, context.Canceled) &&
119					err.Error() != "signal: killed" {
120					errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
121				}
122				done <- true
123			}()
124			select {
125			case <-done:
126			case <-time.After(time.Millisecond * 250):
127			}
128		})
129	}
130	wg.Wait()
131	broker.Shutdown()
132	return errors.Join(errs...)
133}
134
135// Initialize initializes MCP clients based on the provided configuration.
136func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
137	var wg sync.WaitGroup
138	// Initialize states for all configured MCPs
139	for name, m := range cfg.MCP {
140		if m.Disabled {
141			updateState(name, StateDisabled, nil, nil, Counts{})
142			slog.Debug("skipping disabled mcp", "name", name)
143			continue
144		}
145
146		// Set initial starting state
147		updateState(name, StateStarting, nil, nil, Counts{})
148
149		wg.Add(1)
150		go func(name string, m config.MCPConfig) {
151			defer func() {
152				wg.Done()
153				if r := recover(); r != nil {
154					var err error
155					switch v := r.(type) {
156					case error:
157						err = v
158					case string:
159						err = fmt.Errorf("panic: %s", v)
160					default:
161						err = fmt.Errorf("panic: %v", v)
162					}
163					updateState(name, StateError, err, nil, Counts{})
164					slog.Error("panic in mcp client initialization", "error", err, "name", name)
165				}
166			}()
167
168			// createSession handles its own timeout internally.
169			session, err := createSession(ctx, name, m, cfg.Resolver())
170			if err != nil {
171				return
172			}
173
174			tools, err := getTools(ctx, session)
175			if err != nil {
176				slog.Error("error listing tools", "error", err)
177				updateState(name, StateError, err, nil, Counts{})
178				session.Close()
179				return
180			}
181
182			prompts, err := getPrompts(ctx, session)
183			if err != nil {
184				slog.Error("error listing prompts", "error", err)
185				updateState(name, StateError, err, nil, Counts{})
186				session.Close()
187				return
188			}
189
190			toolCount := updateTools(name, tools)
191			updatePrompts(name, prompts)
192			sessions.Set(name, session)
193
194			updateState(name, StateConnected, nil, session, Counts{
195				Tools:   toolCount,
196				Prompts: len(prompts),
197			})
198		}(name, m)
199	}
200	wg.Wait()
201}
202
203func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
204	sess, ok := sessions.Get(name)
205	if !ok {
206		return nil, fmt.Errorf("mcp '%s' not available", name)
207	}
208
209	cfg := config.Get()
210	m := cfg.MCP[name]
211	state, _ := states.Get(name)
212
213	timeout := mcpTimeout(m)
214	pingCtx, cancel := context.WithTimeout(ctx, timeout)
215	defer cancel()
216	err := sess.Ping(pingCtx, nil)
217	if err == nil {
218		return sess, nil
219	}
220	updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
221
222	sess, err = createSession(ctx, name, m, cfg.Resolver())
223	if err != nil {
224		return nil, err
225	}
226
227	updateState(name, StateConnected, nil, sess, state.Counts)
228	sessions.Set(name, sess)
229	return sess, nil
230}
231
232// updateState updates the state of an MCP client and publishes an event
233func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
234	info := ClientInfo{
235		Name:   name,
236		State:  state,
237		Error:  err,
238		Client: client,
239		Counts: counts,
240	}
241	switch state {
242	case StateConnected:
243		info.ConnectedAt = time.Now()
244	case StateError:
245		sessions.Del(name)
246	}
247	states.Set(name, info)
248
249	// Publish state change event
250	broker.Publish(pubsub.UpdatedEvent, Event{
251		Type:   EventStateChanged,
252		Name:   name,
253		State:  state,
254		Error:  err,
255		Counts: counts,
256	})
257}
258
259func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
260	timeout := mcpTimeout(m)
261	mcpCtx, cancel := context.WithCancel(ctx)
262	cancelTimer := time.AfterFunc(timeout, cancel)
263
264	transport, err := createTransport(mcpCtx, m, resolver)
265	if err != nil {
266		updateState(name, StateError, err, nil, Counts{})
267		slog.Error("error creating mcp client", "error", err, "name", name)
268		cancel()
269		cancelTimer.Stop()
270		return nil, err
271	}
272
273	client := mcp.NewClient(
274		&mcp.Implementation{
275			Name:    "crush",
276			Version: version.Version,
277			Title:   "Crush",
278		},
279		&mcp.ClientOptions{
280			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
281				broker.Publish(pubsub.UpdatedEvent, Event{
282					Type: EventToolsListChanged,
283					Name: name,
284				})
285			},
286			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
287				broker.Publish(pubsub.UpdatedEvent, Event{
288					Type: EventPromptsListChanged,
289					Name: name,
290				})
291			},
292			LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
293				slog.Info("MCP log", "name", name, "data", req.Params.Data)
294			},
295		},
296	)
297
298	session, err := client.Connect(mcpCtx, transport, nil)
299	if err != nil {
300		err = maybeStdioErr(err, transport)
301		updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
302		slog.Error("MCP client failed to initialize", "error", err, "name", name)
303		cancel()
304		cancelTimer.Stop()
305		return nil, err
306	}
307
308	cancelTimer.Stop()
309	slog.Info("MCP client initialized", "name", name)
310	return session, nil
311}
312
313// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
314// to parse, and the cli will then close it, causing the EOF error.
315// so, if we got an EOF err, and the transport is STDIO, we try to exec it
316// again with a timeout and collect the output so we can add details to the
317// error.
318// this happens particularly when starting things with npx, e.g. if node can't
319// be found or some other error like that.
320func maybeStdioErr(err error, transport mcp.Transport) error {
321	if !errors.Is(err, io.EOF) {
322		return err
323	}
324	ct, ok := transport.(*mcp.CommandTransport)
325	if !ok {
326		return err
327	}
328	if err2 := stdioCheck(ct.Command); err2 != nil {
329		err = errors.Join(err, err2)
330	}
331	return err
332}
333
334func maybeTimeoutErr(err error, timeout time.Duration) error {
335	if errors.Is(err, context.Canceled) {
336		return fmt.Errorf("timed out after %s", timeout)
337	}
338	return err
339}
340
341func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
342	switch m.Type {
343	case config.MCPStdio:
344		command, err := resolver.ResolveValue(m.Command)
345		if err != nil {
346			return nil, fmt.Errorf("invalid mcp command: %w", err)
347		}
348		if strings.TrimSpace(command) == "" {
349			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
350		}
351		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
352		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
353		return &mcp.CommandTransport{
354			Command: cmd,
355		}, nil
356	case config.MCPHttp:
357		if strings.TrimSpace(m.URL) == "" {
358			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
359		}
360		client := &http.Client{
361			Transport: &headerRoundTripper{
362				headers: m.ResolvedHeaders(),
363			},
364		}
365		return &mcp.StreamableClientTransport{
366			Endpoint:   m.URL,
367			HTTPClient: client,
368		}, nil
369	case config.MCPSSE:
370		if strings.TrimSpace(m.URL) == "" {
371			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
372		}
373		client := &http.Client{
374			Transport: &headerRoundTripper{
375				headers: m.ResolvedHeaders(),
376			},
377		}
378		return &mcp.SSEClientTransport{
379			Endpoint:   m.URL,
380			HTTPClient: client,
381		}, nil
382	default:
383		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
384	}
385}
386
387type headerRoundTripper struct {
388	headers map[string]string
389}
390
391func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
392	for k, v := range rt.headers {
393		req.Header.Set(k, v)
394	}
395	return http.DefaultTransport.RoundTrip(req)
396}
397
398func mcpTimeout(m config.MCPConfig) time.Duration {
399	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
400}
401
402func stdioCheck(old *exec.Cmd) error {
403	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
404	defer cancel()
405	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
406	cmd.Env = old.Env
407	out, err := cmd.CombinedOutput()
408	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
409		return nil
410	}
411	return fmt.Errorf("%w: %s", err, string(out))
412}