mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"log/slog"
 11	"maps"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/home"
 22	"github.com/charmbracelet/crush/internal/llm/tools"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/version"
 26	"github.com/modelcontextprotocol/go-sdk/mcp"
 27)
 28
 29// MCPState represents the current state of an MCP client
 30type MCPState int
 31
 32const (
 33	MCPStateDisabled MCPState = iota
 34	MCPStateStarting
 35	MCPStateConnected
 36	MCPStateError
 37)
 38
 39func (s MCPState) String() string {
 40	switch s {
 41	case MCPStateDisabled:
 42		return "disabled"
 43	case MCPStateStarting:
 44		return "starting"
 45	case MCPStateConnected:
 46		return "connected"
 47	case MCPStateError:
 48		return "error"
 49	default:
 50		return "unknown"
 51	}
 52}
 53
 54// MCPEventType represents the type of MCP event
 55type MCPEventType string
 56
 57const (
 58	MCPEventStateChanged     MCPEventType = "state_changed"
 59	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 60)
 61
 62// MCPEvent represents an event in the MCP system
 63type MCPEvent struct {
 64	Type      MCPEventType
 65	Name      string
 66	State     MCPState
 67	Error     error
 68	ToolCount int
 69}
 70
 71// MCPClientInfo holds information about an MCP client's state
 72type MCPClientInfo struct {
 73	Name        string
 74	State       MCPState
 75	Error       error
 76	Client      *mcp.ClientSession
 77	ToolCount   int
 78	ConnectedAt time.Time
 79}
 80
 81var (
 82	mcpToolsOnce    sync.Once
 83	mcpTools        = csync.NewMap[string, tools.BaseTool]()
 84	mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
 85	mcpClients      = csync.NewMap[string, *mcp.ClientSession]()
 86	mcpStates       = csync.NewMap[string, MCPClientInfo]()
 87	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 88)
 89
 90type McpTool struct {
 91	mcpName     string
 92	tool        *mcp.Tool
 93	permissions permission.Service
 94	workingDir  string
 95}
 96
 97func (b *McpTool) Name() string {
 98	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 99}
