mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/csync"
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14
 15	"github.com/charmbracelet/crush/internal/permission"
 16	"github.com/charmbracelet/crush/internal/version"
 17
 18	"github.com/mark3labs/mcp-go/client"
 19	"github.com/mark3labs/mcp-go/client/transport"
 20	"github.com/mark3labs/mcp-go/mcp"
 21)
 22
 23type mcpTool struct {
 24	mcpName     string
 25	tool        mcp.Tool
 26	mcpConfig   config.MCPConfig
 27	permissions permission.Service
 28	workingDir  string
 29}
 30
 31type MCPClient interface {
 32	Initialize(
 33		ctx context.Context,
 34		request mcp.InitializeRequest,
 35	) (*mcp.InitializeResult, error)
 36	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 37	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 38	Close() error
 39}
 40
 41func (b *mcpTool) Name() string {
 42	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 43}
 44
 45func (b *mcpTool) Info() tools.ToolInfo {
 46	required := b.tool.InputSchema.Required
 47	if required == nil {
 48		required = make([]string, 0)
 49	}
 50	return tools.ToolInfo{
 51		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 52		Description: b.tool.Description,
 53		Parameters:  b.tool.InputSchema.Properties,
 54		Required:    required,
 55	}
 56}
 57
 58func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 59	defer c.Close()
 60	initRequest := mcp.InitializeRequest{}
 61	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 62	initRequest.Params.ClientInfo = mcp.Implementation{
 63		Name:    "Crush",
 64		Version: version.Version,
 65	}
 66
 67	_, err := c.Initialize(ctx, initRequest)
 68	if err != nil {
 69		return tools.NewTextErrorResponse(err.Error()), nil
 70	}
 71
 72	toolRequest := mcp.CallToolRequest{}
 73	toolRequest.Params.Name = toolName
 74	var args map[string]any
 75	if err = json.Unmarshal([]byte(input), &args); err != nil {
 76		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 77	}
 78	toolRequest.Params.Arguments = args
 79	result, err := c.CallTool(ctx, toolRequest)
 80	if err != nil {
 81		return tools.NewTextErrorResponse(err.Error()), nil
 82	}
 83
 84	output := ""
 85	for _, v := range result.Content {
 86		if v, ok := v.(mcp.TextContent); ok {
 87			output = v.Text
 88		} else {
 89			output = fmt.Sprintf("%v", v)
 90		}
 91	}
 92
 93	return tools.NewTextResponse(output), nil
 94}
 95
 96func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 97	sessionID, messageID := tools.GetContextValues(ctx)
 98	if sessionID == "" || messageID == "" {
 99		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
100	}
101	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
102	p := b.permissions.Request(
103		permission.CreatePermissionRequest{
104			SessionID:   sessionID,
105			Path:        b.workingDir,
106			ToolName:    b.Info().Name,
107			Action:      "execute",
108			Description: permissionDescription,
109			Params:      params.Input,
110		},
111	)
112	if !p {
113		return tools.ToolResponse{}, permission.ErrorPermissionDenied
114	}
115
116	switch b.mcpConfig.Type {
117	case config.MCPStdio:
118		c, err := client.NewStdioMCPClient(
119			b.mcpConfig.Command,
120			b.mcpConfig.ResolvedEnv(),
121			b.mcpConfig.Args...,
122		)
123		if err != nil {
124			return tools.NewTextErrorResponse(err.Error()), nil
125		}
126		return runTool(ctx, c, b.tool.Name, params.Input)
127	case config.MCPHttp:
128		c, err := client.NewStreamableHttpClient(
129			b.mcpConfig.URL,
130			transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
131		)
132		if err != nil {
133			return tools.NewTextErrorResponse(err.Error()), nil
134		}
135		return runTool(ctx, c, b.tool.Name, params.Input)
136	case config.MCPSse:
137		c, err := client.NewSSEMCPClient(
138			b.mcpConfig.URL,
139			client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
140		)
141		if err != nil {
142			return tools.NewTextErrorResponse(err.Error()), nil
143		}
144		return runTool(ctx, c, b.tool.Name, params.Input)
145	}
146
147	return tools.NewTextErrorResponse("invalid mcp type"), nil
148}
149
150func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
151	return &mcpTool{
152		mcpName:     name,
153		tool:        tool,
154		mcpConfig:   mcpConfig,
155		permissions: permissions,
156		workingDir:  workingDir,
157	}
158}
159
160func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
161	var stdioTools []tools.BaseTool
162	initRequest := mcp.InitializeRequest{}
163	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
164	initRequest.Params.ClientInfo = mcp.Implementation{
165		Name:    "Crush",
166		Version: version.Version,
167	}
168
169	_, err := c.Initialize(ctx, initRequest)
170	if err != nil {
171		slog.Error("error initializing mcp client", "error", err)
172		return stdioTools
173	}
174	toolsRequest := mcp.ListToolsRequest{}
175	tools, err := c.ListTools(ctx, toolsRequest)
176	if err != nil {
177		slog.Error("error listing tools", "error", err)
178		return stdioTools
179	}
180	for _, t := range tools.Tools {
181		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
182	}
183	defer c.Close()
184	return stdioTools
185}
186
187var (
188	mcpToolsOnce sync.Once
189	mcpTools     []tools.BaseTool
190)
191
192func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
193	mcpToolsOnce.Do(func() {
194		mcpTools = doGetMCPTools(ctx, permissions, cfg)
195	})
196	return mcpTools
197}
198
199func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
200	var wg sync.WaitGroup
201	result := csync.NewSlice[tools.BaseTool]()
202	for name, m := range cfg.MCP {
203		if m.Disabled {
204			slog.Debug("skipping disabled mcp", "name", name)
205			continue
206		}
207		wg.Add(1)
208		go func(name string, m config.MCPConfig) {
209			defer wg.Done()
210			switch m.Type {
211			case config.MCPStdio:
212				c, err := client.NewStdioMCPClient(
213					m.Command,
214					m.ResolvedEnv(),
215					m.Args...,
216				)
217				if err != nil {
218					slog.Error("error creating mcp client", "error", err)
219					return
220				}
221
222				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
223			case config.MCPHttp:
224				c, err := client.NewStreamableHttpClient(
225					m.URL,
226					transport.WithHTTPHeaders(m.ResolvedHeaders()),
227				)
228				if err != nil {
229					slog.Error("error creating mcp client", "error", err)
230					return
231				}
232				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
233			case config.MCPSse:
234				c, err := client.NewSSEMCPClient(
235					m.URL,
236					client.WithHeaders(m.ResolvedHeaders()),
237				)
238				if err != nil {
239					slog.Error("error creating mcp client", "error", err)
240					return
241				}
242				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
243			}
244		}(name, m)
245	}
246	wg.Wait()
247	return slices.Collect(result.Seq())
248}