mcp-tools.go

  1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"maps"
 11	"slices"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/home"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/version"
 23	"github.com/mark3labs/mcp-go/client"
 24	"github.com/mark3labs/mcp-go/client/transport"
 25	"github.com/mark3labs/mcp-go/mcp"
 26)
 27
 28// MCPState represents the current state of an MCP client
 29type MCPState int
 30
 31const (
 32	MCPStateDisabled MCPState = iota
 33	MCPStateStarting
 34	MCPStateConnected
 35	MCPStateError
 36)
 37
 38func (s MCPState) String() string {
 39	switch s {
 40	case MCPStateDisabled:
 41		return "disabled"
 42	case MCPStateStarting:
 43		return "starting"
 44	case MCPStateConnected:
 45		return "connected"
 46	case MCPStateError:
 47		return "error"
 48	default:
 49		return "unknown"
 50	}
 51}
 52
 53// MCPEventType represents the type of MCP event
 54type MCPEventType string
 55
 56const (
 57	MCPEventStateChanged MCPEventType = "state_changed"
 58)
 59
 60// MCPEvent represents an event in the MCP system
 61type MCPEvent struct {
 62	Type      MCPEventType
 63	Name      string
 64	State     MCPState
 65	Error     error
 66	ToolCount int
 67}
 68
 69// MCPClientInfo holds information about an MCP client's state
 70type MCPClientInfo struct {
 71	Name        string
 72	State       MCPState
 73	Error       error
 74	Client      *client.Client
 75	ToolCount   int
 76	ConnectedAt time.Time
 77}
 78
 79var (
 80	mcpToolsOnce sync.Once
 81	mcpTools     []tools.BaseTool
 82	mcpClients   = csync.NewMap[string, *client.Client]()
 83	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 84	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 85)
 86
 87type McpTool struct {
 88	mcpName     string
 89	tool        mcp.Tool
 90	permissions permission.Service
 91	workingDir  string
 92}
 93
 94func (b *McpTool) Name() string {
 95	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 96}
 97
 98func (b *McpTool) Info() tools.ToolInfo {
 99	required := b.tool.InputSchema.Required
100	if required == nil {
101		required = make([]string, 0)
102	}
103	parameters := b.tool.InputSchema.Properties
104	if parameters == nil {
105		parameters = make(map[string]any)
106	}
107	return tools.ToolInfo{
108		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
109		Description: b.tool.Description,
110		Parameters:  parameters,
111		Required:    required,
112	}
113}
114
115func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
116	var args map[string]any
117	if err := json.Unmarshal([]byte(input), &args); err != nil {
118		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
119	}
120
121	c, err := getOrRenewClient(ctx, name)
122	if err != nil {
123		return tools.NewTextErrorResponse(err.Error()), nil
124	}
125	result, err := c.CallTool(ctx, mcp.CallToolRequest{
126		Params: mcp.CallToolParams{
127			Name:      toolName,
128			Arguments: args,
129		},
130	})
131	if err != nil {
132		return tools.NewTextErrorResponse(err.Error()), nil
133	}
134
135	output := make([]string, 0, len(result.Content))
136	for _, v := range result.Content {
137		if v, ok := v.(mcp.TextContent); ok {
138			output = append(output, v.Text)
139		} else {
140			output = append(output, fmt.Sprintf("%v", v))
141		}
142	}
143	return tools.NewTextResponse(strings.Join(output, "\n")), nil
144}
145
146func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
147	c, ok := mcpClients.Get(name)
148	if !ok {
149		return nil, fmt.Errorf("mcp '%s' not available", name)
150	}
151
152	m := config.Get().MCP[name]
153	state, _ := mcpStates.Get(name)
154
155	pingCtx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
156	defer cancel()
157	err := c.Ping(pingCtx)
158	if err == nil {
159		return c, nil
160	}
161	updateMCPState(name, MCPStateError, err, nil, state.ToolCount)
162
163	c, err = createAndInitializeClient(ctx, name, m)
164	if err != nil {
165		return nil, err
166	}
167
168	updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
169	mcpClients.Set(name, c)
170	return c, nil
171}
172
173func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
174	sessionID, messageID := tools.GetContextValues(ctx)
175	if sessionID == "" || messageID == "" {
176		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
177	}
178	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
179	p := b.permissions.Request(
180		permission.CreatePermissionRequest{
181			SessionID:   sessionID,
182			ToolCallID:  params.ID,
183			Path:        b.workingDir,
184			ToolName:    b.Info().Name,
185			Action:      "execute",
186			Description: permissionDescription,
187			Params:      params.Input,
188		},
189	)
190	if !p {
191		return tools.ToolResponse{}, permission.ErrorPermissionDenied
192	}
193
194	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
195}
196
197func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
198	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
199	if err != nil {
200		slog.Error("error listing tools", "error", err)
201		updateMCPState(name, MCPStateError, err, nil, 0)
202		c.Close()
203		mcpClients.Del(name)
204		return nil
205	}
206	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
207	for _, tool := range result.Tools {
208		mcpTools = append(mcpTools, &McpTool{
209			mcpName:     name,
210			tool:        tool,
211			permissions: permissions,
212			workingDir:  workingDir,
213		})
214	}
215	return mcpTools
216}
217
218// SubscribeMCPEvents returns a channel for MCP events
219func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
220	return mcpBroker.Subscribe(ctx)
221}
222
223// GetMCPStates returns the current state of all MCP clients
224func GetMCPStates() map[string]MCPClientInfo {
225	return maps.Collect(mcpStates.Seq2())
226}
227
228// GetMCPState returns the state of a specific MCP client
229func GetMCPState(name string) (MCPClientInfo, bool) {
230	return mcpStates.Get(name)
231}
232
233// updateMCPState updates the state of an MCP client and publishes an event
234func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
235	info := MCPClientInfo{
236		Name:      name,
237		State:     state,
238		Error:     err,
239		Client:    client,
240		ToolCount: toolCount,
241	}
242	if state == MCPStateConnected {
243		info.ConnectedAt = time.Now()
244	}
245	mcpStates.Set(name, info)
246
247	// Publish state change event
248	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
249		Type:      MCPEventStateChanged,
250		Name:      name,
251		State:     state,
252		Error:     err,
253		ToolCount: toolCount,
254	})
255}
256
257// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
258func CloseMCPClients() error {
259	var errs []error
260	for c := range mcpClients.Seq() {
261		if err := c.Close(); err != nil {
262			errs = append(errs, err)
263		}
264	}
265	mcpBroker.Shutdown()
266	return errors.Join(errs...)
267}
268
269var mcpInitRequest = mcp.InitializeRequest{
270	Params: mcp.InitializeParams{
271		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
272		ClientInfo: mcp.Implementation{
273			Name:    "Crush",
274			Version: version.Version,
275		},
276	},
277}
278
279func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
280	var wg sync.WaitGroup
281	result := csync.NewSlice[tools.BaseTool]()
282
283	// Initialize states for all configured MCPs
284	for name, m := range cfg.MCP {
285		if m.Disabled {
286			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
287			slog.Debug("skipping disabled mcp", "name", name)
288			continue
289		}
290
291		// Set initial starting state
292		updateMCPState(name, MCPStateStarting, nil, nil, 0)
293
294		wg.Add(1)
295		go func(name string, m config.MCPConfig) {
296			defer func() {
297				wg.Done()
298				if r := recover(); r != nil {
299					var err error
300					switch v := r.(type) {
301					case error:
302						err = v
303					case string:
304						err = fmt.Errorf("panic: %s", v)
305					default:
306						err = fmt.Errorf("panic: %v", v)
307					}
308					updateMCPState(name, MCPStateError, err, nil, 0)
309					slog.Error("panic in mcp client initialization", "error", err, "name", name)
310				}
311			}()
312
313			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
314			defer cancel()
315			c, err := createAndInitializeClient(ctx, name, m)
316			if err != nil {
317				return
318			}
319			mcpClients.Set(name, c)
320
321			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
322			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
323			result.Append(tools...)
324		}(name, m)
325	}
326	wg.Wait()
327	return slices.Collect(result.Seq())
328}
329
330func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
331	c, err := createMcpClient(name, m)
332	if err != nil {
333		updateMCPState(name, MCPStateError, err, nil, 0)
334		slog.Error("error creating mcp client", "error", err, "name", name)
335		return nil, err
336	}
337	// Only call Start() for non-stdio clients, as stdio clients auto-start
338	if m.Type != config.MCPStdio {
339		if err := c.Start(ctx); err != nil {
340			updateMCPState(name, MCPStateError, err, nil, 0)
341			slog.Error("error starting mcp client", "error", err, "name", name)
342			_ = c.Close()
343			return nil, err
344		}
345	}
346	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
347		updateMCPState(name, MCPStateError, err, nil, 0)
348		slog.Error("error initializing mcp client", "error", err, "name", name)
349		_ = c.Close()
350		return nil, err
351	}
352
353	slog.Info("Initialized mcp client", "name", name)
354	return c, nil
355}
356
357func createMcpClient(name string, m config.MCPConfig) (*client.Client, error) {
358	switch m.Type {
359	case config.MCPStdio:
360		if strings.TrimSpace(m.Command) == "" {
361			return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
362		}
363		return client.NewStdioMCPClientWithOptions(
364			home.Long(m.Command),
365			m.ResolvedEnv(),
366			m.Args,
367			transport.WithCommandLogger(mcpLogger{name: name}),
368		)
369	case config.MCPHttp:
370		if strings.TrimSpace(m.URL) == "" {
371			return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
372		}
373		return client.NewStreamableHttpClient(
374			m.URL,
375			transport.WithHTTPHeaders(m.ResolvedHeaders()),
376			transport.WithHTTPLogger(mcpLogger{name: name}),
377		)
378	case config.MCPSse:
379		if strings.TrimSpace(m.URL) == "" {
380			return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
381		}
382		return client.NewSSEMCPClient(
383			m.URL,
384			client.WithHeaders(m.ResolvedHeaders()),
385			transport.WithSSELogger(mcpLogger{name: name}),
386		)
387	default:
388		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
389	}
390}
391
392// for MCP's clients.
393type mcpLogger struct{ name string }
394
395func (l mcpLogger) Errorf(format string, v ...any) {
396	slog.Error(fmt.Sprintf(format, v...), "name", l.name)
397}
398
399func (l mcpLogger) Infof(format string, v ...any) {
400	slog.Info(fmt.Sprintf(format, v...), "name", l.name)
401}
402
403func mcpTimeout(m config.MCPConfig) time.Duration {
404	return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
405}