mcp-tools.go

  1package tools
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"maps"
 10	"slices"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/ai"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/version"
 21	"github.com/mark3labs/mcp-go/client"
 22	"github.com/mark3labs/mcp-go/client/transport"
 23	"github.com/mark3labs/mcp-go/mcp"
 24)
 25
 26// MCPState represents the current state of an MCP client
 27type MCPState int
 28
 29const (
 30	MCPStateDisabled MCPState = iota
 31	MCPStateStarting
 32	MCPStateConnected
 33	MCPStateError
 34)
 35
 36func (s MCPState) String() string {
 37	switch s {
 38	case MCPStateDisabled:
 39		return "disabled"
 40	case MCPStateStarting:
 41		return "starting"
 42	case MCPStateConnected:
 43		return "connected"
 44	case MCPStateError:
 45		return "error"
 46	default:
 47		return "unknown"
 48	}
 49}
 50
 51// MCPEventType represents the type of MCP event
 52type MCPEventType string
 53
 54const (
 55	MCPEventStateChanged MCPEventType = "state_changed"
 56)
 57
 58// MCPEvent represents an event in the MCP system
 59type MCPEvent struct {
 60	Type      MCPEventType
 61	Name      string
 62	State     MCPState
 63	Error     error
 64	ToolCount int
 65}
 66
 67// MCPClientInfo holds information about an MCP client's state
 68type MCPClientInfo struct {
 69	Name        string
 70	State       MCPState
 71	Error       error
 72	Client      *client.Client
 73	ToolCount   int
 74	ConnectedAt time.Time
 75}
 76
 77var (
 78	mcpToolsOnce sync.Once
 79	mcpTools     []ai.AgentTool
 80	mcpClients   = csync.NewMap[string, *client.Client]()
 81	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 82	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 83)
 84
 85type McpTool struct {
 86	mcpName     string
 87	tool        mcp.Tool
 88	permissions permission.Service
 89	workingDir  string
 90}
 91
 92func (b *McpTool) Name() string {
 93	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 94}
 95
 96func (b *McpTool) Info() ai.ToolInfo {
 97	required := b.tool.InputSchema.Required
 98	if required == nil {
 99		required = make([]string, 0)
100	}
101	return ai.ToolInfo{
102		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
103		Description: b.tool.Description,
104		Parameters:  b.tool.InputSchema.Properties,
105		Required:    required,
106	}
107}
108
109func runTool(ctx context.Context, name, toolName string, input string) (ai.ToolResponse, error) {
110	var args map[string]any
111	if err := json.Unmarshal([]byte(input), &args); err != nil {
112		return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
113	}
114
115	c, err := getOrRenewClient(ctx, name)
116	if err != nil {
117		return ai.NewTextErrorResponse(err.Error()), nil
118	}
119	result, err := c.CallTool(ctx, mcp.CallToolRequest{
120		Params: mcp.CallToolParams{
121			Name:      toolName,
122			Arguments: args,
123		},
124	})
125	if err != nil {
126		return ai.NewTextErrorResponse(err.Error()), nil
127	}
128
129	output := make([]string, 0, len(result.Content))
130	for _, v := range result.Content {
131		if v, ok := v.(mcp.TextContent); ok {
132			output = append(output, v.Text)
133		} else {
134			output = append(output, fmt.Sprintf("%v", v))
135		}
136	}
137	return ai.NewTextResponse(strings.Join(output, "\n")), nil
138}
139
140func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
141	c, ok := mcpClients.Get(name)
142	if !ok {
143		return nil, fmt.Errorf("mcp '%s' not available", name)
144	}
145
146	m := config.Get().MCP[name]
147	state, _ := mcpStates.Get(name)
148
149	pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
150	defer cancel()
151	err := c.Ping(pingCtx)
152	if err == nil {
153		return c, nil
154	}
155	updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
156
157	c, err = createAndInitializeClient(ctx, name, m)
158	if err != nil {
159		return nil, err
160	}
161
162	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
163	mcpClients.Set(name, c)
164	return c, nil
165}
166
167func (b *McpTool) Run(ctx context.Context, params ai.ToolCall) (ai.ToolResponse, error) {
168	sessionID, messageID := GetContextValues(ctx)
169	if sessionID == "" || messageID == "" {
170		return ai.ToolResponse{}, fmt.Errorf("session ID and message ID are required for MCP tool execution")
171	}
172	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
173	p := b.permissions.Request(
174		permission.CreatePermissionRequest{
175			SessionID:   sessionID,
176			ToolCallID:  params.ID,
177			Path:        b.workingDir,
178			ToolName:    b.Info().Name,
179			Action:      "execute",
180			Description: permissionDescription,
181			Params:      params.Input,
182		},
183	)
184	if !p {
185		return ai.ToolResponse{}, permission.ErrorPermissionDenied
186	}
187
188	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
189}
190
191func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []ai.AgentTool {
192	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
193	if err != nil {
194		slog.Error("error listing tools", "error", err)
195		updateMCPState(name, MCPStateError, err, nil, 0)
196		c.Close()
197		mcpClients.Del(name)
198		return nil
199	}
200	mcpTools := make([]ai.AgentTool, 0, len(result.Tools))
201	for _, tool := range result.Tools {
202		mcpTools = append(mcpTools, &McpTool{
203			mcpName:     name,
204			tool:        tool,
205			permissions: permissions,
206			workingDir:  workingDir,
207		})
208	}
209	return mcpTools
210}
211
212// SubscribeMCPEvents returns a channel for MCP events
213func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
214	return mcpBroker.Subscribe(ctx)
215}
216
217// GetMCPStates returns the current state of all MCP clients
218func GetMCPStates() map[string]MCPClientInfo {
219	return maps.Collect(mcpStates.Seq2())
220}
221
222// GetMCPState returns the state of a specific MCP client
223func GetMCPState(name string) (MCPClientInfo, bool) {
224	return mcpStates.Get(name)
225}
226
227// updateMCPState updates the state of an MCP client and publishes an event
228func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
229	info := MCPClientInfo{
230		Name:      name,
231		State:     state,
232		Error:     err,
233		Client:    client,
234		ToolCount: toolCount,
235	}
236	if state == MCPStateConnected {
237		info.ConnectedAt = time.Now()
238	}
239	mcpStates.Set(name, info)
240
241	// Publish state change event
242	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
243		Type:      MCPEventStateChanged,
244		Name:      name,
245		State:     state,
246		Error:     err,
247		ToolCount: toolCount,
248	})
249}
250
251// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
252func CloseMCPClients() {
253	for c := range mcpClients.Seq() {
254		_ = c.Close()
255	}
256	mcpBroker.Shutdown()
257}
258
259var mcpInitRequest = mcp.InitializeRequest{
260	Params: mcp.InitializeParams{
261		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
262		ClientInfo: mcp.Implementation{
263			Name:    "Crush",
264			Version: version.Version,
265		},
266	},
267}
268
269func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []ai.AgentTool {
270	var mcpTools []ai.AgentTool
271	mcpToolsOnce.Do(func() {
272		var wg sync.WaitGroup
273		result := csync.NewSlice[ai.AgentTool]()
274
275		// Initialize states for all configured MCPs
276		for name, m := range cfg.MCP {
277			if m.Disabled {
278				updateMCPState(name, MCPStateDisabled, nil, nil, 0)
279				slog.Debug("skipping disabled mcp", "name", name)
280				continue
281			}
282
283			// Set initial starting state
284			updateMCPState(name, MCPStateStarting, nil, nil, 0)
285
286			wg.Add(1)
287			go func(name string, m config.MCPConfig) {
288				defer func() {
289					wg.Done()
290					if r := recover(); r != nil {
291						var err error
292						switch v := r.(type) {
293						case error:
294							err = v
295						case string:
296							err = fmt.Errorf("panic: %s", v)
297						default:
298							err = fmt.Errorf("panic: %v", v)
299						}
300						updateMCPState(name, MCPStateError, err, nil, 0)
301						slog.Error("panic in mcp client initialization", "error", err, "name", name)
302					}
303				}()
304
305				mcpCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
306				defer cancel()
307				c, err := createAndInitializeClient(mcpCtx, name, m)
308				if err != nil {
309					return
310				}
311				mcpClients.Set(name, c)
312
313				tools := getTools(mcpCtx, name, permissions, c, cfg.WorkingDir())
314				updateMCPState(name, MCPStateConnected, nil, c, len(tools))
315				result.Append(tools...)
316			}(name, m)
317		}
318		wg.Wait()
319		mcpTools = slices.Collect(result.Seq())
320	})
321	return mcpTools
322}
323
324func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
325	c, err := createMcpClient(m)
326	if err != nil {
327		updateMCPState(name, MCPStateError, err, nil, 0)
328		slog.Error("error creating mcp client", "error", err, "name", name)
329		return nil, err
330	}
331	// Only call Start() for non-stdio clients, as stdio clients auto-start
332	if m.Type != config.MCPStdio {
333		if err := c.Start(ctx); err != nil {
334			updateMCPState(name, MCPStateError, err, nil, 0)
335			slog.Error("error starting mcp client", "error", err, "name", name)
336			_ = c.Close()
337			return nil, err
338		}
339	}
340	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
341		updateMCPState(name, MCPStateError, err, nil, 0)
342		slog.Error("error initializing mcp client", "error", err, "name", name)
343		_ = c.Close()
344		return nil, err
345	}
346
347	slog.Info("Initialized mcp client", "name", name)
348	return c, nil
349}
350
351func createMcpClient(m config.MCPConfig) (*client.Client, error) {
352	switch m.Type {
353	case config.MCPStdio:
354		return client.NewStdioMCPClientWithOptions(
355			m.Command,
356			m.ResolvedEnv(),
357			m.Args,
358			transport.WithCommandLogger(mcpLogger{}),
359		)
360	case config.MCPHttp:
361		return client.NewStreamableHttpClient(
362			m.URL,
363			transport.WithHTTPHeaders(m.ResolvedHeaders()),
364			transport.WithHTTPLogger(mcpLogger{}),
365		)
366	case config.MCPSse:
367		return client.NewSSEMCPClient(
368			m.URL,
369			client.WithHeaders(m.ResolvedHeaders()),
370			transport.WithSSELogger(mcpLogger{}),
371		)
372	default:
373		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
374	}
375}
376
377// for MCP's clients.
378type mcpLogger struct{}
379
380func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
381func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }
382
383func mcpTimeout(m config.MCPConfig) time.Duration {
384	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
385}