mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"io"
 10	"log/slog"
 11	"maps"
 12	"net/http"
 13	"os/exec"
 14	"strings"
 15	"sync"
 16	"time"
 17
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/home"
 21	"github.com/charmbracelet/crush/internal/llm/tools"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/version"
 25	"github.com/modelcontextprotocol/go-sdk/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) String() string {
 39	switch s {
 40	case MCPStateDisabled:
 41		return "disabled"
 42	case MCPStateStarting:
 43		return "starting"
 44	case MCPStateConnected:
 45		return "connected"
 46	case MCPStateError:
 47		return "error"
 48	default:
 49		return "unknown"
 50	}
 51}
 52
 53// MCPEventType represents the type of MCP event
 54type MCPEventType string
 55
 56const (
 57	MCPEventStateChanged     MCPEventType = "state_changed"
 58	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 59)
 60
 61// MCPEvent represents an event in the MCP system
 62type MCPEvent struct {
 63	Type      MCPEventType
 64	Name      string
 65	State     MCPState
 66	Error     error
 67	ToolCount int
 68}
 69
 70// MCPClientInfo holds information about an MCP client's state
 71type MCPClientInfo struct {
 72	Name        string
 73	State       MCPState
 74	Error       error
 75	Client      *mcp.ClientSession
 76	ToolCount   int
 77	ConnectedAt time.Time
 78}
 79
 80var (
 81	mcpToolsOnce    sync.Once
 82	mcpTools        = csync.NewMap[string, tools.BaseTool]()
 83	mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
 84	mcpClients      = csync.NewMap[string, *mcp.ClientSession]()
 85	mcpStates       = csync.NewMap[string, MCPClientInfo]()
 86	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 87)
 88
 89type McpTool struct {
 90	mcpName     string
 91	tool        *mcp.Tool
 92	permissions permission.Service
 93	workingDir  string
 94}
 95
 96func (b *McpTool) Name() string {
 97	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 98}
 99
