mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7
  8	"github.com/charmbracelet/crush/internal/config"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/logging"
 11	"github.com/charmbracelet/crush/internal/permission"
 12	"github.com/charmbracelet/crush/internal/version"
 13
 14	"github.com/mark3labs/mcp-go/client"
 15	"github.com/mark3labs/mcp-go/mcp"
 16)
 17
 18type mcpTool struct {
 19	mcpName     string
 20	tool        mcp.Tool
 21	mcpConfig   config.MCP
 22	permissions permission.Service
 23}
 24
 25type MCPClient interface {
 26	Initialize(
 27		ctx context.Context,
 28		request mcp.InitializeRequest,
 29	) (*mcp.InitializeResult, error)
 30	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 31	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 32	Close() error
 33}
 34
 35func (b *mcpTool) Name() string {
 36	return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
 37}
 38
 39func (b *mcpTool) Info() tools.ToolInfo {
 40	required := b.tool.InputSchema.Required
 41	if required == nil {
 42		required = make([]string, 0)
 43	}
 44	return tools.ToolInfo{
 45		Name:        fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
 46		Description: b.tool.Description,
 47		Parameters:  b.tool.InputSchema.Properties,
 48		Required:    required,
 49	}
 50}
 51
 52func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 53	defer c.Close()
 54	initRequest := mcp.InitializeRequest{}
 55	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 56	initRequest.Params.ClientInfo = mcp.Implementation{
 57		Name:    "Crush",
 58		Version: version.Version,
 59	}
 60
 61	_, err := c.Initialize(ctx, initRequest)
 62	if err != nil {
 63		return tools.NewTextErrorResponse(err.Error()), nil
 64	}
 65
 66	toolRequest := mcp.CallToolRequest{}
 67	toolRequest.Params.Name = toolName
 68	var args map[string]any
 69	if err = json.Unmarshal([]byte(input), &args); err != nil {
 70		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 71	}
 72	toolRequest.Params.Arguments = args
 73	result, err := c.CallTool(ctx, toolRequest)
 74	if err != nil {
 75		return tools.NewTextErrorResponse(err.Error()), nil
 76	}
 77
 78	output := ""
 79	for _, v := range result.Content {
 80		if v, ok := v.(mcp.TextContent); ok {
 81			output = v.Text
 82		} else {
 83			output = fmt.Sprintf("%v", v)
 84		}
 85	}
 86
 87	return tools.NewTextResponse(output), nil
 88}
 89
 90func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 91	sessionID, messageID := tools.GetContextValues(ctx)
 92	if sessionID == "" || messageID == "" {
 93		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 94	}
 95	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 96	p := b.permissions.Request(
 97		permission.CreatePermissionRequest{
 98			SessionID:   sessionID,
 99			Path:        config.WorkingDirectory(),
100			ToolName:    b.Info().Name,
101			Action:      "execute",
102			Description: permissionDescription,
103			Params:      params.Input,
104		},
105	)
106	if !p {
107		return tools.NewTextErrorResponse("permission denied"), nil
108	}
109
110	switch b.mcpConfig.Type {
111	case config.MCPStdio:
112		c, err := client.NewStdioMCPClient(
113			b.mcpConfig.Command,
114			b.mcpConfig.Env,
115			b.mcpConfig.Args...,
116		)
117		if err != nil {
118			return tools.NewTextErrorResponse(err.Error()), nil
119		}
120		return runTool(ctx, c, b.tool.Name, params.Input)
121	case config.MCPSse:
122		c, err := client.NewSSEMCPClient(
123			b.mcpConfig.URL,
124			client.WithHeaders(b.mcpConfig.Headers),
125		)
126		if err != nil {
127			return tools.NewTextErrorResponse(err.Error()), nil
128		}
129		return runTool(ctx, c, b.tool.Name, params.Input)
130	}
131
132	return tools.NewTextErrorResponse("invalid mcp type"), nil
133}
134
135func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
136	return &mcpTool{
137		mcpName:     name,
138		tool:        tool,
139		mcpConfig:   mcpConfig,
140		permissions: permissions,
141	}
142}
143
144var mcpTools []tools.BaseTool
145
146func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
147	var stdioTools []tools.BaseTool
148	initRequest := mcp.InitializeRequest{}
149	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
150	initRequest.Params.ClientInfo = mcp.Implementation{
151		Name:    "Crush",
152		Version: version.Version,
153	}
154
155	_, err := c.Initialize(ctx, initRequest)
156	if err != nil {
157		logging.Error("error initializing mcp client", "error", err)
158		return stdioTools
159	}
160	toolsRequest := mcp.ListToolsRequest{}
161	tools, err := c.ListTools(ctx, toolsRequest)
162	if err != nil {
163		logging.Error("error listing tools", "error", err)
164		return stdioTools
165	}
166	for _, t := range tools.Tools {
167		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
168	}
169	defer c.Close()
170	return stdioTools
171}
172
173func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
174	if len(mcpTools) > 0 {
175		return mcpTools
176	}
177	for name, m := range config.Get().MCP {
178		switch m.Type {
179		case config.MCPStdio:
180			c, err := client.NewStdioMCPClient(
181				m.Command,
182				m.Env,
183				m.Args...,
184			)
185			if err != nil {
186				logging.Error("error creating mcp client", "error", err)
187				continue
188			}
189
190			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
191		case config.MCPSse:
192			c, err := client.NewSSEMCPClient(
193				m.URL,
194				client.WithHeaders(m.Headers),
195			)
196			if err != nil {
197				logging.Error("error creating mcp client", "error", err)
198				continue
199			}
200			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
201		}
202	}
203
204	return mcpTools
205}