init.go

  1// Package mcp provides functionality for managing Model Context Protocol (MCP)
  2// clients within the Crush application.
  3package mcp
  4
  5import (
  6	"cmp"
  7	"context"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"log/slog"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/home"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/version"
 25	"github.com/modelcontextprotocol/go-sdk/mcp"
 26)
 27
 28var (
 29	sessions = csync.NewMap[string, *mcp.ClientSession]()
 30	states   = csync.NewMap[string, ClientInfo]()
 31	broker   = pubsub.NewBroker[Event]()
 32	initOnce sync.Once
 33	initDone = make(chan struct{})
 34)
 35
 36// State represents the current state of an MCP client
 37type State int
 38
 39const (
 40	StateDisabled State = iota
 41	StateStarting
 42	StateConnected
 43	StateError
 44)
 45
 46func (s State) String() string {
 47	switch s {
 48	case StateDisabled:
 49		return "disabled"
 50	case StateStarting:
 51		return "starting"
 52	case StateConnected:
 53		return "connected"
 54	case StateError:
 55		return "error"
 56	default:
 57		return "unknown"
 58	}
 59}
 60
 61// EventType represents the type of MCP event
 62type EventType uint
 63
 64const (
 65	EventStateChanged EventType = iota
 66	EventToolsListChanged
 67	EventPromptsListChanged
 68)
 69
 70// Event represents an event in the MCP system
 71type Event struct {
 72	Type   EventType
 73	Name   string
 74	State  State
 75	Error  error
 76	Counts Counts
 77}
 78
 79// Counts number of available tools, prompts, etc.
 80type Counts struct {
 81	Tools   int
 82	Prompts int
 83}
 84
 85// ClientInfo holds information about an MCP client's state
 86type ClientInfo struct {
 87	Name        string
 88	State       State
 89	Error       error
 90	Client      *mcp.ClientSession
 91	Counts      Counts
 92	ConnectedAt time.Time
 93}
 94
 95// AddEventListener registers a callback for MCP events.
 96func AddEventListener(fn func(pubsub.Event[Event])) {
 97	broker.AddListener(fn)
 98}
 99
