init.go

  1package mcp
  2
  3import (
  4	"cmp"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"log/slog"
 10	"maps"
 11	"net/http"
 12	"os"
 13	"os/exec"
 14	"strings"
 15	"sync"
 16	"time"
 17
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/home"
 21	"github.com/charmbracelet/crush/internal/permission"
 22	"github.com/charmbracelet/crush/internal/pubsub"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/modelcontextprotocol/go-sdk/mcp"
 25)
 26
 27var (
 28	sessions = csync.NewMap[string, *mcp.ClientSession]()
 29	states   = csync.NewMap[string, ClientInfo]()
 30	broker   = pubsub.NewBroker[Event]()
 31)
 32
 33// State represents the current state of an MCP client
 34type State int
 35
 36const (
 37	StateDisabled State = iota
 38	StateStarting
 39	StateConnected
 40	StateError
 41)
 42
 43func (s State) String() string {
 44	switch s {
 45	case StateDisabled:
 46		return "disabled"
 47	case StateStarting:
 48		return "starting"
 49	case StateConnected:
 50		return "connected"
 51	case StateError:
 52		return "error"
 53	default:
 54		return "unknown"
 55	}
 56}
 57
 58// EventType represents the type of MCP event
 59type EventType string
 60
 61const (
 62	EventStateChanged       EventType = "state_changed"
 63	EventToolsListChanged   EventType = "tools_list_changed"
 64	EventPromptsListChanged EventType = "prompts_list_changed"
 65)
 66
 67// Event represents an event in the MCP system
 68type Event struct {
 69	Type   EventType
 70	Name   string
 71	State  State
 72	Error  error
 73	Counts Counts
 74}
 75
 76// Counts number of available tools, prompts, etc.
 77type Counts struct {
 78	Tools   int
 79	Prompts int
 80}
 81
 82// ClientInfo holds information about an MCP client's state
 83type ClientInfo struct {
 84	Name        string
 85	State       State
 86	Error       error
 87	Client      *mcp.ClientSession
 88	Counts      Counts
 89	ConnectedAt time.Time
 90}
 91
 92// SubscribeEvents returns a channel for MCP events
 93func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
 94	return broker.Subscribe(ctx)
 95}
 96
 97// GetStates returns the current state of all MCP clients
 98func GetStates() map[string]ClientInfo {
 99	return maps.Collect(states.Seq2())
100}
101
102// GetState returns the state of a specific MCP client
103func GetState(name string) (ClientInfo, bool) {
104	return states.Get(name)
105}
106
107// Close closes all MCP clients. This should be called during application shutdown.
108func Close() error {
109	var errs []error
110	for name, c := range sessions.Seq2() {
111		if err := c.Close(); err != nil &&
112			!errors.Is(err, io.EOF) &&
113			!errors.Is(err, context.Canceled) &&
114			err.Error() != "signal: killed" {
115			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
116		}
117	}
118	broker.Shutdown()
119	return errors.Join(errs...)
120}
121
122// Initialize initializes MCP clients based on the provided configuration.
123func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
124	var wg sync.WaitGroup
125	// Initialize states for all configured MCPs
126	for name, m := range cfg.MCP {
127		if m.Disabled {
128			updateState(name, StateDisabled, nil, nil, Counts{})
129			slog.Debug("skipping disabled mcp", "name", name)
130			continue
131		}
132
133		// Set initial starting state
134		updateState(name, StateStarting, nil, nil, Counts{})
135
136		wg.Add(1)
137		go func(name string, m config.MCPConfig) {
138			defer func() {
139				wg.Done()
140				if r := recover(); r != nil {
141					var err error
142					switch v := r.(type) {
143					case error:
144						err = v
145					case string:
146						err = fmt.Errorf("panic: %s", v)
147					default:
148						err = fmt.Errorf("panic: %v", v)
149					}
150					updateState(name, StateError, err, nil, Counts{})
151					slog.Error("panic in mcp client initialization", "error", err, "name", name)
152				}
153			}()
154
155			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
156			defer cancel()
157
158			session, err := createSession(ctx, name, m, cfg.Resolver())
159			if err != nil {
160				return
161			}
162
163			tools, err := getTools(ctx, name, permissions, session, cfg.WorkingDir())
164			if err != nil {
165				slog.Error("error listing tools", "error", err)
166				updateState(name, StateError, err, nil, Counts{})
167				session.Close()
168				return
169			}
170
171			prompts, err := getPrompts(ctx, session)
172			if err != nil {
173				slog.Error("error listing prompts", "error", err)
174				updateState(name, StateError, err, nil, Counts{})
175				session.Close()
176				return
177			}
178
179			updateTools(name, tools)
180			updatePrompts(name, prompts)
181			sessions.Set(name, session)
182
183			updateState(name, StateConnected, nil, session, Counts{
184				Tools:   len(tools),
185				Prompts: len(prompts),
186			})
187		}(name, m)
188	}
189	wg.Wait()
190}
191
192func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
193	sess, ok := sessions.Get(name)
194	if !ok {
195		return nil, fmt.Errorf("mcp '%s' not available", name)
196	}
197
198	cfg := config.Get()
199	m := cfg.MCP[name]
200	state, _ := states.Get(name)
201
202	timeout := mcpTimeout(m)
203	pingCtx, cancel := context.WithTimeout(ctx, timeout)
204	defer cancel()
205	err := sess.Ping(pingCtx, nil)
206	if err == nil {
207		return sess, nil
208	}
209	updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
210
211	sess, err = createSession(ctx, name, m, cfg.Resolver())
212	if err != nil {
213		return nil, err
214	}
215
216	updateState(name, StateConnected, nil, sess, state.Counts)
217	sessions.Set(name, sess)
218	return sess, nil
219}
220
221// updateState updates the state of an MCP client and publishes an event
222func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
223	info := ClientInfo{
224		Name:   name,
225		State:  state,
226		Error:  err,
227		Client: client,
228		Counts: counts,
229	}
230	switch state {
231	case StateConnected:
232		info.ConnectedAt = time.Now()
233	case StateError:
234		updateTools(name, nil)
235		sessions.Del(name)
236	}
237	states.Set(name, info)
238
239	// Publish state change event
240	broker.Publish(pubsub.UpdatedEvent, Event{
241		Type:   EventStateChanged,
242		Name:   name,
243		State:  state,
244		Error:  err,
245		Counts: counts,
246	})
247}
248
249func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
250	timeout := mcpTimeout(m)
251	mcpCtx, cancel := context.WithCancel(ctx)
252	cancelTimer := time.AfterFunc(timeout, cancel)
253
254	transport, err := createTransport(mcpCtx, m, resolver)
255	if err != nil {
256		updateState(name, StateError, err, nil, Counts{})
257		slog.Error("error creating mcp client", "error", err, "name", name)
258		cancel()
259		cancelTimer.Stop()
260		return nil, err
261	}
262
263	client := mcp.NewClient(
264		&mcp.Implementation{
265			Name:    "crush",
266			Version: version.Version,
267			Title:   "Crush",
268		},
269		&mcp.ClientOptions{
270			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
271				broker.Publish(pubsub.UpdatedEvent, Event{
272					Type: EventToolsListChanged,
273					Name: name,
274				})
275			},
276			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
277				broker.Publish(pubsub.UpdatedEvent, Event{
278					Type: EventPromptsListChanged,
279					Name: name,
280				})
281			},
282			KeepAlive: time.Minute * 10,
283		},
284	)
285
286	session, err := client.Connect(mcpCtx, transport, nil)
287	if err != nil {
288		err = maybeStdioErr(err, transport)
289		updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
290		slog.Error("error starting mcp client", "error", err, "name", name)
291		cancel()
292		cancelTimer.Stop()
293		return nil, err
294	}
295
296	cancelTimer.Stop()
297	slog.Info("Initialized mcp client", "name", name)
298	return session, nil
299}
300
301// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
302// to parse, and the cli will then close it, causing the EOF error.
303// so, if we got an EOF err, and the transport is STDIO, we try to exec it
304// again with a timeout and collect the output so we can add details to the
305// error.
306// this happens particularly when starting things with npx, e.g. if node can't
307// be found or some other error like that.
308func maybeStdioErr(err error, transport mcp.Transport) error {
309	if !errors.Is(err, io.EOF) {
310		return err
311	}
312	ct, ok := transport.(*mcp.CommandTransport)
313	if !ok {
314		return err
315	}
316	if err2 := stdioCheck(ct.Command); err2 != nil {
317		err = errors.Join(err, err2)
318	}
319	return err
320}
321
322func maybeTimeoutErr(err error, timeout time.Duration) error {
323	if errors.Is(err, context.Canceled) {
324		return fmt.Errorf("timed out after %s", timeout)
325	}
326	return err
327}
328
329func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
330	switch m.Type {
331	case config.MCPStdio:
332		command, err := resolver.ResolveValue(m.Command)
333		if err != nil {
334			return nil, fmt.Errorf("invalid mcp command: %w", err)
335		}
336		if strings.TrimSpace(command) == "" {
337			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
338		}
339		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
340		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
341		return &mcp.CommandTransport{
342			Command: cmd,
343		}, nil
344	case config.MCPHttp:
345		if strings.TrimSpace(m.URL) == "" {
346			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
347		}
348		client := &http.Client{
349			Transport: &headerRoundTripper{
350				headers: m.ResolvedHeaders(),
351			},
352		}
353		return &mcp.StreamableClientTransport{
354			Endpoint:   m.URL,
355			HTTPClient: client,
356		}, nil
357	case config.MCPSSE:
358		if strings.TrimSpace(m.URL) == "" {
359			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
360		}
361		client := &http.Client{
362			Transport: &headerRoundTripper{
363				headers: m.ResolvedHeaders(),
364			},
365		}
366		return &mcp.SSEClientTransport{
367			Endpoint:   m.URL,
368			HTTPClient: client,
369		}, nil
370	default:
371		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
372	}
373}
374
375type headerRoundTripper struct {
376	headers map[string]string
377}
378
379func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
380	for k, v := range rt.headers {
381		req.Header.Set(k, v)
382	}
383	return http.DefaultTransport.RoundTrip(req)
384}
385
386func mcpTimeout(m config.MCPConfig) time.Duration {
387	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
388}
389
390func stdioCheck(old *exec.Cmd) error {
391	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
392	defer cancel()
393	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
394	cmd.Env = old.Env
395	out, err := cmd.CombinedOutput()
396	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
397		return nil
398	}
399	return fmt.Errorf("%w: %s", err, string(out))
400}