mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7
  8	"github.com/charmbracelet/crush/internal/config"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/logging"
 11	"github.com/charmbracelet/crush/internal/permission"
 12	"github.com/charmbracelet/crush/internal/version"
 13
 14	"github.com/mark3labs/mcp-go/client"
 15	"github.com/mark3labs/mcp-go/mcp"
 16)
 17
 18type mcpTool struct {
 19	mcpName     string
 20	tool        mcp.Tool
 21	mcpConfig   config.MCPServer
 22	permissions permission.Service
 23}
 24
 25type MCPClient interface {
 26	Initialize(
 27		ctx context.Context,
 28		request mcp.InitializeRequest,
 29	) (*mcp.InitializeResult, error)
 30	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 31	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 32	Close() error
 33}
 34
 35func (b *mcpTool) Info() tools.ToolInfo {
 36	return tools.ToolInfo{
 37		Name:        fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
 38		Description: b.tool.Description,
 39		Parameters:  b.tool.InputSchema.Properties,
 40		Required:    b.tool.InputSchema.Required,
 41	}
 42}
 43
 44func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 45	defer c.Close()
 46	initRequest := mcp.InitializeRequest{}
 47	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 48	initRequest.Params.ClientInfo = mcp.Implementation{
 49		Name:    "Crush",
 50		Version: version.Version,
 51	}
 52
 53	_, err := c.Initialize(ctx, initRequest)
 54	if err != nil {
 55		return tools.NewTextErrorResponse(err.Error()), nil
 56	}
 57
 58	toolRequest := mcp.CallToolRequest{}
 59	toolRequest.Params.Name = toolName
 60	var args map[string]any
 61	if err = json.Unmarshal([]byte(input), &args); err != nil {
 62		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 63	}
 64	toolRequest.Params.Arguments = args
 65	result, err := c.CallTool(ctx, toolRequest)
 66	if err != nil {
 67		return tools.NewTextErrorResponse(err.Error()), nil
 68	}
 69
 70	output := ""
 71	for _, v := range result.Content {
 72		if v, ok := v.(mcp.TextContent); ok {
 73			output = v.Text
 74		} else {
 75			output = fmt.Sprintf("%v", v)
 76		}
 77	}
 78
 79	return tools.NewTextResponse(output), nil
 80}
 81
 82func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 83	sessionID, messageID := tools.GetContextValues(ctx)
 84	if sessionID == "" || messageID == "" {
 85		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 86	}
 87	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 88	p := b.permissions.Request(
 89		permission.CreatePermissionRequest{
 90			SessionID:   sessionID,
 91			Path:        config.WorkingDirectory(),
 92			ToolName:    b.Info().Name,
 93			Action:      "execute",
 94			Description: permissionDescription,
 95			Params:      params.Input,
 96		},
 97	)
 98	if !p {
 99		return tools.NewTextErrorResponse("permission denied"), nil
100	}
101
102	switch b.mcpConfig.Type {
103	case config.MCPStdio:
104		c, err := client.NewStdioMCPClient(
105			b.mcpConfig.Command,
106			b.mcpConfig.Env,
107			b.mcpConfig.Args...,
108		)
109		if err != nil {
110			return tools.NewTextErrorResponse(err.Error()), nil
111		}
112		return runTool(ctx, c, b.tool.Name, params.Input)
113	case config.MCPSse:
114		c, err := client.NewSSEMCPClient(
115			b.mcpConfig.URL,
116			client.WithHeaders(b.mcpConfig.Headers),
117		)
118		if err != nil {
119			return tools.NewTextErrorResponse(err.Error()), nil
120		}
121		return runTool(ctx, c, b.tool.Name, params.Input)
122	}
123
124	return tools.NewTextErrorResponse("invalid mcp type"), nil
125}
126
127func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPServer) tools.BaseTool {
128	return &mcpTool{
129		mcpName:     name,
130		tool:        tool,
131		mcpConfig:   mcpConfig,
132		permissions: permissions,
133	}
134}
135
136var mcpTools []tools.BaseTool
137
138func getTools(ctx context.Context, name string, m config.MCPServer, permissions permission.Service, c MCPClient) []tools.BaseTool {
139	var stdioTools []tools.BaseTool
140	initRequest := mcp.InitializeRequest{}
141	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
142	initRequest.Params.ClientInfo = mcp.Implementation{
143		Name:    "Crush",
144		Version: version.Version,
145	}
146
147	_, err := c.Initialize(ctx, initRequest)
148	if err != nil {
149		logging.Error("error initializing mcp client", "error", err)
150		return stdioTools
151	}
152	toolsRequest := mcp.ListToolsRequest{}
153	tools, err := c.ListTools(ctx, toolsRequest)
154	if err != nil {
155		logging.Error("error listing tools", "error", err)
156		return stdioTools
157	}
158	for _, t := range tools.Tools {
159		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
160	}
161	defer c.Close()
162	return stdioTools
163}
164
165func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
166	if len(mcpTools) > 0 {
167		return mcpTools
168	}
169	for name, m := range config.Get().MCPServers {
170		switch m.Type {
171		case config.MCPStdio:
172			c, err := client.NewStdioMCPClient(
173				m.Command,
174				m.Env,
175				m.Args...,
176			)
177			if err != nil {
178				logging.Error("error creating mcp client", "error", err)
179				continue
180			}
181
182			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
183		case config.MCPSse:
184			c, err := client.NewSSEMCPClient(
185				m.URL,
186				client.WithHeaders(m.Headers),
187			)
188			if err != nil {
189				logging.Error("error creating mcp client", "error", err)
190				continue
191			}
192			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
193		}
194	}
195
196	return mcpTools
197}