100// GetStates returns the current state of all MCP clients
101func GetStates() map[string]ClientInfo {
102	return states.Copy()
103}
104
105// GetState returns the state of a specific MCP client
106func GetState(name string) (ClientInfo, bool) {
107	return states.Get(name)
108}
109
110// Close closes all MCP clients. This should be called during application shutdown.
111func Close() error {
112	var wg sync.WaitGroup
113	done := make(chan struct{}, 1)
114	go func() {
115		for name, session := range sessions.Seq2() {
116			wg.Go(func() {
117				if err := session.Close(); err != nil &&
118					!errors.Is(err, io.EOF) &&
119					!errors.Is(err, context.Canceled) &&
120					err.Error() != "signal: killed" {
121					slog.Warn("Failed to shutdown MCP client", "name", name, "error", err)
122				}
123			})
124		}
125		wg.Wait()
126		done <- struct{}{}
127	}()
128	select {
129	case <-done:
130	case <-time.After(5 * time.Second):
131	}
132	return nil
133}
134
135// Initialize initializes MCP clients based on the provided configuration.
136func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
137	var wg sync.WaitGroup
138	// Initialize states for all configured MCPs
139	for name, m := range cfg.MCP {
140		if m.Disabled {
141			updateState(ctx, name, StateDisabled, nil, nil, Counts{})
142			slog.Debug("skipping disabled mcp", "name", name)
143			continue
144		}
145
146		// Set initial starting state
147		updateState(ctx, name, StateStarting, nil, nil, Counts{})
148
149		wg.Add(1)
150		go func(name string, m config.MCPConfig) {
151			defer func() {
152				wg.Done()
153				if r := recover(); r != nil {
154					var err error
155					switch v := r.(type) {
156					case error:
157						err = v
158					case string:
159						err = fmt.Errorf("panic: %s", v)
160					default:
161						err = fmt.Errorf("panic: %v", v)
162					}
163					updateState(ctx, name, StateError, err, nil, Counts{})
164					slog.Error("panic in mcp client initialization", "error", err, "name", name)
165				}
166			}()
167
168			// createSession handles its own timeout internally.
169			session, err := createSession(ctx, name, m, cfg.Resolver())
170			if err != nil {
171				return
172			}
173
174			tools, err := getTools(ctx, session)
175			if err != nil {
176				slog.Error("error listing tools", "error", err)
177				updateState(ctx, name, StateError, err, nil, Counts{})
178				session.Close()
179				return
180			}
181
182			prompts, err := getPrompts(ctx, session)
183			if err != nil {
184				slog.Error("error listing prompts", "error", err)
185				updateState(ctx, name, StateError, err, nil, Counts{})
186				session.Close()
187				return
188			}
189
190			toolCount := updateTools(name, tools)
191			updatePrompts(name, prompts)
192			sessions.Set(name, session)
193
194			updateState(ctx, name, StateConnected, nil, session, Counts{
195				Tools:   toolCount,
196				Prompts: len(prompts),
197			})
198		}(name, m)
199	}
200	wg.Wait()
201	initOnce.Do(func() { close(initDone) })
202}
203
204// WaitForInit blocks until MCP initialization is complete.
205// If Initialize was never called, this returns immediately.
206func WaitForInit(ctx context.Context) error {
207	select {
208	case <-initDone:
209		return nil
210	case <-ctx.Done():
211		return ctx.Err()
212	}
213}
214
215func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
216	sess, ok := sessions.Get(name)
217	if !ok {
218		return nil, fmt.Errorf("mcp '%s' not available", name)
219	}
220
221	cfg := config.Get()
222	m := cfg.MCP[name]
223	state, _ := states.Get(name)
224
225	timeout := mcpTimeout(m)
226	pingCtx, cancel := context.WithTimeout(ctx, timeout)
227	defer cancel()
228	err := sess.Ping(pingCtx, nil)
229	if err == nil {
230		return sess, nil
231	}
232	updateState(ctx, name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
233
234	sess, err = createSession(ctx, name, m, cfg.Resolver())
235	if err != nil {
236		return nil, err
237	}
238
239	updateState(ctx, name, StateConnected, nil, sess, state.Counts)
240	sessions.Set(name, sess)
241	return sess, nil
242}
243
244// updateState updates the state of an MCP client and publishes an event
245func updateState(ctx context.Context, name string, state State, err error, client *mcp.ClientSession, counts Counts) {
246	info := ClientInfo{
247		Name:   name,
248		State:  state,
249		Error:  err,
250		Client: client,
251		Counts: counts,
252	}
253	switch state {
254	case StateConnected:
255		info.ConnectedAt = time.Now()
256	case StateError:
257		sessions.Del(name)
258	}
259	states.Set(name, info)
260
261	// Publish state change event
262	broker.Publish(ctx, pubsub.UpdatedEvent, Event{
263		Type:   EventStateChanged,
264		Name:   name,
265		State:  state,
266		Error:  err,
267		Counts: counts,
268	})
269}
270
271func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
272	timeout := mcpTimeout(m)
273	mcpCtx, cancel := context.WithCancel(ctx)
274	cancelTimer := time.AfterFunc(timeout, cancel)
275
276	transport, err := createTransport(mcpCtx, m, resolver)
277	if err != nil {
278		updateState(ctx, name, StateError, err, nil, Counts{})
279		slog.Error("error creating mcp client", "error", err, "name", name)
280		cancel()
281		cancelTimer.Stop()
282		return nil, err
283	}
284
285	client := mcp.NewClient(
286		&mcp.Implementation{
287			Name:    "crush",
288			Version: version.Version,
289			Title:   "Crush",
290		},
291		&mcp.ClientOptions{
292			ToolListChangedHandler: func(ctx context.Context, _ *mcp.ToolListChangedRequest) {
293				broker.Publish(ctx, pubsub.UpdatedEvent, Event{
294					Type: EventToolsListChanged,
295					Name: name,
296				})
297			},
298			PromptListChangedHandler: func(ctx context.Context, _ *mcp.PromptListChangedRequest) {
299				broker.Publish(ctx, pubsub.UpdatedEvent, Event{
300					Type: EventPromptsListChanged,
301					Name: name,
302				})
303			},
304			LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
305				slog.Info("MCP log", "name", name, "data", req.Params.Data)
306			},
307		},
308	)
309
310	session, err := client.Connect(mcpCtx, transport, nil)
311	if err != nil {
312		err = maybeStdioErr(err, transport)
313		updateState(ctx, name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
314		slog.Error("MCP client failed to initialize", "error", err, "name", name)
315		cancel()
316		cancelTimer.Stop()
317		return nil, err
318	}
319
320	cancelTimer.Stop()
321	slog.Info("MCP client initialized", "name", name)
322	return session, nil
323}
324
325// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
326// to parse, and the cli will then close it, causing the EOF error.
327// so, if we got an EOF err, and the transport is STDIO, we try to exec it
328// again with a timeout and collect the output so we can add details to the
329// error.
330// this happens particularly when starting things with npx, e.g. if node can't
331// be found or some other error like that.
332func maybeStdioErr(err error, transport mcp.Transport) error {
333	if !errors.Is(err, io.EOF) {
334		return err
335	}
336	ct, ok := transport.(*mcp.CommandTransport)
337	if !ok {
338		return err
339	}
340	if err2 := stdioCheck(ct.Command); err2 != nil {
341		err = errors.Join(err, err2)
342	}
343	return err
344}
345
346func maybeTimeoutErr(err error, timeout time.Duration) error {
347	if errors.Is(err, context.Canceled) {
348		return fmt.Errorf("timed out after %s", timeout)
349	}
350	return err
351}
352
353func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
354	switch m.Type {
355	case config.MCPStdio:
356		command, err := resolver.ResolveValue(m.Command)
357		if err != nil {
358			return nil, fmt.Errorf("invalid mcp command: %w", err)
359		}
360		if strings.TrimSpace(command) == "" {
361			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
362		}
363		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
364		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
365		return &mcp.CommandTransport{
366			Command: cmd,
367		}, nil
368	case config.MCPHttp:
369		if strings.TrimSpace(m.URL) == "" {
370			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
371		}
372		client := &http.Client{
373			Transport: &headerRoundTripper{
374				headers: m.ResolvedHeaders(),
375			},
376		}
377		return &mcp.StreamableClientTransport{
378			Endpoint:   m.URL,
379			HTTPClient: client,
380		}, nil
381	case config.MCPSSE:
382		if strings.TrimSpace(m.URL) == "" {
383			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
384		}
385		client := &http.Client{
386			Transport: &headerRoundTripper{
387				headers: m.ResolvedHeaders(),
388			},
389		}
390		return &mcp.SSEClientTransport{
391			Endpoint:   m.URL,
392			HTTPClient: client,
393		}, nil
394	default:
395		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
396	}
397}
398
399type headerRoundTripper struct {
400	headers map[string]string
401}
402
403func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
404	for k, v := range rt.headers {
405		req.Header.Set(k, v)
406	}
407	return http.DefaultTransport.RoundTrip(req)
408}
409
410func mcpTimeout(m config.MCPConfig) time.Duration {
411	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
412}
413
414func stdioCheck(old *exec.Cmd) error {
415	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
416	defer cancel()
417	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
418	cmd.Env = old.Env
419	out, err := cmd.CombinedOutput()
420	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
421		return nil
422	}
423	return fmt.Errorf("%w: %s", err, string(out))
424}