mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/permission"
 17	"github.com/charmbracelet/crush/internal/pubsub"
 18	"github.com/charmbracelet/crush/internal/version"
 19	"github.com/mark3labs/mcp-go/client"
 20	"github.com/mark3labs/mcp-go/client/transport"
 21	"github.com/mark3labs/mcp-go/mcp"
 22)
 23
 24// MCPState represents the current state of an MCP client
 25type MCPState int
 26
 27const (
 28	MCPStateDisabled MCPState = iota
 29	MCPStateStarting
 30	MCPStateConnected
 31	MCPStateError
 32)
 33
 34func (s MCPState) String() string {
 35	switch s {
 36	case MCPStateDisabled:
 37		return "disabled"
 38	case MCPStateStarting:
 39		return "starting"
 40	case MCPStateConnected:
 41		return "connected"
 42	case MCPStateError:
 43		return "error"
 44	default:
 45		return "unknown"
 46	}
 47}
 48
 49// MCPEventType represents the type of MCP event
 50type MCPEventType string
 51
 52const (
 53	MCPEventStateChanged MCPEventType = "state_changed"
 54)
 55
 56// MCPEvent represents an event in the MCP system
 57type MCPEvent struct {
 58	Type      MCPEventType
 59	Name      string
 60	State     MCPState
 61	Error     error
 62	ToolCount int
 63}
 64
 65// MCPClientInfo holds information about an MCP client's state
 66type MCPClientInfo struct {
 67	Name        string
 68	State       MCPState
 69	Error       error
 70	Client      *client.Client
 71	ToolCount   int
 72	ConnectedAt time.Time
 73}
 74
 75var (
 76	mcpToolsOnce sync.Once
 77	mcpTools     []tools.BaseTool
 78	mcpClients   = csync.NewMap[string, *client.Client]()
 79	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 80	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 81)
 82
 83type McpTool struct {
 84	mcpName     string
 85	tool        mcp.Tool
 86	permissions permission.Service
 87	workingDir  string
 88}
 89
 90func (b *McpTool) Name() string {
 91	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 92}
 93
 94func (b *McpTool) Info() tools.ToolInfo {
 95	required := b.tool.InputSchema.Required
 96	if required == nil {
 97		required = make([]string, 0)
 98	}
 99	return tools.ToolInfo{
100		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
101		Description: b.tool.Description,
102		Parameters:  b.tool.InputSchema.Properties,
103		Required:    required,
104	}
105}
106
107func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
108	var args map[string]any
109	if err := json.Unmarshal([]byte(input), &args); err != nil {
110		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
111	}
112	c, ok := mcpClients.Get(name)
113	if !ok {
114		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
115	}
116	result, err := c.CallTool(ctx, mcp.CallToolRequest{
117		Params: mcp.CallToolParams{
118			Name:      toolName,
119			Arguments: args,
120		},
121	})
122	if err != nil {
123		return tools.NewTextErrorResponse(err.Error()), nil
124	}
125
126	output := make([]string, 0, len(result.Content))
127	for _, v := range result.Content {
128		if v, ok := v.(mcp.TextContent); ok {
129			output = append(output, v.Text)
130		} else {
131			output = append(output, fmt.Sprintf("%v", v))
132		}
133	}
134	return tools.NewTextResponse(strings.Join(output, "\n")), nil
135}
136
137func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
138	sessionID, messageID := tools.GetContextValues(ctx)
139	if sessionID == "" || messageID == "" {
140		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
141	}
142	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
143	p := b.permissions.Request(
144		permission.CreatePermissionRequest{
145			SessionID:   sessionID,
146			ToolCallID:  params.ID,
147			Path:        b.workingDir,
148			ToolName:    b.Info().Name,
149			Action:      "execute",
150			Description: permissionDescription,
151			Params:      params.Input,
152		},
153	)
154	if !p {
155		return tools.ToolResponse{}, permission.ErrorPermissionDenied
156	}
157
158	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
159}
160
161func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
162	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
163	if err != nil {
164		slog.Error("error listing tools", "error", err)
165		updateMCPState(name, MCPStateError, err, nil, 0)
166		c.Close()
167		mcpClients.Del(name)
168		return nil
169	}
170	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
171	for _, tool := range result.Tools {
172		mcpTools = append(mcpTools, &McpTool{
173			mcpName:     name,
174			tool:        tool,
175			permissions: permissions,
176			workingDir:  workingDir,
177		})
178	}
179	return mcpTools
180}
181
182// SubscribeMCPEvents returns a channel for MCP events
183func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
184	return mcpBroker.Subscribe(ctx)
185}
186
187// GetMCPStates returns the current state of all MCP clients
188func GetMCPStates() map[string]MCPClientInfo {
189	states := make(map[string]MCPClientInfo)
190	for name, info := range mcpStates.Seq2() {
191		states[name] = info
192	}
193	return states
194}
195
196// GetMCPState returns the state of a specific MCP client
197func GetMCPState(name string) (MCPClientInfo, bool) {
198	return mcpStates.Get(name)
199}
200
201// updateMCPState updates the state of an MCP client and publishes an event
202func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
203	info := MCPClientInfo{
204		Name:      name,
205		State:     state,
206		Error:     err,
207		Client:    client,
208		ToolCount: toolCount,
209	}
210	if state == MCPStateConnected {
211		info.ConnectedAt = time.Now()
212	}
213	mcpStates.Set(name, info)
214
215	// Publish state change event
216	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
217		Type:      MCPEventStateChanged,
218		Name:      name,
219		State:     state,
220		Error:     err,
221		ToolCount: toolCount,
222	})
223}
224
225// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
226func CloseMCPClients() {
227	for c := range mcpClients.Seq() {
228		_ = c.Close()
229	}
230	mcpBroker.Shutdown()
231}
232
233var mcpInitRequest = mcp.InitializeRequest{
234	Params: mcp.InitializeParams{
235		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
236		ClientInfo: mcp.Implementation{
237			Name:    "Crush",
238			Version: version.Version,
239		},
240	},
241}
242
243func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
244	var wg sync.WaitGroup
245	result := csync.NewSlice[tools.BaseTool]()
246
247	// Initialize states for all configured MCPs
248	for name, m := range cfg.MCP {
249		if m.Disabled {
250			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
251			slog.Debug("skipping disabled mcp", "name", name)
252			continue
253		}
254
255		// Set initial starting state
256		updateMCPState(name, MCPStateStarting, nil, nil, 0)
257
258		wg.Add(1)
259		go func(name string, m config.MCPConfig) {
260			defer func() {
261				wg.Done()
262				if r := recover(); r != nil {
263					var err error
264					switch v := r.(type) {
265					case error:
266						err = v
267					case string:
268						err = fmt.Errorf("panic: %s", v)
269					default:
270						err = fmt.Errorf("panic: %v", v)
271					}
272					updateMCPState(name, MCPStateError, err, nil, 0)
273					slog.Error("panic in mcp client initialization", "error", err, "name", name)
274				}
275			}()
276
277			ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
278			defer cancel()
279			c, err := createMcpClient(m)
280			if err != nil {
281				updateMCPState(name, MCPStateError, err, nil, 0)
282				slog.Error("error creating mcp client", "error", err, "name", name)
283				return
284			}
285			// Only call Start() for non-stdio clients, as stdio clients auto-start
286			if m.Type != config.MCPStdio {
287				if err := c.Start(ctx); err != nil {
288					updateMCPState(name, MCPStateError, err, nil, 0)
289					slog.Error("error starting mcp client", "error", err, "name", name)
290					_ = c.Close()
291					return
292				}
293			}
294			if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
295				updateMCPState(name, MCPStateError, err, nil, 0)
296				slog.Error("error initializing mcp client", "error", err, "name", name)
297				_ = c.Close()
298				return
299			}
300
301			slog.Info("Initialized mcp client", "name", name)
302			mcpClients.Set(name, c)
303
304			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
305			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
306			result.Append(tools...)
307		}(name, m)
308	}
309	wg.Wait()
310	return slices.Collect(result.Seq())
311}
312
313func createMcpClient(m config.MCPConfig) (*client.Client, error) {
314	switch m.Type {
315	case config.MCPStdio:
316		return client.NewStdioMCPClientWithOptions(
317			m.Command,
318			m.ResolvedEnv(),
319			m.Args,
320			transport.WithCommandLogger(mcpLogger{}),
321		)
322	case config.MCPHttp:
323		return client.NewStreamableHttpClient(
324			m.URL,
325			transport.WithHTTPHeaders(m.ResolvedHeaders()),
326			transport.WithHTTPLogger(mcpLogger{}),
327		)
328	case config.MCPSse:
329		return client.NewSSEMCPClient(
330			m.URL,
331			client.WithHeaders(m.ResolvedHeaders()),
332			transport.WithSSELogger(mcpLogger{}),
333		)
334	default:
335		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
336	}
337}
338
339// for MCP's clients.
340type mcpLogger struct{}
341
342func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
343func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }