mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/csync"
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14	"github.com/charmbracelet/crush/internal/permission"
 15	"github.com/charmbracelet/crush/internal/version"
 16	"github.com/mark3labs/mcp-go/client"
 17	"github.com/mark3labs/mcp-go/client/transport"
 18	"github.com/mark3labs/mcp-go/mcp"
 19)
 20
 21var (
 22	mcpToolsOnce sync.Once
 23	mcpTools     []tools.BaseTool
 24	mcpClients   = csync.NewMap[string, *client.Client]()
 25)
 26
 27type McpTool struct {
 28	mcpName     string
 29	tool        mcp.Tool
 30	permissions permission.Service
 31	workingDir  string
 32}
 33
 34func (b *McpTool) Name() string {
 35	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 36}
 37
 38func (b *McpTool) Info() tools.ToolInfo {
 39	required := b.tool.InputSchema.Required
 40	if required == nil {
 41		required = make([]string, 0)
 42	}
 43	return tools.ToolInfo{
 44		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 45		Description: b.tool.Description,
 46		Parameters:  b.tool.InputSchema.Properties,
 47		Required:    required,
 48	}
 49}
 50
 51func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
 52	var args map[string]any
 53	if err := json.Unmarshal([]byte(input), &args); err != nil {
 54		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 55	}
 56	c, ok := mcpClients.Get(name)
 57	if !ok {
 58		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
 59	}
 60	result, err := c.CallTool(ctx, mcp.CallToolRequest{
 61		Params: mcp.CallToolParams{
 62			Name:      toolName,
 63			Arguments: args,
 64		},
 65	})
 66	if err != nil {
 67		return tools.NewTextErrorResponse(err.Error()), nil
 68	}
 69
 70	output := ""
 71	for _, v := range result.Content {
 72		if v, ok := v.(mcp.TextContent); ok {
 73			output = v.Text
 74		} else {
 75			output = fmt.Sprintf("%v", v)
 76		}
 77	}
 78
 79	return tools.NewTextResponse(output), nil
 80}
 81
 82func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 83	sessionID, messageID := tools.GetContextValues(ctx)
 84	if sessionID == "" || messageID == "" {
 85		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 86	}
 87	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 88	p := b.permissions.Request(
 89		permission.CreatePermissionRequest{
 90			SessionID:   sessionID,
 91			ToolCallID:  params.ID,
 92			Path:        b.workingDir,
 93			ToolName:    b.Info().Name,
 94			Action:      "execute",
 95			Description: permissionDescription,
 96			Params:      params.Input,
 97		},
 98	)
 99	if !p {
100		return tools.ToolResponse{}, permission.ErrorPermissionDenied
101	}
102
103	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
104}
105
106func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
107	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
108	if err != nil {
109		slog.Error("error listing tools", "error", err)
110		c.Close()
111		mcpClients.Del(name)
112		return nil
113	}
114	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
115	for _, tool := range result.Tools {
116		mcpTools = append(mcpTools, &McpTool{
117			mcpName:     name,
118			tool:        tool,
119			permissions: permissions,
120			workingDir:  workingDir,
121		})
122	}
123	return mcpTools
124}
125
126// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
127func CloseMCPClients() {
128	for c := range mcpClients.Seq() {
129		_ = c.Close()
130	}
131}
132
133var mcpInitRequest = mcp.InitializeRequest{
134	Params: mcp.InitializeParams{
135		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
136		ClientInfo: mcp.Implementation{
137			Name:    "Crush",
138			Version: version.Version,
139		},
140	},
141}
142
143func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
144	var wg sync.WaitGroup
145	result := csync.NewSlice[tools.BaseTool]()
146	for name, m := range cfg.MCP {
147		if m.Disabled {
148			slog.Debug("skipping disabled mcp", "name", name)
149			continue
150		}
151		wg.Add(1)
152		go func(name string, m config.MCPConfig) {
153			defer wg.Done()
154			c, err := createMcpClient(m)
155			if err != nil {
156				slog.Error("error creating mcp client", "error", err, "name", name)
157				return
158			}
159			if err := c.Start(ctx); err != nil {
160				slog.Error("error starting mcp client", "error", err, "name", name)
161				_ = c.Close()
162				return
163			}
164			if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
165				slog.Error("error initializing mcp client", "error", err, "name", name)
166				_ = c.Close()
167				return
168			}
169
170			slog.Info("Initialized mcp client", "name", name)
171			mcpClients.Set(name, c)
172
173			result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
174		}(name, m)
175	}
176	wg.Wait()
177	return slices.Collect(result.Seq())
178}
179
180func createMcpClient(m config.MCPConfig) (*client.Client, error) {
181	switch m.Type {
182	case config.MCPStdio:
183		return client.NewStdioMCPClient(
184			m.Command,
185			m.ResolvedEnv(),
186			m.Args...,
187		)
188	case config.MCPHttp:
189		return client.NewStreamableHttpClient(
190			m.URL,
191			transport.WithHTTPHeaders(m.ResolvedHeaders()),
192			transport.WithLogger(mcpHTTPLogger{}),
193		)
194	case config.MCPSse:
195		return client.NewSSEMCPClient(
196			m.URL,
197			client.WithHeaders(m.ResolvedHeaders()),
198		)
199	default:
200		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
201	}
202}
203
204// for MCP's HTTP client.
205type mcpHTTPLogger struct{}
206
207func (l mcpHTTPLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
208func (l mcpHTTPLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }