mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"maps"
 10	"slices"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/llm/tools"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/pubsub"
 20	"github.com/charmbracelet/crush/internal/version"
 21	"github.com/mark3labs/mcp-go/client"
 22	"github.com/mark3labs/mcp-go/client/transport"
 23	"github.com/mark3labs/mcp-go/mcp"
 24)
 25
 26// MCPState represents the current state of an MCP client
 27type MCPState int
 28
 29const (
 30	MCPStateDisabled MCPState = iota
 31	MCPStateStarting
 32	MCPStateConnected
 33	MCPStateError
 34)
 35
 36func (s MCPState) String() string {
 37	switch s {
 38	case MCPStateDisabled:
 39		return "disabled"
 40	case MCPStateStarting:
 41		return "starting"
 42	case MCPStateConnected:
 43		return "connected"
 44	case MCPStateError:
 45		return "error"
 46	default:
 47		return "unknown"
 48	}
 49}
 50
 51// MCPEventType represents the type of MCP event
 52type MCPEventType string
 53
 54const (
 55	MCPEventStateChanged MCPEventType = "state_changed"
 56)
 57
 58// MCPEvent represents an event in the MCP system
 59type MCPEvent struct {
 60	Type      MCPEventType
 61	Name      string
 62	State     MCPState
 63	Error     error
 64	ToolCount int
 65}
 66
 67// MCPClientInfo holds information about an MCP client's state
 68type MCPClientInfo struct {
 69	Name        string
 70	State       MCPState
 71	Error       error
 72	Client      *client.Client
 73	ToolCount   int
 74	ConnectedAt time.Time
 75}
 76
 77var (
 78	mcpToolsOnce sync.Once
 79	mcpTools     []tools.BaseTool
 80	mcpClients   = csync.NewMap[string, *client.Client]()
 81	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 82	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 83)
 84
 85type McpTool struct {
 86	mcpName     string
 87	tool        mcp.Tool
 88	permissions permission.Service
 89	workingDir  string
 90}
 91
 92func (b *McpTool) Name() string {
 93	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 94}
 95
 96func (b *McpTool) Info() tools.ToolInfo {
 97	required := b.tool.InputSchema.Required
 98	if required == nil {
 99		required = make([]string, 0)
100	}
101	return tools.ToolInfo{
102		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
103		Description: b.tool.Description,
104		Parameters:  b.tool.InputSchema.Properties,
105		Required:    required,
106	}
107}
108
109func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
110	var args map[string]any
111	if err := json.Unmarshal([]byte(input), &args); err != nil {
112		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
113	}
114
115	c, err := getOrRenewClient(ctx, name)
116	if err != nil {
117		return tools.NewTextErrorResponse(err.Error()), nil
118	}
119	result, err := c.CallTool(ctx, mcp.CallToolRequest{
120		Params: mcp.CallToolParams{
121			Name:      toolName,
122			Arguments: args,
123		},
124	})
125	if err != nil {
126		return tools.NewTextErrorResponse(err.Error()), nil
127	}
128
129	output := make([]string, 0, len(result.Content))
130	for _, v := range result.Content {
131		if v, ok := v.(mcp.TextContent); ok {
132			output = append(output, v.Text)
133		} else {
134			output = append(output, fmt.Sprintf("%v", v))
135		}
136	}
137	return tools.NewTextResponse(strings.Join(output, "\n")), nil
138}
139
140func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
141	c, ok := mcpClients.Get(name)
142	if !ok {
143		return nil, fmt.Errorf("mcp '%s' not available", name)
144	}
145
146	m := config.Get().MCP[name]
147	state, _ := mcpStates.Get(name)
148
149	pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
150	defer cancel()
151	err := c.Ping(pingCtx)
152	if err == nil {
153		return c, nil
154	}
155	updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
156
157	c, err = createAndInitializeClient(ctx, name, m)
158	if err != nil {
159		return nil, err
160	}
161
162	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
163	mcpClients.Set(name, c)
164	return c, nil
165}
166
167func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
168	sessionID, messageID := tools.GetContextValues(ctx)
169	if sessionID == "" || messageID == "" {
170		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
171	}
172	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
173	p := b.permissions.Request(
174		permission.CreatePermissionRequest{
175			SessionID:   sessionID,
176			ToolCallID:  params.ID,
177			Path:        b.workingDir,
178			ToolName:    b.Info().Name,
179			Action:      "execute",
180			Description: permissionDescription,
181			Params:      params.Input,
182		},
183	)
184	if !p {
185		return tools.ToolResponse{}, permission.ErrorPermissionDenied
186	}
187
188	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
189}
190
191func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
192	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
193	if err != nil {
194		slog.Error("error listing tools", "error", err)
195		updateMCPState(name, MCPStateError, err, nil, 0)
196		c.Close()
197		mcpClients.Del(name)
198		return nil
199	}
200	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
201	for _, tool := range result.Tools {
202		mcpTools = append(mcpTools, &McpTool{
203			mcpName:     name,
204			tool:        tool,
205			permissions: permissions,
206			workingDir:  workingDir,
207		})
208	}
209	return mcpTools
210}
211
212// SubscribeMCPEvents returns a channel for MCP events
213func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
214	return mcpBroker.Subscribe(ctx)
215}
216
217// GetMCPStates returns the current state of all MCP clients
218func GetMCPStates() map[string]MCPClientInfo {
219	return maps.Collect(mcpStates.Seq2())
220}
221
222// GetMCPState returns the state of a specific MCP client
223func GetMCPState(name string) (MCPClientInfo, bool) {
224	return mcpStates.Get(name)
225}
226
227// updateMCPState updates the state of an MCP client and publishes an event
228func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
229	info := MCPClientInfo{
230		Name:      name,
231		State:     state,
232		Error:     err,
233		Client:    client,
234		ToolCount: toolCount,
235	}
236	if state == MCPStateConnected {
237		info.ConnectedAt = time.Now()
238	}
239	mcpStates.Set(name, info)
240
241	// Publish state change event
242	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
243		Type:      MCPEventStateChanged,
244		Name:      name,
245		State:     state,
246		Error:     err,
247		ToolCount: toolCount,
248	})
249}
250
251// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
252func CloseMCPClients() {
253	for c := range mcpClients.Seq() {
254		_ = c.Close()
255	}
256	mcpBroker.Shutdown()
257}
258
259var mcpInitRequest = mcp.InitializeRequest{
260	Params: mcp.InitializeParams{
261		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
262		ClientInfo: mcp.Implementation{
263			Name:    "Crush",
264			Version: version.Version,
265		},
266	},
267}
268
269func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
270	var wg sync.WaitGroup
271	result := csync.NewSlice[tools.BaseTool]()
272
273	// Initialize states for all configured MCPs
274	for name, m := range cfg.MCP {
275		if m.Disabled {
276			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
277			slog.Debug("skipping disabled mcp", "name", name)
278			continue
279		}
280
281		// Set initial starting state
282		updateMCPState(name, MCPStateStarting, nil, nil, 0)
283
284		wg.Add(1)
285		go func(name string, m config.MCPConfig) {
286			defer func() {
287				wg.Done()
288				if r := recover(); r != nil {
289					var err error
290					switch v := r.(type) {
291					case error:
292						err = v
293					case string:
294						err = fmt.Errorf("panic: %s", v)
295					default:
296						err = fmt.Errorf("panic: %v", v)
297					}
298					updateMCPState(name, MCPStateError, err, nil, 0)
299					slog.Error("panic in mcp client initialization", "error", err, "name", name)
300				}
301			}()
302
303			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
304			defer cancel()
305			c, err := createAndInitializeClient(ctx, name, m)
306			if err != nil {
307				return
308			}
309			mcpClients.Set(name, c)
310
311			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
312			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
313			result.Append(tools...)
314		}(name, m)
315	}
316	wg.Wait()
317	return slices.Collect(result.Seq())
318}
319
320func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
321	c, err := createMcpClient(m)
322	if err != nil {
323		updateMCPState(name, MCPStateError, err, nil, 0)
324		slog.Error("error creating mcp client", "error", err, "name", name)
325		return nil, err
326	}
327	// Only call Start() for non-stdio clients, as stdio clients auto-start
328	if m.Type != config.MCPStdio {
329		if err := c.Start(ctx); err != nil {
330			updateMCPState(name, MCPStateError, err, nil, 0)
331			slog.Error("error starting mcp client", "error", err, "name", name)
332			_ = c.Close()
333			return nil, err
334		}
335	}
336	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
337		updateMCPState(name, MCPStateError, err, nil, 0)
338		slog.Error("error initializing mcp client", "error", err, "name", name)
339		_ = c.Close()
340		return nil, err
341	}
342
343	slog.Info("Initialized mcp client", "name", name)
344	return c, nil
345}
346
347func createMcpClient(m config.MCPConfig) (*client.Client, error) {
348	switch m.Type {
349	case config.MCPStdio:
350		return client.NewStdioMCPClientWithOptions(
351			m.Command,
352			m.ResolvedEnv(),
353			m.Args,
354			transport.WithCommandLogger(mcpLogger{}),
355		)
356	case config.MCPHttp:
357		return client.NewStreamableHttpClient(
358			m.URL,
359			transport.WithHTTPHeaders(m.ResolvedHeaders()),
360			transport.WithHTTPLogger(mcpLogger{}),
361		)
362	case config.MCPSse:
363		return client.NewSSEMCPClient(
364			m.URL,
365			client.WithHeaders(m.ResolvedHeaders()),
366			transport.WithSSELogger(mcpLogger{}),
367		)
368	default:
369		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
370	}
371}
372
373// for MCP's clients.
374type mcpLogger struct{}
375
376func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
377func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }
378
379func mcpTimeout(m config.MCPConfig) time.Duration {
380	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
381}