mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/csync"
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14	"github.com/charmbracelet/crush/internal/version"
 15
 16	"github.com/charmbracelet/crush/internal/permission"
 17
 18	"github.com/mark3labs/mcp-go/client"
 19	"github.com/mark3labs/mcp-go/client/transport"
 20	"github.com/mark3labs/mcp-go/mcp"
 21)
 22
 23var (
 24	mcpToolsOnce sync.Once
 25	mcpTools     []tools.BaseTool
 26	mcpClients   = csync.NewMap[string, *client.Client]()
 27)
 28
 29type McpTool struct {
 30	mcpName     string
 31	tool        mcp.Tool
 32	permissions permission.Service
 33	workingDir  string
 34}
 35
 36func (b *McpTool) Name() string {
 37	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 38}
 39
 40func (b *McpTool) Info() tools.ToolInfo {
 41	required := b.tool.InputSchema.Required
 42	if required == nil {
 43		required = make([]string, 0)
 44	}
 45	return tools.ToolInfo{
 46		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 47		Description: b.tool.Description,
 48		Parameters:  b.tool.InputSchema.Properties,
 49		Required:    required,
 50	}
 51}
 52
 53func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
 54	var args map[string]any
 55	if err := json.Unmarshal([]byte(input), &args); err != nil {
 56		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 57	}
 58	c, ok := mcpClients.Get(name)
 59	if !ok {
 60		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
 61	}
 62	result, err := c.CallTool(ctx, mcp.CallToolRequest{
 63		Params: mcp.CallToolParams{
 64			Name:      toolName,
 65			Arguments: args,
 66		},
 67	})
 68	if err != nil {
 69		return tools.NewTextErrorResponse(err.Error()), nil
 70	}
 71
 72	output := ""
 73	for _, v := range result.Content {
 74		if v, ok := v.(mcp.TextContent); ok {
 75			output = v.Text
 76		} else {
 77			output = fmt.Sprintf("%v", v)
 78		}
 79	}
 80
 81	return tools.NewTextResponse(output), nil
 82}
 83
 84func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 85	sessionID, messageID := tools.GetContextValues(ctx)
 86	if sessionID == "" || messageID == "" {
 87		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 88	}
 89	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 90	p := b.permissions.Request(
 91		permission.CreatePermissionRequest{
 92			SessionID:   sessionID,
 93			ToolCallID:  params.ID,
 94			Path:        b.workingDir,
 95			ToolName:    b.Info().Name,
 96			Action:      "execute",
 97			Description: permissionDescription,
 98			Params:      params.Input,
 99		},
100	)
101	if !p {
102		return tools.ToolResponse{}, permission.ErrorPermissionDenied
103	}
104
105	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
106}
107
108func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
109	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
110	if err != nil {
111		slog.Error("error listing tools", "error", err)
112		c.Close()
113		mcpClients.Del(name)
114		return nil
115	}
116	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
117	for _, tool := range result.Tools {
118		mcpTools = append(mcpTools, &McpTool{
119			mcpName:     name,
120			tool:        tool,
121			permissions: permissions,
122			workingDir:  workingDir,
123		})
124	}
125	return mcpTools
126}
127
128// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
129func CloseMCPClients() {
130	for c := range mcpClients.Seq() {
131		_ = c.Close()
132	}
133}
134
135var mcpInitRequest = mcp.InitializeRequest{
136	Params: mcp.InitializeParams{
137		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
138		ClientInfo: mcp.Implementation{
139			Name:    "Crush",
140			Version: version.Version,
141		},
142	},
143}
144
145func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
146	var wg sync.WaitGroup
147	result := csync.NewSlice[tools.BaseTool]()
148	for name, m := range cfg.MCP {
149		if m.Disabled {
150			slog.Debug("skipping disabled mcp", "name", name)
151			continue
152		}
153		wg.Add(1)
154		go func(name string, m config.MCPConfig) {
155			defer wg.Done()
156			c, err := createMcpClient(m)
157			if err != nil {
158				slog.Error("error creating mcp client", "error", err, "name", name)
159				return
160			}
161			if err := c.Start(ctx); err != nil {
162				slog.Error("error starting mcp client", "error", err, "name", name)
163				_ = c.Close()
164				return
165			}
166			if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
167				slog.Error("error initializing mcp client", "error", err, "name", name)
168				_ = c.Close()
169				return
170			}
171
172			slog.Info("Initialized mcp client", "name", name)
173			mcpClients.Set(name, c)
174
175			result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
176		}(name, m)
177	}
178	wg.Wait()
179	return slices.Collect(result.Seq())
180}
181
182func createMcpClient(m config.MCPConfig) (*client.Client, error) {
183	switch m.Type {
184	case config.MCPStdio:
185		return client.NewStdioMCPClient(
186			m.Command,
187			m.ResolvedEnv(),
188			m.Args...,
189		)
190	case config.MCPHttp:
191		return client.NewStreamableHttpClient(
192			m.URL,
193			transport.WithHTTPHeaders(m.ResolvedHeaders()),
194		)
195	case config.MCPSse:
196		return client.NewSSEMCPClient(
197			m.URL,
198			client.WithHeaders(m.ResolvedHeaders()),
199		)
200	default:
201		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
202	}
203}