mcp-tools.go

  1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/home"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/pubsub"
 21	"github.com/charmbracelet/crush/internal/version"
 22	"github.com/charmbracelet/fantasy/ai"
 23	"github.com/mark3labs/mcp-go/client"
 24	"github.com/mark3labs/mcp-go/client/transport"
 25	"github.com/mark3labs/mcp-go/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) String() string {
 39	switch s {
 40	case MCPStateDisabled:
 41		return "disabled"
 42	case MCPStateStarting:
 43		return "starting"
 44	case MCPStateConnected:
 45		return "connected"
 46	case MCPStateError:
 47		return "error"
 48	default:
 49		return "unknown"
 50	}
 51}
 52
 53// MCPEventType represents the type of MCP event
 54type MCPEventType string
 55
 56const (
 57	MCPEventStateChanged     MCPEventType = "state_changed"
 58	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 59)
 60
 61// MCPEvent represents an event in the MCP system
 62type MCPEvent struct {
 63	Type      MCPEventType
 64	Name      string
 65	State     MCPState
 66	Error     error
 67	ToolCount int
 68}
 69
 70// MCPClientInfo holds information about an MCP client's state
 71type MCPClientInfo struct {
 72	Name        string
 73	State       MCPState
 74	Error       error
 75	Client      *client.Client
 76	ToolCount   int
 77	ConnectedAt time.Time
 78}
 79
 80var (
 81	mcpToolsOnce    sync.Once
 82	mcpTools        = csync.NewMap[string, *McpTool]()
 83	mcpClient2Tools = csync.NewMap[string, []*McpTool]()
 84	mcpClients      = csync.NewMap[string, *client.Client]()
 85	mcpStates       = csync.NewMap[string, MCPClientInfo]()
 86	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 87)
 88
 89type McpTool struct {
 90	mcpName         string
 91	tool            mcp.Tool
 92	permissions     permission.Service
 93	workingDir      string
 94	providerOptions ai.ProviderOptions
 95}
 96
 97func (m *McpTool) SetProviderOptions(opts ai.ProviderOptions) {
 98	m.providerOptions = opts
 99}
