mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7
  8	"github.com/charmbracelet/crush/internal/config"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/logging"
 11	"github.com/charmbracelet/crush/internal/permission"
 12	"github.com/charmbracelet/crush/internal/version"
 13
 14	"github.com/mark3labs/mcp-go/client"
 15	"github.com/mark3labs/mcp-go/client/transport"
 16	"github.com/mark3labs/mcp-go/mcp"
 17)
 18
 19type mcpTool struct {
 20	mcpName     string
 21	tool        mcp.Tool
 22	mcpConfig   config.MCPConfig
 23	permissions permission.Service
 24}
 25
 26type MCPClient interface {
 27	Initialize(
 28		ctx context.Context,
 29		request mcp.InitializeRequest,
 30	) (*mcp.InitializeResult, error)
 31	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 32	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 33	Close() error
 34}
 35
 36func (b *mcpTool) Name() string {
 37	return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
 38}
 39
 40func (b *mcpTool) Info() tools.ToolInfo {
 41	required := b.tool.InputSchema.Required
 42	if required == nil {
 43		required = make([]string, 0)
 44	}
 45	return tools.ToolInfo{
 46		Name:        fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
 47		Description: b.tool.Description,
 48		Parameters:  b.tool.InputSchema.Properties,
 49		Required:    required,
 50	}
 51}
 52
 53func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 54	defer c.Close()
 55	initRequest := mcp.InitializeRequest{}
 56	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 57	initRequest.Params.ClientInfo = mcp.Implementation{
 58		Name:    "Crush",
 59		Version: version.Version,
 60	}
 61
 62	_, err := c.Initialize(ctx, initRequest)
 63	if err != nil {
 64		return tools.NewTextErrorResponse(err.Error()), nil
 65	}
 66
 67	toolRequest := mcp.CallToolRequest{}
 68	toolRequest.Params.Name = toolName
 69	var args map[string]any
 70	if err = json.Unmarshal([]byte(input), &args); err != nil {
 71		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 72	}
 73	toolRequest.Params.Arguments = args
 74	result, err := c.CallTool(ctx, toolRequest)
 75	if err != nil {
 76		return tools.NewTextErrorResponse(err.Error()), nil
 77	}
 78
 79	output := ""
 80	for _, v := range result.Content {
 81		if v, ok := v.(mcp.TextContent); ok {
 82			output = v.Text
 83		} else {
 84			output = fmt.Sprintf("%v", v)
 85		}
 86	}
 87
 88	return tools.NewTextResponse(output), nil
 89}
 90
 91func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 92	sessionID, messageID := tools.GetContextValues(ctx)
 93	if sessionID == "" || messageID == "" {
 94		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 95	}
 96	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 97	p := b.permissions.Request(
 98		permission.CreatePermissionRequest{
 99			SessionID:   sessionID,
100			Path:        config.Get().WorkingDir(),
101			ToolName:    b.Info().Name,
102			Action:      "execute",
103			Description: permissionDescription,
104			Params:      params.Input,
105		},
106	)
107	if !p {
108		return tools.NewTextErrorResponse("permission denied"), nil
109	}
110
111	switch b.mcpConfig.Type {
112	case config.MCPStdio:
113		c, err := client.NewStdioMCPClient(
114			b.mcpConfig.Command,
115			b.mcpConfig.Env,
116			b.mcpConfig.Args...,
117		)
118		if err != nil {
119			return tools.NewTextErrorResponse(err.Error()), nil
120		}
121		return runTool(ctx, c, b.tool.Name, params.Input)
122	case config.MCPHttp:
123		c, err := client.NewStreamableHttpClient(
124			b.mcpConfig.URL,
125			transport.WithHTTPHeaders(b.mcpConfig.Headers),
126		)
127		if err != nil {
128			return tools.NewTextErrorResponse(err.Error()), nil
129		}
130		return runTool(ctx, c, b.tool.Name, params.Input)
131	case config.MCPSse:
132		c, err := client.NewSSEMCPClient(
133			b.mcpConfig.URL,
134			client.WithHeaders(b.mcpConfig.Headers),
135		)
136		if err != nil {
137			return tools.NewTextErrorResponse(err.Error()), nil
138		}
139		return runTool(ctx, c, b.tool.Name, params.Input)
140	}
141
142	return tools.NewTextErrorResponse("invalid mcp type"), nil
143}
144
145func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig) tools.BaseTool {
146	return &mcpTool{
147		mcpName:     name,
148		tool:        tool,
149		mcpConfig:   mcpConfig,
150		permissions: permissions,
151	}
152}
153
154var mcpTools []tools.BaseTool
155
156func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
157	var stdioTools []tools.BaseTool
158	initRequest := mcp.InitializeRequest{}
159	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
160	initRequest.Params.ClientInfo = mcp.Implementation{
161		Name:    "Crush",
162		Version: version.Version,
163	}
164
165	_, err := c.Initialize(ctx, initRequest)
166	if err != nil {
167		logging.Error("error initializing mcp client", "error", err)
168		return stdioTools
169	}
170	toolsRequest := mcp.ListToolsRequest{}
171	tools, err := c.ListTools(ctx, toolsRequest)
172	if err != nil {
173		logging.Error("error listing tools", "error", err)
174		return stdioTools
175	}
176	for _, t := range tools.Tools {
177		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
178	}
179	defer c.Close()
180	return stdioTools
181}
182
183func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
184	if len(mcpTools) > 0 {
185		return mcpTools
186	}
187	for name, m := range config.Get().MCP {
188		switch m.Type {
189		case config.MCPStdio:
190			c, err := client.NewStdioMCPClient(
191				m.Command,
192				m.Env,
193				m.Args...,
194			)
195			if err != nil {
196				logging.Error("error creating mcp client", "error", err)
197				continue
198			}
199
200			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
201		case config.MCPHttp:
202			c, err := client.NewStreamableHttpClient(
203				m.URL,
204				transport.WithHTTPHeaders(m.Headers),
205			)
206			if err != nil {
207				logging.Error("error creating mcp client", "error", err)
208				continue
209			}
210			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
211		case config.MCPSse:
212			c, err := client.NewSSEMCPClient(
213				m.URL,
214				client.WithHeaders(m.Headers),
215			)
216			if err != nil {
217				logging.Error("error creating mcp client", "error", err)
218				continue
219			}
220			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
221		}
222	}
223
224	return mcpTools
225}