init.go

  1// Package mcp provides functionality for managing Model Context Protocol (MCP)
  2// clients within the Crush application.
  3package mcp
  4
  5import (
  6	"cmp"
  7	"context"
  8	"errors"
  9	"fmt"
 10	"io"
 11	"log/slog"
 12	"maps"
 13	"net/http"
 14	"os"
 15	"os/exec"
 16	"strings"
 17	"sync"
 18	"time"
 19
 20	"github.com/charmbracelet/crush/internal/config"
 21	"github.com/charmbracelet/crush/internal/csync"
 22	"github.com/charmbracelet/crush/internal/home"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/version"
 26	"github.com/modelcontextprotocol/go-sdk/mcp"
 27)
 28
 29var (
 30	sessions = csync.NewMap[string, *mcp.ClientSession]()
 31	states   = csync.NewMap[string, ClientInfo]()
 32	broker   = pubsub.NewBroker[Event]()
 33)
 34
 35// State represents the current state of an MCP client
 36type State int
 37
 38const (
 39	StateDisabled State = iota
 40	StateStarting
 41	StateConnected
 42	StateError
 43)
 44
 45func (s State) String() string {
 46	switch s {
 47	case StateDisabled:
 48		return "disabled"
 49	case StateStarting:
 50		return "starting"
 51	case StateConnected:
 52		return "connected"
 53	case StateError:
 54		return "error"
 55	default:
 56		return "unknown"
 57	}
 58}
 59
 60// EventType represents the type of MCP event
 61type EventType uint
 62
 63const (
 64	EventStateChanged EventType = iota
 65	EventToolsListChanged
 66	EventPromptsListChanged
 67)
 68
 69// Event represents an event in the MCP system
 70type Event struct {
 71	Type   EventType
 72	Name   string
 73	State  State
 74	Error  error
 75	Counts Counts
 76}
 77
 78// Counts number of available tools, prompts, etc.
 79type Counts struct {
 80	Tools   int
 81	Prompts int
 82}
 83
 84// ClientInfo holds information about an MCP client's state
 85type ClientInfo struct {
 86	Name        string
 87	State       State
 88	Error       error
 89	Client      *mcp.ClientSession
 90	Counts      Counts
 91	ConnectedAt time.Time
 92}
 93
 94// SubscribeEvents returns a channel for MCP events
 95func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] {
 96	return broker.Subscribe(ctx)
 97}
 98
 99// GetStates returns the current state of all MCP clients
