mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/pubsub"
 16	"github.com/charmbracelet/crush/internal/version"
 17
 18	"github.com/charmbracelet/crush/internal/permission"
 19
 20	"github.com/mark3labs/mcp-go/client"
 21	"github.com/mark3labs/mcp-go/client/transport"
 22	"github.com/mark3labs/mcp-go/mcp"
 23)
 24
 25// MCPState represents the current state of an MCP client
 26type MCPState int
 27
 28const (
 29	MCPStateDisabled MCPState = iota
 30	MCPStateStarting
 31	MCPStateConnected
 32	MCPStateError
 33)
 34
 35func (s MCPState) String() string {
 36	switch s {
 37	case MCPStateDisabled:
 38		return "disabled"
 39	case MCPStateStarting:
 40		return "starting"
 41	case MCPStateConnected:
 42		return "connected"
 43	case MCPStateError:
 44		return "error"
 45	default:
 46		return "unknown"
 47	}
 48}
 49
 50// MCPEventType represents the type of MCP event
 51type MCPEventType string
 52
 53const (
 54	MCPEventStateChanged MCPEventType = "state_changed"
 55)
 56
 57// MCPEvent represents an event in the MCP system
 58type MCPEvent struct {
 59	Type      MCPEventType
 60	Name      string
 61	State     MCPState
 62	Error     error
 63	ToolCount int
 64}
 65
 66// MCPClientInfo holds information about an MCP client's state
 67type MCPClientInfo struct {
 68	Name        string
 69	State       MCPState
 70	Error       error
 71	Client      *client.Client
 72	ToolCount   int
 73	ConnectedAt time.Time
 74}
 75
 76var (
 77	mcpToolsOnce sync.Once
 78	mcpTools     []tools.BaseTool
 79	mcpClients   = csync.NewMap[string, *client.Client]()
 80	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 81	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 82)
 83
 84type McpTool struct {
 85	mcpName     string
 86	tool        mcp.Tool
 87	permissions permission.Service
 88	workingDir  string
 89}
 90
 91func (b *McpTool) Name() string {
 92	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 93}
 94
 95func (b *McpTool) Info() tools.ToolInfo {
 96	required := b.tool.InputSchema.Required
 97	if required == nil {
 98		required = make([]string, 0)
 99	}
100	return tools.ToolInfo{
101		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
102		Description: b.tool.Description,
103		Parameters:  b.tool.InputSchema.Properties,
104		Required:    required,
105	}
106}
107
108func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
109	var args map[string]any
110	if err := json.Unmarshal([]byte(input), &args); err != nil {
111		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
112	}
113	c, ok := mcpClients.Get(name)
114	if !ok {
115		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
116	}
117	result, err := c.CallTool(ctx, mcp.CallToolRequest{
118		Params: mcp.CallToolParams{
119			Name:      toolName,
120			Arguments: args,
121		},
122	})
123	if err != nil {
124		return tools.NewTextErrorResponse(err.Error()), nil
125	}
126
127	output := ""
128	for _, v := range result.Content {
129		if v, ok := v.(mcp.TextContent); ok {
130			output = v.Text
131		} else {
132			output = fmt.Sprintf("%v", v)
133		}
134	}
135
136	return tools.NewTextResponse(output), nil
137}
138
139func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
140	sessionID, messageID := tools.GetContextValues(ctx)
141	if sessionID == "" || messageID == "" {
142		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
143	}
144	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
145	p := b.permissions.Request(
146		permission.CreatePermissionRequest{
147			SessionID:   sessionID,
148			ToolCallID:  params.ID,
149			Path:        b.workingDir,
150			ToolName:    b.Info().Name,
151			Action:      "execute",
152			Description: permissionDescription,
153			Params:      params.Input,
154		},
155	)
156	if !p {
157		return tools.ToolResponse{}, permission.ErrorPermissionDenied
158	}
159
160	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
161}
162
163func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
164	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
165	if err != nil {
166		slog.Error("error listing tools", "error", err)
167		updateMCPState(name, MCPStateError, err, nil, 0)
168		c.Close()
169		mcpClients.Del(name)
170		return nil
171	}
172	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
173	for _, tool := range result.Tools {
174		mcpTools = append(mcpTools, &McpTool{
175			mcpName:     name,
176			tool:        tool,
177			permissions: permissions,
178			workingDir:  workingDir,
179		})
180	}
181	return mcpTools
182}
183
184// SubscribeMCPEvents returns a channel for MCP events
185func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
186	return mcpBroker.Subscribe(ctx)
187}
188
189// GetMCPStates returns the current state of all MCP clients
190func GetMCPStates() map[string]MCPClientInfo {
191	states := make(map[string]MCPClientInfo)
192	for name, info := range mcpStates.Seq2() {
193		states[name] = info
194	}
195	return states
196}
197
198// GetMCPState returns the state of a specific MCP client
199func GetMCPState(name string) (MCPClientInfo, bool) {
200	return mcpStates.Get(name)
201}
202
203// updateMCPState updates the state of an MCP client and publishes an event
204func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
205	info := MCPClientInfo{
206		Name:      name,
207		State:     state,
208		Error:     err,
209		Client:    client,
210		ToolCount: toolCount,
211	}
212	if state == MCPStateConnected {
213		info.ConnectedAt = time.Now()
214	}
215	mcpStates.Set(name, info)
216
217	// Publish state change event
218	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
219		Type:      MCPEventStateChanged,
220		Name:      name,
221		State:     state,
222		Error:     err,
223		ToolCount: toolCount,
224	})
225}
226
227// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
228func CloseMCPClients() {
229	for c := range mcpClients.Seq() {
230		_ = c.Close()
231	}
232	mcpBroker.Shutdown()
233}
234
235var mcpInitRequest = mcp.InitializeRequest{
236	Params: mcp.InitializeParams{
237		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
238		ClientInfo: mcp.Implementation{
239			Name:    "Crush",
240			Version: version.Version,
241		},
242	},
243}
244
245func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
246	var wg sync.WaitGroup
247	result := csync.NewSlice[tools.BaseTool]()
248
249	// Initialize states for all configured MCPs
250	for name, m := range cfg.MCP {
251		if m.Disabled {
252			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
253			slog.Debug("skipping disabled mcp", "name", name)
254			continue
255		}
256
257		// Set initial starting state
258		updateMCPState(name, MCPStateStarting, nil, nil, 0)
259
260		wg.Add(1)
261		go func(name string, m config.MCPConfig) {
262			defer func() {
263				wg.Done()
264				if r := recover(); r != nil {
265					var err error
266					switch v := r.(type) {
267					case error:
268						err = v
269					case string:
270						err = fmt.Errorf("panic: %s", v)
271					default:
272						err = fmt.Errorf("panic: %v", v)
273					}
274					updateMCPState(name, MCPStateError, err, nil, 0)
275					slog.Error("panic in mcp client initialization", "error", err, "name", name)
276				}
277			}()
278
279			c, err := createMcpClient(m)
280			if err != nil {
281				updateMCPState(name, MCPStateError, err, nil, 0)
282				slog.Error("error creating mcp client", "error", err, "name", name)
283				return
284			}
285			if err := c.Start(ctx); err != nil {
286				updateMCPState(name, MCPStateError, err, nil, 0)
287				slog.Error("error starting mcp client", "error", err, "name", name)
288				_ = c.Close()
289				return
290			}
291			if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
292				updateMCPState(name, MCPStateError, err, nil, 0)
293				slog.Error("error initializing mcp client", "error", err, "name", name)
294				_ = c.Close()
295				return
296			}
297
298			slog.Info("Initialized mcp client", "name", name)
299			mcpClients.Set(name, c)
300
301			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
302			toolCount := len(tools)
303			updateMCPState(name, MCPStateConnected, nil, c, toolCount)
304			result.Append(tools...)
305		}(name, m)
306	}
307	wg.Wait()
308	return slices.Collect(result.Seq())
309}
310
311func createMcpClient(m config.MCPConfig) (*client.Client, error) {
312	switch m.Type {
313	case config.MCPStdio:
314		return client.NewStdioMCPClient(
315			m.Command,
316			m.ResolvedEnv(),
317			m.Args...,
318		)
319	case config.MCPHttp:
320		return client.NewStreamableHttpClient(
321			m.URL,
322			transport.WithHTTPHeaders(m.ResolvedHeaders()),
323		)
324	case config.MCPSse:
325		return client.NewSSEMCPClient(
326			m.URL,
327			client.WithHeaders(m.ResolvedHeaders()),
328		)
329	default:
330		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
331	}
332}