mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"log/slog"
 11	"maps"
 12	"net/http"
 13	"os"
 14	"os/exec"
 15	"strings"
 16	"sync"
 17	"time"
 18
 19	"github.com/charmbracelet/crush/internal/config"
 20	"github.com/charmbracelet/crush/internal/csync"
 21	"github.com/charmbracelet/crush/internal/home"
 22	"github.com/charmbracelet/crush/internal/llm/tools"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/version"
 26	"github.com/modelcontextprotocol/go-sdk/mcp"
 27)
 28
 29// MCPState represents the current state of an MCP client
 30type MCPState int
 31
 32const (
 33	MCPStateDisabled MCPState = iota
 34	MCPStateStarting
 35	MCPStateConnected
 36	MCPStateError
 37)
 38
 39func (s MCPState) String() string {
 40	switch s {
 41	case MCPStateDisabled:
 42		return "disabled"
 43	case MCPStateStarting:
 44		return "starting"
 45	case MCPStateConnected:
 46		return "connected"
 47	case MCPStateError:
 48		return "error"
 49	default:
 50		return "unknown"
 51	}
 52}
 53
 54// MCPEventType represents the type of MCP event
 55type MCPEventType string
 56
 57const (
 58	MCPEventStateChanged     MCPEventType = "state_changed"
 59	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 60)
 61
 62// MCPEvent represents an event in the MCP system
 63type MCPEvent struct {
 64	Type      MCPEventType
 65	Name      string
 66	State     MCPState
 67	Error     error
 68	ToolCount int
 69}
 70
 71// MCPClientInfo holds information about an MCP client's state
 72type MCPClientInfo struct {
 73	Name        string
 74	State       MCPState
 75	Error       error
 76	Client      *mcp.ClientSession
 77	ToolCount   int
 78	ConnectedAt time.Time
 79}
 80
 81var (
 82	mcpToolsOnce    sync.Once
 83	mcpTools        = csync.NewMap[string, tools.BaseTool]()
 84	mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
 85	mcpClients      = csync.NewMap[string, *mcp.ClientSession]()
 86	mcpStates       = csync.NewMap[string, MCPClientInfo]()
 87	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 88)
 89
 90type McpTool struct {
 91	mcpName     string
 92	tool        *mcp.Tool
 93	permissions permission.Service
 94	workingDir  string
 95}
 96
 97func (b *McpTool) Name() string {
 98	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 99}