100func GetStates() map[string]ClientInfo {
101	return maps.Collect(states.Seq2())
102}
103
104// GetState returns the state of a specific MCP client
105func GetState(name string) (ClientInfo, bool) {
106	return states.Get(name)
107}
108
109// Close closes all MCP clients. This should be called during application shutdown.
110func Close() error {
111	var errs []error
112	var wg sync.WaitGroup
113	for name, session := range sessions.Seq2() {
114		wg.Go(func() {
115			done := make(chan bool, 1)
116			go func() {
117				if err := session.Close(); err != nil &&
118					!errors.Is(err, io.EOF) &&
119					!errors.Is(err, context.Canceled) &&
120					err.Error() != "signal: killed" {
121					errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
122				}
123				done <- true
124			}()
125			select {
126			case <-done:
127			case <-time.After(time.Millisecond * 250):
128			}
129		})
130	}
131	wg.Wait()
132	broker.Shutdown()
133	return errors.Join(errs...)
134}
135
136// Initialize initializes MCP clients based on the provided configuration.
137func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) {
138	var wg sync.WaitGroup
139	// Initialize states for all configured MCPs
140	for name, m := range cfg.MCP {
141		if m.Disabled {
142			updateState(name, StateDisabled, nil, nil, Counts{})
143			slog.Debug("skipping disabled mcp", "name", name)
144			continue
145		}
146
147		// Set initial starting state
148		updateState(name, StateStarting, nil, nil, Counts{})
149
150		wg.Add(1)
151		go func(name string, m config.MCPConfig) {
152			defer func() {
153				wg.Done()
154				if r := recover(); r != nil {
155					var err error
156					switch v := r.(type) {
157					case error:
158						err = v
159					case string:
160						err = fmt.Errorf("panic: %s", v)
161					default:
162						err = fmt.Errorf("panic: %v", v)
163					}
164					updateState(name, StateError, err, nil, Counts{})
165					slog.Error("panic in mcp client initialization", "error", err, "name", name)
166				}
167			}()
168
169			// createSession handles its own timeout internally.
170			session, err := createSession(ctx, name, m, cfg.Resolver())
171			if err != nil {
172				return
173			}
174
175			tools, err := getTools(ctx, session)
176			if err != nil {
177				slog.Error("error listing tools", "error", err)
178				updateState(name, StateError, err, nil, Counts{})
179				session.Close()
180				return
181			}
182
183			prompts, err := getPrompts(ctx, session)
184			if err != nil {
185				slog.Error("error listing prompts", "error", err)
186				updateState(name, StateError, err, nil, Counts{})
187				session.Close()
188				return
189			}
190
191			updateTools(name, tools)
192			updatePrompts(name, prompts)
193			sessions.Set(name, session)
194
195			updateState(name, StateConnected, nil, session, Counts{
196				Tools:   len(tools),
197				Prompts: len(prompts),
198			})
199		}(name, m)
200	}
201	wg.Wait()
202}
203
204func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
205	sess, ok := sessions.Get(name)
206	if !ok {
207		return nil, fmt.Errorf("mcp '%s' not available", name)
208	}
209
210	cfg := config.Get()
211	m := cfg.MCP[name]
212	state, _ := states.Get(name)
213
214	timeout := mcpTimeout(m)
215	pingCtx, cancel := context.WithTimeout(ctx, timeout)
216	defer cancel()
217	err := sess.Ping(pingCtx, nil)
218	if err == nil {
219		return sess, nil
220	}
221	updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts)
222
223	sess, err = createSession(ctx, name, m, cfg.Resolver())
224	if err != nil {
225		return nil, err
226	}
227
228	updateState(name, StateConnected, nil, sess, state.Counts)
229	sessions.Set(name, sess)
230	return sess, nil
231}
232
233// updateState updates the state of an MCP client and publishes an event
234func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) {
235	info := ClientInfo{
236		Name:   name,
237		State:  state,
238		Error:  err,
239		Client: client,
240		Counts: counts,
241	}
242	switch state {
243	case StateConnected:
244		info.ConnectedAt = time.Now()
245	case StateError:
246		sessions.Del(name)
247	}
248	states.Set(name, info)
249
250	// Publish state change event
251	broker.Publish(pubsub.UpdatedEvent, Event{
252		Type:   EventStateChanged,
253		Name:   name,
254		State:  state,
255		Error:  err,
256		Counts: counts,
257	})
258}
259
260func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
261	timeout := mcpTimeout(m)
262	mcpCtx, cancel := context.WithCancel(ctx)
263	cancelTimer := time.AfterFunc(timeout, cancel)
264
265	transport, err := createTransport(mcpCtx, m, resolver)
266	if err != nil {
267		updateState(name, StateError, err, nil, Counts{})
268		slog.Error("error creating mcp client", "error", err, "name", name)
269		cancel()
270		cancelTimer.Stop()
271		return nil, err
272	}
273
274	client := mcp.NewClient(
275		&mcp.Implementation{
276			Name:    "crush",
277			Version: version.Version,
278			Title:   "Crush",
279		},
280		&mcp.ClientOptions{
281			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
282				broker.Publish(pubsub.UpdatedEvent, Event{
283					Type: EventToolsListChanged,
284					Name: name,
285				})
286			},
287			PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) {
288				broker.Publish(pubsub.UpdatedEvent, Event{
289					Type: EventPromptsListChanged,
290					Name: name,
291				})
292			},
293			LoggingMessageHandler: func(_ context.Context, req *mcp.LoggingMessageRequest) {
294				slog.Info("MCP log", "name", name, "data", req.Params.Data)
295			},
296		},
297	)
298
299	session, err := client.Connect(mcpCtx, transport, nil)
300	if err != nil {
301		err = maybeStdioErr(err, transport)
302		updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{})
303		slog.Error("MCP client failed to initialize", "error", err, "name", name)
304		cancel()
305		cancelTimer.Stop()
306		return nil, err
307	}
308
309	cancelTimer.Stop()
310	slog.Info("MCP client initialized", "name", name)
311	return session, nil
312}
313
314// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
315// to parse, and the cli will then close it, causing the EOF error.
316// so, if we got an EOF err, and the transport is STDIO, we try to exec it
317// again with a timeout and collect the output so we can add details to the
318// error.
319// this happens particularly when starting things with npx, e.g. if node can't
320// be found or some other error like that.
321func maybeStdioErr(err error, transport mcp.Transport) error {
322	if !errors.Is(err, io.EOF) {
323		return err
324	}
325	ct, ok := transport.(*mcp.CommandTransport)
326	if !ok {
327		return err
328	}
329	if err2 := stdioCheck(ct.Command); err2 != nil {
330		err = errors.Join(err, err2)
331	}
332	return err
333}
334
335func maybeTimeoutErr(err error, timeout time.Duration) error {
336	if errors.Is(err, context.Canceled) {
337		return fmt.Errorf("timed out after %s", timeout)
338	}
339	return err
340}
341
342func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
343	switch m.Type {
344	case config.MCPStdio:
345		command, err := resolver.ResolveValue(m.Command)
346		if err != nil {
347			return nil, fmt.Errorf("invalid mcp command: %w", err)
348		}
349		if strings.TrimSpace(command) == "" {
350			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
351		}
352		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
353		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
354		return &mcp.CommandTransport{
355			Command: cmd,
356		}, nil
357	case config.MCPHttp:
358		if strings.TrimSpace(m.URL) == "" {
359			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
360		}
361		client := &http.Client{
362			Transport: &headerRoundTripper{
363				headers: m.ResolvedHeaders(),
364			},
365		}
366		return &mcp.StreamableClientTransport{
367			Endpoint:   m.URL,
368			HTTPClient: client,
369		}, nil
370	case config.MCPSSE:
371		if strings.TrimSpace(m.URL) == "" {
372			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
373		}
374		client := &http.Client{
375			Transport: &headerRoundTripper{
376				headers: m.ResolvedHeaders(),
377			},
378		}
379		return &mcp.SSEClientTransport{
380			Endpoint:   m.URL,
381			HTTPClient: client,
382		}, nil
383	default:
384		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
385	}
386}
387
388type headerRoundTripper struct {
389	headers map[string]string
390}
391
392func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
393	for k, v := range rt.headers {
394		req.Header.Set(k, v)
395	}
396	return http.DefaultTransport.RoundTrip(req)
397}
398
399func mcpTimeout(m config.MCPConfig) time.Duration {
400	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
401}
402
403func stdioCheck(old *exec.Cmd) error {
404	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
405	defer cancel()
406	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
407	cmd.Env = old.Env
408	out, err := cmd.CombinedOutput()
409	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
410		return nil
411	}
412	return fmt.Errorf("%w: %s", err, string(out))
413}