mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"net/http"
 12	"os/exec"
 13	"strings"
 14	"sync"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/csync"
 19	"github.com/charmbracelet/crush/internal/home"
 20	"github.com/charmbracelet/crush/internal/llm/tools"
 21	"github.com/charmbracelet/crush/internal/permission"
 22	"github.com/charmbracelet/crush/internal/pubsub"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/modelcontextprotocol/go-sdk/mcp"
 25)
 26
 27// MCPState represents the current state of an MCP client
 28type MCPState int
 29
 30const (
 31	MCPStateDisabled MCPState = iota
 32	MCPStateStarting
 33	MCPStateConnected
 34	MCPStateError
 35)
 36
 37func (s MCPState) String() string {
 38	switch s {
 39	case MCPStateDisabled:
 40		return "disabled"
 41	case MCPStateStarting:
 42		return "starting"
 43	case MCPStateConnected:
 44		return "connected"
 45	case MCPStateError:
 46		return "error"
 47	default:
 48		return "unknown"
 49	}
 50}
 51
 52// MCPEventType represents the type of MCP event
 53type MCPEventType string
 54
 55const (
 56	MCPEventStateChanged     MCPEventType = "state_changed"
 57	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 58)
 59
 60// MCPEvent represents an event in the MCP system
 61type MCPEvent struct {
 62	Type      MCPEventType
 63	Name      string
 64	State     MCPState
 65	Error     error
 66	ToolCount int
 67}
 68
 69// MCPClientInfo holds information about an MCP client's state
 70type MCPClientInfo struct {
 71	Name        string
 72	State       MCPState
 73	Error       error
 74	Client      *mcp.ClientSession
 75	ToolCount   int
 76	ConnectedAt time.Time
 77}
 78
 79var (
 80	mcpToolsOnce    sync.Once
 81	mcpTools        = csync.NewMap[string, tools.BaseTool]()
 82	mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
 83	mcpClients      = csync.NewMap[string, *mcp.ClientSession]()
 84	mcpStates       = csync.NewMap[string, MCPClientInfo]()
 85	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 86)
 87
 88type McpTool struct {
 89	mcpName     string
 90	tool        *mcp.Tool
 91	permissions permission.Service
 92	workingDir  string
 93}
 94
 95func (b *McpTool) Name() string {
 96	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 97}
 98
 99func (b *McpTool) Info() tools.ToolInfo {
100	input := b.tool.InputSchema.(map[string]any)
101	required, _ := input["required"].([]string)
102	parameters, _ := input["properties"].(map[string]any)
103	return tools.ToolInfo{
104		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
105		Description: b.tool.Description,
106		Parameters:  parameters,
107		Required:    required,
108	}
109}
110
111func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
112	var args map[string]any
113	if err := json.Unmarshal([]byte(input), &args); err != nil {
114		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
115	}
116
117	c, err := getOrRenewClient(ctx, name)
118	if err != nil {
119		return tools.NewTextErrorResponse(err.Error()), nil
120	}
121	result, err := c.CallTool(ctx, &mcp.CallToolParams{
122		Name:      toolName,
123		Arguments: args,
124	})
125	if err != nil {
126		return tools.NewTextErrorResponse(err.Error()), nil
127	}
128
129	output := make([]string, 0, len(result.Content))
130	for _, v := range result.Content {
131		if v, ok := v.(*mcp.TextContent); ok {
132			output = append(output, v.Text)
133		} else {
134			output = append(output, fmt.Sprintf("%v", v))
135		}
136	}
137	return tools.NewTextResponse(strings.Join(output, "\n")), nil
138}
139
140func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
141	sess, ok := mcpClients.Get(name)
142	if !ok {
143		return nil, fmt.Errorf("mcp '%s' not available", name)
144	}
145
146	cfg := config.Get()
147	m := cfg.MCP[name]
148	state, _ := mcpStates.Get(name)
149
150	timeout := mcpTimeout(m)
151	pingCtx, cancel := context.WithTimeout(ctx, timeout)
152	defer cancel()
153	err := sess.Ping(pingCtx, nil)
154	if err == nil {
155		return sess, nil
156	}
157	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
158
159	sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
160	if err != nil {
161		return nil, err
162	}
163
164	updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
165	mcpClients.Set(name, sess)
166	return sess, nil
167}
168
169func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
170	sessionID, messageID := tools.GetContextValues(ctx)
171	if sessionID == "" || messageID == "" {
172		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
173	}
174	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", b.Info().Name)
175	p := b.permissions.Request(
176		permission.CreatePermissionRequest{
177			SessionID:   sessionID,
178			ToolCallID:  params.ID,
179			Path:        b.workingDir,
180			ToolName:    b.Info().Name,
181			Action:      "execute",
182			Description: permissionDescription,
183			Params:      params.Input,
184		},
185	)
186	if !p {
187		return tools.ToolResponse{}, permission.ErrorPermissionDenied
188	}
189
190	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
191}
192
193func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
194	result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
195	if err != nil {
196		return nil, err
197	}
198	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
199	for _, tool := range result.Tools {
200		mcpTools = append(mcpTools, &McpTool{
201			mcpName:     name,
202			tool:        tool,
203			permissions: permissions,
204			workingDir:  workingDir,
205		})
206	}
207	return mcpTools, nil
208}
209
210// SubscribeMCPEvents returns a channel for MCP events
211func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
212	return mcpBroker.Subscribe(ctx)
213}
214
215// GetMCPStates returns the current state of all MCP clients
216func GetMCPStates() map[string]MCPClientInfo {
217	return maps.Collect(mcpStates.Seq2())
218}
219
220// GetMCPState returns the state of a specific MCP client
221func GetMCPState(name string) (MCPClientInfo, bool) {
222	return mcpStates.Get(name)
223}
224
225// updateMCPState updates the state of an MCP client and publishes an event
226func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
227	info := MCPClientInfo{
228		Name:      name,
229		State:     state,
230		Error:     err,
231		Client:    client,
232		ToolCount: toolCount,
233	}
234	switch state {
235	case MCPStateConnected:
236		info.ConnectedAt = time.Now()
237	case MCPStateError:
238		updateMcpTools(name, nil)
239		mcpClients.Del(name)
240	}
241	mcpStates.Set(name, info)
242
243	// Publish state change event
244	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
245		Type:      MCPEventStateChanged,
246		Name:      name,
247		State:     state,
248		Error:     err,
249		ToolCount: toolCount,
250	})
251}
252
253// publishMCPEventToolsListChanged publishes a tool list changed event
254func publishMCPEventToolsListChanged(name string) {
255	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
256		Type: MCPEventToolsListChanged,
257		Name: name,
258	})
259}
260
261// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
262func CloseMCPClients() error {
263	var errs []error
264	for name, c := range mcpClients.Seq2() {
265		if err := c.Close(); err != nil {
266			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
267		}
268	}
269	mcpBroker.Shutdown()
270	return errors.Join(errs...)
271}
272
273func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
274	var wg sync.WaitGroup
275	// Initialize states for all configured MCPs
276	for name, m := range cfg.MCP {
277		if m.Disabled {
278			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
279			slog.Debug("skipping disabled mcp", "name", name)
280			continue
281		}
282
283		// Set initial starting state
284		updateMCPState(name, MCPStateStarting, nil, nil, 0)
285
286		wg.Add(1)
287		go func(name string, m config.MCPConfig) {
288			defer func() {
289				wg.Done()
290				if r := recover(); r != nil {
291					var err error
292					switch v := r.(type) {
293					case error:
294						err = v
295					case string:
296						err = fmt.Errorf("panic: %s", v)
297					default:
298						err = fmt.Errorf("panic: %v", v)
299					}
300					updateMCPState(name, MCPStateError, err, nil, 0)
301					slog.Error("panic in mcp client initialization", "error", err, "name", name)
302				}
303			}()
304
305			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
306			defer cancel()
307
308			c, err := createMCPSession(ctx, name, m, cfg.Resolver())
309			if err != nil {
310				return
311			}
312
313			mcpClients.Set(name, c)
314
315			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
316			if err != nil {
317				slog.Error("error listing tools", "error", err)
318				updateMCPState(name, MCPStateError, err, nil, 0)
319				c.Close()
320				return
321			}
322
323			updateMcpTools(name, tools)
324			mcpClients.Set(name, c)
325			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
326		}(name, m)
327	}
328	wg.Wait()
329}
330
331// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
332func updateMcpTools(mcpName string, tools []tools.BaseTool) {
333	if len(tools) == 0 {
334		mcpClient2Tools.Del(mcpName)
335	} else {
336		mcpClient2Tools.Set(mcpName, tools)
337	}
338	for _, tools := range mcpClient2Tools.Seq2() {
339		for _, t := range tools {
340			mcpTools.Set(t.Name(), t)
341		}
342	}
343}
344
345func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
346	transport, err := createMCPTransport(name, m, resolver)
347	if err != nil {
348		updateMCPState(name, MCPStateError, err, nil, 0)
349		slog.Error("error creating mcp client", "error", err, "name", name)
350		return nil, err
351	}
352
353	client := mcp.NewClient(&mcp.Implementation{
354		Name:    "crush",
355		Version: version.Version,
356		Title:   "Crush",
357	}, &mcp.ClientOptions{
358		ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
359			mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
360				Type: MCPEventToolsListChanged,
361				Name: name,
362			})
363		},
364	})
365
366	// XXX: ideally we should be able to use context.WithTimeout here, but,
367	// the SSE MCP client will start failing once that context is canceled.
368	timeout := mcpTimeout(m)
369	mcpCtx, cancel := context.WithCancel(ctx)
370	cancelTimer := time.AfterFunc(timeout, cancel)
371
372	session, err := client.Connect(mcpCtx, transport, nil)
373	if err != nil {
374		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
375		slog.Error("error starting mcp client", "error", err, "name", name)
376		_ = session.Close()
377		cancel()
378		return nil, err
379	}
380
381	cancelTimer.Stop()
382	slog.Info("Initialized mcp client", "name", name)
383	return session, nil
384}
385
386func maybeTimeoutErr(err error, timeout time.Duration) error {
387	if errors.Is(err, context.Canceled) {
388		return fmt.Errorf("timed out after %s", timeout)
389	}
390	return err
391}
392
393func createMCPTransport(name string, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
394	switch m.Type {
395	case config.MCPStdio:
396		command, err := resolver.ResolveValue(m.Command)
397		if err != nil {
398			return nil, fmt.Errorf("invalid mcp command: %w", err)
399		}
400		if strings.TrimSpace(command) == "" {
401			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
402		}
403		cmd := exec.Command(home.Long(command), m.Args...)
404		cmd.Env = m.ResolvedEnv()
405		return &mcp.CommandTransport{
406			Command: cmd,
407		}, nil
408	case config.MCPHttp:
409		if strings.TrimSpace(m.URL) == "" {
410			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
411		}
412		client := &http.Client{
413			Transport: &headerRoundTripper{
414				headers: m.ResolvedHeaders(),
415			},
416		}
417		return &mcp.StreamableClientTransport{
418			Endpoint:   m.URL,
419			HTTPClient: client,
420		}, nil
421	case config.MCPSSE:
422		if strings.TrimSpace(m.URL) == "" {
423			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
424		}
425		client := &http.Client{
426			Transport: &headerRoundTripper{
427				headers: m.ResolvedHeaders(),
428			},
429		}
430		return &mcp.SSEClientTransport{
431			Endpoint:   m.URL,
432			HTTPClient: client,
433		}, nil
434	default:
435		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
436	}
437}
438
439type headerRoundTripper struct {
440	headers map[string]string
441}
442
443func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
444	for k, v := range rt.headers {
445		req.Header.Set(k, v)
446	}
447	return http.DefaultTransport.RoundTrip(req)
448}
449
450// for MCP's clients.
451type mcpLogger struct{ name string }
452
453func (l mcpLogger) Errorf(format string, v ...any) {
454	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
455}
456
457func (l mcpLogger) Infof(format string, v ...any) {
458	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
459}
460
461func mcpTimeout(m config.MCPConfig) time.Duration {
462	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
463}