100
101func (b *McpTool) Info() tools.ToolInfo {
102	var parameters map[string]any
103	var required []string
104
105	if input, ok := b.tool.InputSchema.(map[string]any); ok {
106		if props, ok := input["properties"].(map[string]any); ok {
107			parameters = props
108		}
109		if req, ok := input["required"].([]any); ok {
110			// Convert []any -> []string when elements are strings
111			for _, v := range req {
112				if s, ok := v.(string); ok {
113					required = append(required, s)
114				}
115			}
116		} else if reqStr, ok := input["required"].([]string); ok {
117			// Handle case where it's already []string
118			required = reqStr
119		}
120	}
121
122	return tools.ToolInfo{
123		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
124		Description: b.tool.Description,
125		Parameters:  parameters,
126		Required:    required,
127	}
128}
129
130func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
131	var args map[string]any
132	if err := json.Unmarshal([]byte(input), &args); err != nil {
133		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
134	}
135
136	c, err := getOrRenewClient(ctx, name)
137	if err != nil {
138		return tools.NewTextErrorResponse(err.Error()), nil
139	}
140	result, err := c.CallTool(ctx, &mcp.CallToolParams{
141		Name:      toolName,
142		Arguments: args,
143	})
144	if err != nil {
145		return tools.NewTextErrorResponse(err.Error()), nil
146	}
147
148	output := make([]string, 0, len(result.Content))
149	for _, v := range result.Content {
150		if vv, ok := v.(*mcp.TextContent); ok {
151			output = append(output, vv.Text)
152		} else {
153			output = append(output, fmt.Sprintf("%v", v))
154		}
155	}
156	return tools.NewTextResponse(strings.Join(output, "\n")), nil
157}
158
159func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
160	sess, ok := mcpClients.Get(name)
161	if !ok {
162		return nil, fmt.Errorf("mcp '%s' not available", name)
163	}
164
165	cfg := config.Get()
166	m := cfg.MCP[name]
167	state, _ := mcpStates.Get(name)
168
169	timeout := mcpTimeout(m)
170	pingCtx, cancel := context.WithTimeout(ctx, timeout)
171	defer cancel()
172	err := sess.Ping(pingCtx, nil)
173	if err == nil {
174		return sess, nil
175	}
176	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
177
178	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
179	if err != nil {
180		return nil, err
181	}
182
183	updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
184	mcpClients.Set(name, sess)
185	return sess, nil
186}
187
188func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
189	sessionID, messageID := tools.GetContextValues(ctx)
190	if sessionID == "" || messageID == "" {
191		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
192	}
193	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
194	p := b.permissions.Request(
195		permission.CreatePermissionRequest{
196			SessionID:   sessionID,
197			ToolCallID:  params.ID,
198			Path:        b.workingDir,
199			ToolName:    b.Info().Name,
200			Action:      "execute",
201			Description: permissionDescription,
202			Params:      params.Input,
203		},
204	)
205	if !p {
206		return tools.ToolResponse{}, permission.ErrorPermissionDenied
207	}
208
209	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
210}
211
212func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
213	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
214	if err != nil {
215		return nil, err
216	}
217	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
218	for _, tool := range result.Tools {
219		mcpTools = append(mcpTools, &McpTool{
220			mcpName:     name,
221			tool:        tool,
222			permissions: permissions,
223			workingDir:  workingDir,
224		})
225	}
226	return mcpTools, nil
227}
228
229// SubscribeMCPEvents returns a channel for MCP events
230func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
231	return mcpBroker.Subscribe(ctx)
232}
233
234// GetMCPStates returns the current state of all MCP clients
235func GetMCPStates() map[string]MCPClientInfo {
236	return maps.Collect(mcpStates.Seq2())
237}
238
239// GetMCPState returns the state of a specific MCP client
240func GetMCPState(name string) (MCPClientInfo, bool) {
241	return mcpStates.Get(name)
242}
243
244// updateMCPState updates the state of an MCP client and publishes an event
245func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
246	info := MCPClientInfo{
247		Name:      name,
248		State:     state,
249		Error:     err,
250		Client:    client,
251		ToolCount: toolCount,
252	}
253	switch state {
254	case MCPStateConnected:
255		info.ConnectedAt = time.Now()
256	case MCPStateError:
257		updateMcpTools(name, nil)
258		mcpClients.Del(name)
259	}
260	mcpStates.Set(name, info)
261
262	// Publish state change event
263	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
264		Type:      MCPEventStateChanged,
265		Name:      name,
266		State:     state,
267		Error:     err,
268		ToolCount: toolCount,
269	})
270}
271
272// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
273func CloseMCPClients() error {
274	var errs []error
275	for name, c := range mcpClients.Seq2() {
276		if err := c.Close(); err != nil &&
277			!errors.Is(err, io.EOF) &&
278			!errors.Is(err, context.Canceled) &&
279			err.Error() != "signal: killed" {
280			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
281		}
282	}
283	mcpBroker.Shutdown()
284	return errors.Join(errs...)
285}
286
287func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
288	var wg sync.WaitGroup
289	// Initialize states for all configured MCPs
290	for name, m := range cfg.MCP {
291		if m.Disabled {
292			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
293			slog.Debug("skipping disabled mcp", "name", name)
294			continue
295		}
296
297		// Set initial starting state
298		updateMCPState(name, MCPStateStarting, nil, nil, 0)
299
300		wg.Add(1)
301		go func(name string, m config.MCPConfig) {
302			defer func() {
303				wg.Done()
304				if r := recover(); r != nil {
305					var err error
306					switch v := r.(type) {
307					case error:
308						err = v
309					case string:
310						err = fmt.Errorf("panic: %s", v)
311					default:
312						err = fmt.Errorf("panic: %v", v)
313					}
314					updateMCPState(name, MCPStateError, err, nil, 0)
315					slog.Error("panic in mcp client initialization", "error", err, "name", name)
316				}
317			}()
318
319			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
320			defer cancel()
321
322			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
323			if err != nil {
324				return
325			}
326
327			mcpClients.Set(name, c)
328
329			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
330			if err != nil {
331				slog.Error("error listing tools", "error", err)
332				updateMCPState(name, MCPStateError, err, nil, 0)
333				c.Close()
334				return
335			}
336
337			updateMcpTools(name, tools)
338			mcpClients.Set(name, c)
339			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
340		}(name, m)
341	}
342	wg.Wait()
343}
344
345// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
346func updateMcpTools(mcpName string, tools []tools.BaseTool) {
347	if len(tools) == 0 {
348		mcpClient2Tools.Del(mcpName)
349	} else {
350		mcpClient2Tools.Set(mcpName, tools)
351	}
352	for _, tools := range mcpClient2Tools.Seq2() {
353		for _, t := range tools {
354			mcpTools.Set(t.Name(), t)
355		}
356	}
357}
358
359func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
360	timeout := mcpTimeout(m)
361	mcpCtx, cancel := context.WithCancel(ctx)
362	cancelTimer := time.AfterFunc(timeout, cancel)
363
364	transport, err := createMCPTransport(mcpCtx, m, resolver)
365	if err != nil {
366		updateMCPState(name, MCPStateError, err, nil, 0)
367		slog.Error("error creating mcp client", "error", err, "name", name)
368		return nil, err
369	}
370
371	client := mcp.NewClient(
372		&mcp.Implementation{
373			Name:    "crush",
374			Version: version.Version,
375			Title:   "Crush",
376		},
377		&mcp.ClientOptions{
378			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
379				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
380					Type: MCPEventToolsListChanged,
381					Name: name,
382				})
383			},
384			KeepAlive: time.Minute * 10,
385		},
386	)
387
388	session, err := client.Connect(mcpCtx, transport, nil)
389	if err != nil {
390		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
391		slog.Error("error starting mcp client", "error", err, "name", name)
392		cancel()
393		return nil, err
394	}
395
396	cancelTimer.Stop()
397	slog.Info("Initialized mcp client", "name", name)
398	return session, nil
399}
400
401func maybeTimeoutErr(err error, timeout time.Duration) error {
402	if errors.Is(err, context.Canceled) {
403		return fmt.Errorf("timed out after %s", timeout)
404	}
405	return err
406}
407
408func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
409	switch m.Type {
410	case config.MCPStdio:
411		command, err := resolver.ResolveValue(m.Command)
412		if err != nil {
413			return nil, fmt.Errorf("invalid mcp command: %w", err)
414		}
415		if strings.TrimSpace(command) == "" {
416			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
417		}
418		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
419		cmd.Env = append(os.Environ(), m.ResolvedEnv()...)
420		return &mcp.CommandTransport{
421			Command: cmd,
422		}, nil
423	case config.MCPHttp:
424		if strings.TrimSpace(m.URL) == "" {
425			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
426		}
427		client := &http.Client{
428			Transport: &headerRoundTripper{
429				headers: m.ResolvedHeaders(),
430			},
431		}
432		return &mcp.StreamableClientTransport{
433			Endpoint:   m.URL,
434			HTTPClient: client,
435		}, nil
436	case config.MCPSSE:
437		if strings.TrimSpace(m.URL) == "" {
438			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
439		}
440		client := &http.Client{
441			Transport: &headerRoundTripper{
442				headers: m.ResolvedHeaders(),
443			},
444		}
445		return &mcp.SSEClientTransport{
446			Endpoint:   m.URL,
447			HTTPClient: client,
448		}, nil
449	default:
450		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
451	}
452}
453
454type headerRoundTripper struct {
455	headers map[string]string
456}
457
458func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
459	for k, v := range rt.headers {
460		req.Header.Set(k, v)
461	}
462	return http.DefaultTransport.RoundTrip(req)
463}
464
465func mcpTimeout(m config.MCPConfig) time.Duration {
466	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
467}