100
101func (b *McpTool) Info() tools.ToolInfo {
102	parameters := make(map[string]any)
103	required := make([]string, 0)
104
105	if input, ok := b.tool.InputSchema.(map[string]any); ok {
106		if props, ok := input["properties"].(map[string]any); ok {
107			parameters = props
108		}
109		if req, ok := input["required"].([]any); ok {
110			// Convert []any -> []string when elements are strings
111			for _, v := range req {
112				if s, ok := v.(string); ok {
113					required = append(required, s)
114				}
115			}
116		} else if reqStr, ok := input["required"].([]string); ok {
117			// Handle case where it's already []string
118			required = reqStr
119		}
120	}
121
122	return tools.ToolInfo{
123		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
124		Description: b.tool.Description,
125		Parameters:  parameters,
126		Required:    required,
127	}
128}
129
130func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
131	var args map[string]any
132	if err := json.Unmarshal([]byte(input), &args); err != nil {
133		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
134	}
135
136	c, err := getOrRenewClient(ctx, name)
137	if err != nil {
138		return tools.NewTextErrorResponse(err.Error()), nil
139	}
140	result, err := c.CallTool(ctx, &mcp.CallToolParams{
141		Name:      toolName,
142		Arguments: args,
143	})
144	if err != nil {
145		return tools.NewTextErrorResponse(err.Error()), nil
146	}
147
148	output := make([]string, 0, len(result.Content))
149	for _, v := range result.Content {
150		if vv, ok := v.(*mcp.TextContent); ok {
151			output = append(output, vv.Text)
152		} else {
153			output = append(output, fmt.Sprintf("%v", v))
154		}
155	}
156	return tools.NewTextResponse(strings.Join(output, "\n")), nil
157}
158
159func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
160	sess, ok := mcpClients.Get(name)
161	if !ok {
162		return nil, fmt.Errorf("mcp '%s' not available", name)
163	}
164
165	cfg := config.Get()
166	m := cfg.MCP[name]
167	state, _ := mcpStates.Get(name)
168
169	timeout := mcpTimeout(m)
170	pingCtx, cancel := context.WithTimeout(ctx, timeout)
171	defer cancel()
172	err := sess.Ping(pingCtx, nil)
173	if err == nil {
174		return sess, nil
175	}
176	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
177
178	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
179	if err != nil {
180		return nil, err
181	}
182
183	updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
184	mcpClients.Set(name, sess)
185	return sess, nil
186}
187
188func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
189	sessionID, messageID := tools.GetContextValues(ctx)
190	if sessionID == "" || messageID == "" {
191		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
192	}
193	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
194	p := b.permissions.Request(
195		permission.CreatePermissionRequest{
196			SessionID:   sessionID,
197			ToolCallID:  params.ID,
198			Path:        b.workingDir,
199			ToolName:    b.Info().Name,
200			Action:      "execute",
201			Description: permissionDescription,
202			Params:      params.Input,
203		},
204	)
205	if !p {
206		return tools.ToolResponse{}, permission.ErrorPermissionDenied
207	}
208
209	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
210}
211
212func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
213	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
214	if err != nil {
215		return nil, err
216	}
217	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
218	for _, tool := range result.Tools {
219		mcpTools = append(mcpTools, &McpTool{
220			mcpName:     name,
221			tool:        tool,
222			permissions: permissions,
223			workingDir:  workingDir,
224		})
225	}
226	return mcpTools, nil
227}
228
229// SubscribeMCPEvents returns a channel for MCP events
230func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
231	return mcpBroker.Subscribe(ctx)
232}
233
234// GetMCPStates returns the current state of all MCP clients
235func GetMCPStates() map[string]MCPClientInfo {
236	return maps.Collect(mcpStates.Seq2())
237}
238
239// GetMCPState returns the state of a specific MCP client
240func GetMCPState(name string) (MCPClientInfo, bool) {
241	return mcpStates.Get(name)
242}
243
244// updateMCPState updates the state of an MCP client and publishes an event
245func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
246	info := MCPClientInfo{
247		Name:      name,
248		State:     state,
249		Error:     err,
250		Client:    client,
251		ToolCount: toolCount,
252	}
253	switch state {
254	case MCPStateConnected:
255		info.ConnectedAt = time.Now()
256	case MCPStateError:
257		updateMcpTools(name, nil)
258		mcpClients.Del(name)
259	}
260	mcpStates.Set(name, info)
261
262	// Publish state change event
263	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
264		Type:      MCPEventStateChanged,
265		Name:      name,
266		State:     state,
267		Error:     err,
268		ToolCount: toolCount,
269	})
270}
271
272// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
273func CloseMCPClients() error {
274	var errs []error
275	for name, c := range mcpClients.Seq2() {
276		if err := c.Close(); err != nil &&
277			!errors.Is(err, io.EOF) &&
278			!errors.Is(err, context.Canceled) &&
279			err.Error() != "signal: killed" {
280			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
281		}
282	}
283	mcpBroker.Shutdown()
284	return errors.Join(errs...)
285}
286
287func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
288	var wg sync.WaitGroup
289	// Initialize states for all configured MCPs
290	for name, m := range cfg.MCP {
291		if m.Disabled {
292			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
293			slog.Debug("skipping disabled mcp", "name", name)
294			continue
295		}
296
297		// Set initial starting state
298		updateMCPState(name, MCPStateStarting, nil, nil, 0)
299
300		wg.Add(1)
301		go func(name string, m config.MCPConfig) {
302			defer func() {
303				wg.Done()
304				if r := recover(); r != nil {
305					var err error
306					switch v := r.(type) {
307					case error:
308						err = v
309					case string:
310						err = fmt.Errorf("panic: %s", v)
311					default:
312						err = fmt.Errorf("panic: %v", v)
313					}
314					updateMCPState(name, MCPStateError, err, nil, 0)
315					slog.Error("panic in mcp client initialization", "error", err, "name", name)
316				}
317			}()
318
319			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
320			defer cancel()
321
322			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
323			if err != nil {
324				return
325			}
326
327			mcpClients.Set(name, c)
328
329			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
330			if err != nil {
331				slog.Error("error listing tools", "error", err)
332				updateMCPState(name, MCPStateError, err, nil, 0)
333				c.Close()
334				return
335			}
336
337			updateMcpTools(name, tools)
338			mcpClients.Set(name, c)
339			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
340		}(name, m)
341	}
342	wg.Wait()
343}
344
345// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
346func updateMcpTools(mcpName string, tools []tools.BaseTool) {
347	if len(tools) == 0 {
348		mcpClient2Tools.Del(mcpName)
349	} else {
350		mcpClient2Tools.Set(mcpName, tools)
351	}
352	for _, tools := range mcpClient2Tools.Seq2() {
353		for _, t := range tools {
354			mcpTools.Set(t.Name(), t)
355		}
356	}
357}
358
359func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
360	timeout := mcpTimeout(m)
361	mcpCtx, cancel := context.WithCancel(ctx)
362	cancelTimer := time.AfterFunc(timeout, cancel)
363
364	transport, err := createMCPTransport(mcpCtx, m, resolver)
365	if err != nil {
366		updateMCPState(name, MCPStateError, err, nil, 0)
367		slog.Error("error creating mcp client", "error", err, "name", name)
368		return nil, err
369	}
370
371	client := mcp.NewClient(
372		&mcp.Implementation{
373			Name:    "crush",
374			Version: version.Version,
375			Title:   "Crush",
376		},
377		&mcp.ClientOptions{
378			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
379				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
380					Type: MCPEventToolsListChanged,
381					Name: name,
382				})
383			},
384			KeepAlive: time.Minute * 10,
385		},
386	)
387
388	session, err := client.Connect(mcpCtx, transport, nil)
389	if err != nil {
390		err = maybeStdioErr(err, transport)
391		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
392		slog.Error("error starting mcp client", "error", err, "name", name)
393		cancel()
394		return nil, err
395	}
396
397	cancelTimer.Stop()
398	slog.Info("Initialized mcp client", "name", name)
399	return session, nil
400}
401
402// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail
403// to parse, and the cli will then close it, causing the EOF error.
404// so, if we got an EOF err, and the transport is STDIO, we try to exec it
405// again with a timeout and collect the output so we can add details to the
406// error.
407// this happens particularly when starting things with npx, e.g. if node can't
408// be found or some other error like that.
409func maybeStdioErr(err error, transport mcp.Transport) error {
410	if !errors.Is(err, io.EOF) {
411		return err
412	}
413	ct, ok := transport.(*mcp.CommandTransport)
414	if !ok {
415		return err
416	}
417	if err2 := stdioMCPCheck(ct.Command); err2 != nil {
418		err = errors.Join(err, err2)
419	}
420	return err
421}
422
423func maybeTimeoutErr(err error, timeout time.Duration) error {
424	if errors.Is(err, context.Canceled) {
425		return fmt.Errorf("timed out after %s", timeout)
426	}
427	return err
428}
429
430func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
431	switch m.Type {
432	case config.MCPStdio:
433		command, err := resolver.ResolveValue(m.Command)
434		if err != nil {
435			return nil, fmt.Errorf("invalid mcp command: %w", err)
436		}
437		if strings.TrimSpace(command) == "" {
438			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
439		}
440		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
441		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
442		return &mcp.CommandTransport{
443			Command: cmd,
444		}, nil
445	case config.MCPHttp:
446		if strings.TrimSpace(m.URL) == "" {
447			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
448		}
449		client := &http.Client{
450			Transport: &headerRoundTripper{
451				headers: m.ResolvedHeaders(),
452			},
453		}
454		return &mcp.StreamableClientTransport{
455			Endpoint:   m.URL,
456			HTTPClient: client,
457		}, nil
458	case config.MCPSSE:
459		if strings.TrimSpace(m.URL) == "" {
460			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
461		}
462		client := &http.Client{
463			Transport: &headerRoundTripper{
464				headers: m.ResolvedHeaders(),
465			},
466		}
467		return &mcp.SSEClientTransport{
468			Endpoint:   m.URL,
469			HTTPClient: client,
470		}, nil
471	default:
472		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
473	}
474}
475
476type headerRoundTripper struct {
477	headers map[string]string
478}
479
480func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
481	for k, v := range rt.headers {
482		req.Header.Set(k, v)
483	}
484	return http.DefaultTransport.RoundTrip(req)
485}
486
487func mcpTimeout(m config.MCPConfig) time.Duration {
488	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
489}
490
491func stdioMCPCheck(old *exec.Cmd) error {
492	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
493	defer cancel()
494	cmd := exec.CommandContext(ctx, old.Path, old.Args...)
495	cmd.Env = old.Env
496	out, err := cmd.CombinedOutput()
497	if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) {
498		return nil
499	}
500	return fmt.Errorf("%w: %s", err, string(out))
501}