100
101func (m *McpTool) ProviderOptions() ai.ProviderOptions {
102	return m.providerOptions
103}
104
105func (m *McpTool) Name() string {
106	return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name)
107}
108
109func (m *McpTool) MCP() string {
110	return m.mcpName
111}
112
113func (m *McpTool) MCPToolName() string {
114	return m.tool.Name
115}
116
117func (m *McpTool) Info() ai.ToolInfo {
118	required := m.tool.InputSchema.Required
119	if required == nil {
120		required = make([]string, 0)
121	}
122	parameters := m.tool.InputSchema.Properties
123	if parameters == nil {
124		parameters = make(map[string]any)
125	}
126	return ai.ToolInfo{
127		Name:        fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name),
128		Description: m.tool.Description,
129		Parameters:  parameters,
130		Required:    required,
131	}
132}
133
134func runTool(ctx context.Context, name, toolName string, input string) (ai.ToolResponse, error) {
135	var args map[string]any
136	if err := json.Unmarshal([]byte(input), &args); err != nil {
137		return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
138	}
139
140	c, err := getOrRenewClient(ctx, name)
141	if err != nil {
142		return ai.NewTextErrorResponse(err.Error()), nil
143	}
144	result, err := c.CallTool(ctx, mcp.CallToolRequest{
145		Params: mcp.CallToolParams{
146			Name:      toolName,
147			Arguments: args,
148		},
149	})
150	if err != nil {
151		return ai.NewTextErrorResponse(err.Error()), nil
152	}
153
154	output := make([]string, 0, len(result.Content))
155	for _, v := range result.Content {
156		if v, ok := v.(mcp.TextContent); ok {
157			output = append(output, v.Text)
158		} else {
159			output = append(output, fmt.Sprintf("%v", v))
160		}
161	}
162	return ai.NewTextResponse(strings.Join(output, "\n")), nil
163}
164
165func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
166	c, ok := mcpClients.Get(name)
167	if !ok {
168		return nil, fmt.Errorf("mcp '%s' not available", name)
169	}
170
171	cfg := config.Get()
172	m := cfg.MCP[name]
173	state, _ := mcpStates.Get(name)
174
175	timeout := mcpTimeout(m)
176	pingCtx, cancel := context.WithTimeout(ctx, timeout)
177	defer cancel()
178	err := c.Ping(pingCtx)
179	if err == nil {
180		return c, nil
181	}
182	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
183
184	c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
185	if err != nil {
186		return nil, err
187	}
188
189	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
190	mcpClients.Set(name, c)
191	return c, nil
192}
193
194func (m *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, error) {
195	sessionID := GetSessionFromContext(ctx)
196	if sessionID == "" {
197		return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
198	}
199	permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name)
200	p := m.permissions.Request(
201		permission.CreatePermissionRequest{
202			SessionID:   sessionID,
203			ToolCallID:  params.ID,
204			Path:        m.workingDir,
205			ToolName:    m.Info().Name,
206			Action:      "execute",
207			Description: permissionDescription,
208			Params:      params.Input,
209		},
210	)
211	if !p {
212		return ai.ToolResponse{}, permission.ErrorPermissionDenied
213	}
214
215	return runTool(ctx, m.mcpName, m.tool.Name, params.Input)
216}
217
218func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]*McpTool, error) {
219	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
220	if err != nil {
221		return nil, err
222	}
223	mcpTools := make([]*McpTool, 0, len(result.Tools))
224	for _, tool := range result.Tools {
225		mcpTools = append(mcpTools, &McpTool{
226			mcpName:     name,
227			tool:        tool,
228			permissions: permissions,
229			workingDir:  workingDir,
230		})
231	}
232	return mcpTools, nil
233}
234
235// SubscribeMCPEvents returns a channel for MCP events
236func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
237	return mcpBroker.Subscribe(ctx)
238}
239
240// GetMCPStates returns the current state of all MCP clients
241func GetMCPStates() map[string]MCPClientInfo {
242	return maps.Collect(mcpStates.Seq2())
243}
244
245// GetMCPState returns the state of a specific MCP client
246func GetMCPState(name string) (MCPClientInfo, bool) {
247	return mcpStates.Get(name)
248}
249
250// updateMCPState updates the state of an MCP client and publishes an event
251func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
252	info := MCPClientInfo{
253		Name:      name,
254		State:     state,
255		Error:     err,
256		Client:    client,
257		ToolCount: toolCount,
258	}
259	switch state {
260	case MCPStateConnected:
261		info.ConnectedAt = time.Now()
262	case MCPStateError:
263		updateMcpTools(name, nil)
264		mcpClients.Del(name)
265	}
266	mcpStates.Set(name, info)
267
268	// Publish state change event
269	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
270		Type:      MCPEventStateChanged,
271		Name:      name,
272		State:     state,
273		Error:     err,
274		ToolCount: toolCount,
275	})
276}
277
278// publishMCPEventToolsListChanged publishes a tool list changed event
279func publishMCPEventToolsListChanged(name string) {
280	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
281		Type: MCPEventToolsListChanged,
282		Name: name,
283	})
284}
285
286// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
287func CloseMCPClients() error {
288	var errs []error
289	for name, c := range mcpClients.Seq2() {
290		if err := c.Close(); err != nil {
291			errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
292		}
293	}
294	mcpBroker.Shutdown()
295	return errors.Join(errs...)
296}
297
298var mcpInitRequest = mcp.InitializeRequest{
299	Params: mcp.InitializeParams{
300		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
301		ClientInfo: mcp.Implementation{
302			Name:    "Crush",
303			Version: version.Version,
304		},
305	},
306}
307
308func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool {
309	mcpToolsOnce.Do(func() {
310		var wg sync.WaitGroup
311		// Initialize states for all configured MCPs
312		for name, m := range cfg.MCP {
313			if m.Disabled {
314				updateMCPState(name, MCPStateDisabled, nil, nil, 0)
315				slog.Debug("skipping disabled mcp", "name", name)
316				continue
317			}
318
319			// Set initial starting state
320			updateMCPState(name, MCPStateStarting, nil, nil, 0)
321
322			wg.Add(1)
323			go func(name string, m config.MCPConfig) {
324				defer func() {
325					wg.Done()
326					if r := recover(); r != nil {
327						var err error
328						switch v := r.(type) {
329						case error:
330							err = v
331						case string:
332							err = fmt.Errorf("panic: %s", v)
333						default:
334							err = fmt.Errorf("panic: %v", v)
335						}
336						updateMCPState(name, MCPStateError, err, nil, 0)
337						slog.Error("panic in mcp client initialization", "error", err, "name", name)
338					}
339				}()
340
341				mcpCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
342				defer cancel()
343
344				c, err := createAndInitializeClient(mcpCtx, name, m, cfg.Resolver())
345				if err != nil {
346					return
347				}
348
349				mcpClients.Set(name, c)
350
351				tools, err := getTools(mcpCtx, name, permissions, c, cfg.WorkingDir())
352				if err != nil {
353					slog.Error("error listing tools", "error", err)
354					updateMCPState(name, MCPStateError, err, nil, 0)
355					c.Close()
356					return
357				}
358
359				updateMcpTools(name, tools)
360				mcpClients.Set(name, c)
361				updateMCPState(name, MCPStateConnected, nil, c, len(tools))
362			}(name, m)
363		}
364		wg.Wait()
365	})
366	return slices.Collect(mcpTools.Seq())
367}
368
369// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
370func updateMcpTools(mcpName string, tools []*McpTool) {
371	if len(tools) == 0 {
372		mcpClient2Tools.Del(mcpName)
373	} else {
374		mcpClient2Tools.Set(mcpName, tools)
375	}
376	for _, tools := range mcpClient2Tools.Seq2() {
377		for _, t := range tools {
378			mcpTools.Set(t.Info().Name, t)
379		}
380	}
381}
382
383func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
384	c, err := createMcpClient(name, m, resolver)
385	if err != nil {
386		updateMCPState(name, MCPStateError, err, nil, 0)
387		slog.Error("error creating mcp client", "error", err, "name", name)
388		return nil, err
389	}
390
391	c.OnNotification(func(n mcp.JSONRPCNotification) {
392		slog.Debug("Received MCP notification", "name", name, "notification", n)
393		switch n.Method {
394		case "notifications/tools/list_changed":
395			publishMCPEventToolsListChanged(name)
396		default:
397			slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
398		}
399	})
400
401	// XXX: ideally we should be able to use context.WithTimeout here, but,
402	// the SSE MCP client will start failing once that context is canceled.
403	timeout := mcpTimeout(m)
404	mcpCtx, cancel := context.WithCancel(ctx)
405	cancelTimer := time.AfterFunc(timeout, cancel)
406
407	if err := c.Start(mcpCtx); err != nil {
408		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
409		slog.Error("error starting mcp client", "error", err, "name", name)
410		_ = c.Close()
411		cancel()
412		return nil, err
413	}
414
415	if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil {
416		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
417		slog.Error("error initializing mcp client", "error", err, "name", name)
418		_ = c.Close()
419		cancel()
420		return nil, err
421	}
422
423	cancelTimer.Stop()
424	slog.Info("Initialized mcp client", "name", name)
425	return c, nil
426}
427
428func maybeTimeoutErr(err error, timeout time.Duration) error {
429	if errors.Is(err, context.Canceled) {
430		return fmt.Errorf("timed out after %s", timeout)
431	}
432	return err
433}
434
435func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
436	switch m.Type {
437	case config.MCPStdio:
438		command, err := resolver.ResolveValue(m.Command)
439		if err != nil {
440			return nil, fmt.Errorf("invalid mcp command: %w", err)
441		}
442		if strings.TrimSpace(command) == "" {
443			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
444		}
445		return client.NewStdioMCPClientWithOptions(
446			home.Long(command),
447			m.ResolvedEnv(),
448			m.Args,
449			transport.WithCommandLogger(mcpLogger{name: name}),
450		)
451	case config.MCPHttp:
452		if strings.TrimSpace(m.URL) == "" {
453			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
454		}
455		return client.NewStreamableHttpClient(
456			m.URL,
457			transport.WithHTTPHeaders(m.ResolvedHeaders()),
458			transport.WithHTTPLogger(mcpLogger{name: name}),
459		)
460	case config.MCPSse:
461		if strings.TrimSpace(m.URL) == "" {
462			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
463		}
464		return client.NewSSEMCPClient(
465			m.URL,
466			client.WithHeaders(m.ResolvedHeaders()),
467			transport.WithSSELogger(mcpLogger{name: name}),
468		)
469	default:
470		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
471	}
472}
473
474// for MCP's clients.
475type mcpLogger struct{ name string }
476
477func (l mcpLogger) Errorf(format string, v ...any) {
478	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
479}
480
481func (l mcpLogger) Infof(format string, v ...any) {
482	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
483}
484
485func mcpTimeout(m config.MCPConfig) time.Duration {
486	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
487}