100func (b *McpTool) Info() tools.ToolInfo {
101	var parameters map[string]any
102	var required []string
103
104	if input, ok := b.tool.InputSchema.(map[string]any); ok {
105		if props, ok := input["properties"].(map[string]any); ok {
106			parameters = props
107		}
108		if req, ok := input["required"].([]any); ok {
109			// Convert []any -> []string when elements are strings
110			for _, v := range req {
111				if s, ok := v.(string); ok {
112					required = append(required, s)
113				}
114			}
115		} else if reqStr, ok := input["required"].([]string); ok {
116			// Handle case where it's already []string
117			required = reqStr
118		}
119	}
120
121	return tools.ToolInfo{
122		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
123		Description: b.tool.Description,
124		Parameters:  parameters,
125		Required:    required,
126	}
127}
128
129func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
130	var args map[string]any
131	if err := json.Unmarshal([]byte(input), &args); err != nil {
132		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
133	}
134
135	c, err := getOrRenewClient(ctx, name)
136	if err != nil {
137		return tools.NewTextErrorResponse(err.Error()), nil
138	}
139	result, err := c.CallTool(ctx, &mcp.CallToolParams{
140		Name:      toolName,
141		Arguments: args,
142	})
143	if err != nil {
144		return tools.NewTextErrorResponse(err.Error()), nil
145	}
146
147	output := make([]string, 0, len(result.Content))
148	for _, v := range result.Content {
149		if vv, ok := v.(*mcp.TextContent); ok {
150			output = append(output, vv.Text)
151		} else {
152			output = append(output, fmt.Sprintf("%v", v))
153		}
154	}
155	return tools.NewTextResponse(strings.Join(output, "\n")), nil
156}
157
158func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
159	sess, ok := mcpClients.Get(name)
160	if !ok {
161		return nil, fmt.Errorf("mcp '%s' not available", name)
162	}
163
164	cfg := config.Get()
165	m := cfg.MCP[name]
166	state, _ := mcpStates.Get(name)
167
168	timeout := mcpTimeout(m)
169	pingCtx, cancel := context.WithTimeout(ctx, timeout)
170	defer cancel()
171	err := sess.Ping(pingCtx, nil)
172	if err == nil {
173		return sess, nil
174	}
175	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
176
177	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
178	if err != nil {
179		return nil, err
180	}
181
182	updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
183	mcpClients.Set(name, sess)
184	return sess, nil
185}
186
187func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
188	sessionID, messageID := tools.GetContextValues(ctx)
189	if sessionID == "" || messageID == "" {
190		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
191	}
192	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
193	p := b.permissions.Request(
194		permission.CreatePermissionRequest{
195			SessionID:   sessionID,
196			ToolCallID:  params.ID,
197			Path:        b.workingDir,
198			ToolName:    b.Info().Name,
199			Action:      "execute",
200			Description: permissionDescription,
201			Params:      params.Input,
202		},
203	)
204	if !p {
205		return tools.ToolResponse{}, permission.ErrorPermissionDenied
206	}
207
208	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
209}
210
211func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
212	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
213	if err != nil {
214		return nil, err
215	}
216	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
217	for _, tool := range result.Tools {
218		mcpTools = append(mcpTools, &McpTool{
219			mcpName:     name,
220			tool:        tool,
221			permissions: permissions,
222			workingDir:  workingDir,
223		})
224	}
225	return mcpTools, nil
226}
227
228// SubscribeMCPEvents returns a channel for MCP events
229func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
230	return mcpBroker.Subscribe(ctx)
231}
232
233// GetMCPStates returns the current state of all MCP clients
234func GetMCPStates() map[string]MCPClientInfo {
235	return maps.Collect(mcpStates.Seq2())
236}
237
238// GetMCPState returns the state of a specific MCP client
239func GetMCPState(name string) (MCPClientInfo, bool) {
240	return mcpStates.Get(name)
241}
242
243// updateMCPState updates the state of an MCP client and publishes an event
244func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
245	info := MCPClientInfo{
246		Name:      name,
247		State:     state,
248		Error:     err,
249		Client:    client,
250		ToolCount: toolCount,
251	}
252	switch state {
253	case MCPStateConnected:
254		info.ConnectedAt = time.Now()
255	case MCPStateError:
256		updateMcpTools(name, nil)
257		mcpClients.Del(name)
258	}
259	mcpStates.Set(name, info)
260
261	// Publish state change event
262	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
263		Type:      MCPEventStateChanged,
264		Name:      name,
265		State:     state,
266		Error:     err,
267		ToolCount: toolCount,
268	})
269}
270
271// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
272func CloseMCPClients() error {
273	var errs []error
274	for name, c := range mcpClients.Seq2() {
275		if err := c.Close(); err != nil &&
276			!errors.Is(err, io.EOF) &&
277			!errors.Is(err, context.Canceled) &&
278			err.Error() != "signal: killed" {
279			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
280		}
281	}
282	mcpBroker.Shutdown()
283	return errors.Join(errs...)
284}
285
286func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
287	var wg sync.WaitGroup
288	// Initialize states for all configured MCPs
289	for name, m := range cfg.MCP {
290		if m.Disabled {
291			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
292			slog.Debug("skipping disabled mcp", "name", name)
293			continue
294		}
295
296		// Set initial starting state
297		updateMCPState(name, MCPStateStarting, nil, nil, 0)
298
299		wg.Add(1)
300		go func(name string, m config.MCPConfig) {
301			defer func() {
302				wg.Done()
303				if r := recover(); r != nil {
304					var err error
305					switch v := r.(type) {
306					case error:
307						err = v
308					case string:
309						err = fmt.Errorf("panic: %s", v)
310					default:
311						err = fmt.Errorf("panic: %v", v)
312					}
313					updateMCPState(name, MCPStateError, err, nil, 0)
314					slog.Error("panic in mcp client initialization", "error", err, "name", name)
315				}
316			}()
317
318			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
319			defer cancel()
320
321			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
322			if err != nil {
323				return
324			}
325
326			mcpClients.Set(name, c)
327
328			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
329			if err != nil {
330				slog.Error("error listing tools", "error", err)
331				updateMCPState(name, MCPStateError, err, nil, 0)
332				c.Close()
333				return
334			}
335
336			updateMcpTools(name, tools)
337			mcpClients.Set(name, c)
338			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
339		}(name, m)
340	}
341	wg.Wait()
342}
343
344// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
345func updateMcpTools(mcpName string, tools []tools.BaseTool) {
346	if len(tools) == 0 {
347		mcpClient2Tools.Del(mcpName)
348	} else {
349		mcpClient2Tools.Set(mcpName, tools)
350	}
351	for _, tools := range mcpClient2Tools.Seq2() {
352		for _, t := range tools {
353			mcpTools.Set(t.Name(), t)
354		}
355	}
356}
357
358func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
359	timeout := mcpTimeout(m)
360	mcpCtx, cancel := context.WithCancel(ctx)
361	cancelTimer := time.AfterFunc(timeout, cancel)
362
363	transport, err := createMCPTransport(mcpCtx, m, resolver)
364	if err != nil {
365		updateMCPState(name, MCPStateError, err, nil, 0)
366		slog.Error("error creating mcp client", "error", err, "name", name)
367		return nil, err
368	}
369
370	client := mcp.NewClient(
371		&mcp.Implementation{
372			Name:    "crush",
373			Version: version.Version,
374			Title:   "Crush",
375		},
376		&mcp.ClientOptions{
377			ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
378				mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
379					Type: MCPEventToolsListChanged,
380					Name: name,
381				})
382			},
383			KeepAlive: time.Minute * 10,
384		},
385	)
386
387	session, err := client.Connect(mcpCtx, transport, nil)
388	if err != nil {
389		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
390		slog.Error("error starting mcp client", "error", err, "name", name)
391		cancel()
392		return nil, err
393	}
394
395	cancelTimer.Stop()
396	slog.Info("Initialized mcp client", "name", name)
397	return session, nil
398}
399
400func maybeTimeoutErr(err error, timeout time.Duration) error {
401	if errors.Is(err, context.Canceled) {
402		return fmt.Errorf("timed out after %s", timeout)
403	}
404	return err
405}
406
407func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
408	switch m.Type {
409	case config.MCPStdio:
410		command, err := resolver.ResolveValue(m.Command)
411		if err != nil {
412			return nil, fmt.Errorf("invalid mcp command: %w", err)
413		}
414		if strings.TrimSpace(command) == "" {
415			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
416		}
417		cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
418		cmd.Env = m.ResolvedEnv()
419		return &mcp.CommandTransport{
420			Command: cmd,
421		}, nil
422	case config.MCPHttp:
423		if strings.TrimSpace(m.URL) == "" {
424			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
425		}
426		client := &http.Client{
427			Transport: &headerRoundTripper{
428				headers: m.ResolvedHeaders(),
429			},
430		}
431		return &mcp.StreamableClientTransport{
432			Endpoint:   m.URL,
433			HTTPClient: client,
434		}, nil
435	case config.MCPSSE:
436		if strings.TrimSpace(m.URL) == "" {
437			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
438		}
439		client := &http.Client{
440			Transport: &headerRoundTripper{
441				headers: m.ResolvedHeaders(),
442			},
443		}
444		return &mcp.SSEClientTransport{
445			Endpoint:   m.URL,
446			HTTPClient: client,
447		}, nil
448	default:
449		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
450	}
451}
452
453type headerRoundTripper struct {
454	headers map[string]string
455}
456
457func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
458	for k, v := range rt.headers {
459		req.Header.Set(k, v)
460	}
461	return http.DefaultTransport.RoundTrip(req)
462}
463
464func mcpTimeout(m config.MCPConfig) time.Duration {
465	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
466}