mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/home"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/proto"
 22	"github.com/charmbracelet/crush/internal/pubsub"
 23	"github.com/charmbracelet/crush/internal/version"
 24	"github.com/mark3labs/mcp-go/client"
 25	"github.com/mark3labs/mcp-go/client/transport"
 26	"github.com/mark3labs/mcp-go/mcp"
 27)
 28
 29type (
 30	MCPState     = proto.MCPState
 31	MCPEventType = proto.MCPEventType
 32	MCPEvent     = proto.MCPEvent
 33)
 34
 35const (
 36	MCPStateDisabled  = proto.MCPStateDisabled
 37	MCPStateStarting  = proto.MCPStateStarting
 38	MCPStateConnected = proto.MCPStateConnected
 39	MCPStateError     = proto.MCPStateError
 40
 41	MCPEventStateChanged = proto.MCPEventStateChanged
 42)
 43
 44// MCPClientInfo holds information about an MCP client's state
 45type MCPClientInfo struct {
 46	Name        string
 47	State       MCPState
 48	Error       error
 49	Client      *client.Client
 50	ToolCount   int
 51	ConnectedAt time.Time
 52}
 53
 54var (
 55	mcpToolsOnce sync.Once
 56	mcpTools     []tools.BaseTool
 57	mcpClients   = csync.NewMap[string, *client.Client]()
 58	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 59	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 60)
 61
 62type McpTool struct {
 63	mcpName     string
 64	tool        mcp.Tool
 65	permissions permission.Service
 66	cfg         *config.Config
 67}
 68
 69func (b *McpTool) Name() string {
 70	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 71}
 72
 73func (b *McpTool) Info() tools.ToolInfo {
 74	required := b.tool.InputSchema.Required
 75	if required == nil {
 76		required = make([]string, 0)
 77	}
 78	parameters := b.tool.InputSchema.Properties
 79	if parameters == nil {
 80		parameters = make(map[string]any)
 81	}
 82	return tools.ToolInfo{
 83		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 84		Description: b.tool.Description,
 85		Parameters:  parameters,
 86		Required:    required,
 87	}
 88}
 89
 90func runTool(ctx context.Context, cfg *config.Config, name, toolName string, input string) (tools.ToolResponse, error) {
 91	var args map[string]any
 92	if err := json.Unmarshal([]byte(input), &args); err != nil {
 93		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 94	}
 95
 96	c, err := getOrRenewClient(ctx, cfg, name)
 97	if err != nil {
 98		return tools.NewTextErrorResponse(err.Error()), nil
 99	}
100	result, err := c.CallTool(ctx, mcp.CallToolRequest{
101		Params: mcp.CallToolParams{
102			Name:      toolName,
103			Arguments: args,
104		},
105	})
106	if err != nil {
107		return tools.NewTextErrorResponse(err.Error()), nil
108	}
109
110	output := make([]string, 0, len(result.Content))
111	for _, v := range result.Content {
112		if v, ok := v.(mcp.TextContent); ok {
113			output = append(output, v.Text)
114		} else {
115			output = append(output, fmt.Sprintf("%v", v))
116		}
117	}
118	return tools.NewTextResponse(strings.Join(output, "\n")), nil
119}
120
121func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*client.Client, error) {
122	c, ok := mcpClients.Get(name)
123	if !ok {
124		return nil, fmt.Errorf("mcp '%s' not available", name)
125	}
126
127	m := cfg.MCP[name]
128	state, _ := mcpStates.Get(name)
129
130	timeout := mcpTimeout(m)
131	pingCtx, cancel := context.WithTimeout(ctx, timeout)
132	defer cancel()
133	err := c.Ping(pingCtx)
134	if err == nil {
135		return c, nil
136	}
137	updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
138
139	c, err = createAndInitializeClient(ctx, name, m)
140	if err != nil {
141		return nil, err
142	}
143
144	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
145	mcpClients.Set(name, c)
146	return c, nil
147}
148
149func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
150	sessionID, messageID := tools.GetContextValues(ctx)
151	if sessionID == "" || messageID == "" {
152		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
153	}
154	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
155	p := b.permissions.Request(
156		permission.CreatePermissionRequest{
157			SessionID:   sessionID,
158			ToolCallID:  params.ID,
159			Path:        b.cfg.WorkingDir(),
160			ToolName:    b.Info().Name,
161			Action:      "execute",
162			Description: permissionDescription,
163			Params:      params.Input,
164		},
165	)
166	if !p {
167		return tools.ToolResponse{}, permission.ErrorPermissionDenied
168	}
169
170	return runTool(ctx, b.cfg, b.mcpName, b.tool.Name, params.Input)
171}
172
173func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, cfg *config.Config) []tools.BaseTool {
174	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
175	if err != nil {
176		slog.Error("error listing tools", "error", err)
177		updateMCPState(name, MCPStateError, err, nil, 0)
178		c.Close()
179		mcpClients.Del(name)
180		return nil
181	}
182	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
183	for _, tool := range result.Tools {
184		mcpTools = append(mcpTools, &McpTool{
185			mcpName:     name,
186			tool:        tool,
187			permissions: permissions,
188			cfg:         cfg,
189		})
190	}
191	return mcpTools
192}
193
194// SubscribeMCPEvents returns a channel for MCP events
195func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
196	return mcpBroker.Subscribe(ctx)
197}
198
199// GetMCPStates returns the current state of all MCP clients
200func GetMCPStates() map[string]MCPClientInfo {
201	return maps.Collect(mcpStates.Seq2())
202}
203
204// GetMCPState returns the state of a specific MCP client
205func GetMCPState(name string) (MCPClientInfo, bool) {
206	return mcpStates.Get(name)
207}
208
209// updateMCPState updates the state of an MCP client and publishes an event
210func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
211	info := MCPClientInfo{
212		Name:      name,
213		State:     state,
214		Error:     err,
215		Client:    client,
216		ToolCount: toolCount,
217	}
218	if state == MCPStateConnected {
219		info.ConnectedAt = time.Now()
220	}
221	mcpStates.Set(name, info)
222
223	// Publish state change event
224	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
225		Type:      MCPEventStateChanged,
226		Name:      name,
227		State:     state,
228		Error:     err,
229		ToolCount: toolCount,
230	})
231}
232
233// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
234func CloseMCPClients() error {
235	var errs []error
236	for c := range mcpClients.Seq() {
237		if err := c.Close(); err != nil {
238			errs = append(errs, err)
239		}
240	}
241	mcpBroker.Shutdown()
242	return errors.Join(errs...)
243}
244
245var mcpInitRequest = mcp.InitializeRequest{
246	Params: mcp.InitializeParams{
247		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
248		ClientInfo: mcp.Implementation{
249			Name:    "Crush",
250			Version: version.Version,
251		},
252	},
253}
254
255func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
256	var wg sync.WaitGroup
257	result := csync.NewSlice[tools.BaseTool]()
258
259	// Initialize states for all configured MCPs
260	for name, m := range cfg.MCP {
261		if m.Disabled {
262			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
263			slog.Debug("skipping disabled mcp", "name", name)
264			continue
265		}
266
267		// Set initial starting state
268		updateMCPState(name, MCPStateStarting, nil, nil, 0)
269
270		wg.Add(1)
271		go func(name string, m config.MCPConfig) {
272			defer func() {
273				wg.Done()
274				if r := recover(); r != nil {
275					var err error
276					switch v := r.(type) {
277					case error:
278						err = v
279					case string:
280						err = fmt.Errorf("panic: %s", v)
281					default:
282						err = fmt.Errorf("panic: %v", v)
283					}
284					updateMCPState(name, MCPStateError, err, nil, 0)
285					slog.Error("panic in mcp client initialization", "error", err, "name", name)
286				}
287			}()
288
289			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
290			defer cancel()
291			c, err := createAndInitializeClient(ctx, name, m)
292			if err != nil {
293				return
294			}
295			mcpClients.Set(name, c)
296
297			tools := getTools(ctx, name, permissions, c, cfg)
298			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
299			result.Append(tools...)
300		}(name, m)
301	}
302	wg.Wait()
303	return slices.Collect(result.Seq())
304}
305
306func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
307	c, err := createMcpClient(name, m)
308	if err != nil {
309		updateMCPState(name, MCPStateError, err, nil, 0)
310		slog.Error("error creating mcp client", "error", err, "name", name)
311		return nil, err
312	}
313
314	timeout := mcpTimeout(m)
315	initCtx, cancel := context.WithTimeout(ctx, timeout)
316	defer cancel()
317
318	// Only call Start() for non-stdio clients, as stdio clients auto-start
319	if m.Type != config.MCPStdio {
320		if err := c.Start(initCtx); err != nil {
321			updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
322			slog.Error("error starting mcp client", "error", err, "name", name)
323			_ = c.Close()
324			return nil, err
325		}
326	}
327	if _, err := c.Initialize(initCtx, mcpInitRequest); err != nil {
328		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
329		slog.Error("error initializing mcp client", "error", err, "name", name)
330		_ = c.Close()
331		return nil, err
332	}
333
334	slog.Info("Initialized mcp client", "name", name)
335	return c, nil
336}
337
338func maybeTimeoutErr(err error, timeout time.Duration) error {
339	if errors.Is(err, context.DeadlineExceeded) {
340		return fmt.Errorf("timed out after %s", timeout)
341	}
342	return err
343}
344
345func createMcpClient(name string, m config.MCPConfig) (*client.Client, error) {
346	switch m.Type {
347	case config.MCPStdio:
348		if strings.TrimSpace(m.Command) == "" {
349			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
350		}
351		return client.NewStdioMCPClientWithOptions(
352			home.Long(m.Command),
353			m.ResolvedEnv(),
354			m.Args,
355			transport.WithCommandLogger(mcpLogger{name: name}),
356		)
357	case config.MCPHttp:
358		if strings.TrimSpace(m.URL) == "" {
359			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
360		}
361		return client.NewStreamableHttpClient(
362			m.URL,
363			transport.WithHTTPHeaders(m.ResolvedHeaders()),
364			transport.WithHTTPLogger(mcpLogger{name: name}),
365		)
366	case config.MCPSse:
367		if strings.TrimSpace(m.URL) == "" {
368			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
369		}
370		return client.NewSSEMCPClient(
371			m.URL,
372			client.WithHeaders(m.ResolvedHeaders()),
373			transport.WithSSELogger(mcpLogger{name: name}),
374		)
375	default:
376		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
377	}
378}
379
380// for MCP's clients.
381type mcpLogger struct{ name string }
382
383func (l mcpLogger) Errorf(format string, v ...any) {
384	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
385}
386
387func (l mcpLogger) Infof(format string, v ...any) {
388	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
389}
390
391func mcpTimeout(m config.MCPConfig) time.Duration {
392	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
393}