mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11
 12	"github.com/charmbracelet/crush/internal/permission"
 13	"github.com/charmbracelet/crush/internal/version"
 14
 15	"github.com/mark3labs/mcp-go/client"
 16	"github.com/mark3labs/mcp-go/client/transport"
 17	"github.com/mark3labs/mcp-go/mcp"
 18)
 19
 20type mcpTool struct {
 21	mcpName     string
 22	tool        mcp.Tool
 23	mcpConfig   config.MCPConfig
 24	permissions permission.Service
 25	workingDir  string
 26}
 27
 28type MCPClient interface {
 29	Initialize(
 30		ctx context.Context,
 31		request mcp.InitializeRequest,
 32	) (*mcp.InitializeResult, error)
 33	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 34	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 35	Close() error
 36}
 37
 38func (b *mcpTool) Name() string {
 39	return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
 40}
 41
 42func (b *mcpTool) Info() tools.ToolInfo {
 43	required := b.tool.InputSchema.Required
 44	if required == nil {
 45		required = make([]string, 0)
 46	}
 47	return tools.ToolInfo{
 48		Name:        fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
 49		Description: b.tool.Description,
 50		Parameters:  b.tool.InputSchema.Properties,
 51		Required:    required,
 52	}
 53}
 54
 55func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 56	defer c.Close()
 57	initRequest := mcp.InitializeRequest{}
 58	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 59	initRequest.Params.ClientInfo = mcp.Implementation{
 60		Name:    "Crush",
 61		Version: version.Version,
 62	}
 63
 64	_, err := c.Initialize(ctx, initRequest)
 65	if err != nil {
 66		return tools.NewTextErrorResponse(err.Error()), nil
 67	}
 68
 69	toolRequest := mcp.CallToolRequest{}
 70	toolRequest.Params.Name = toolName
 71	var args map[string]any
 72	if err = json.Unmarshal([]byte(input), &args); err != nil {
 73		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 74	}
 75	toolRequest.Params.Arguments = args
 76	result, err := c.CallTool(ctx, toolRequest)
 77	if err != nil {
 78		return tools.NewTextErrorResponse(err.Error()), nil
 79	}
 80
 81	output := ""
 82	for _, v := range result.Content {
 83		if v, ok := v.(mcp.TextContent); ok {
 84			output = v.Text
 85		} else {
 86			output = fmt.Sprintf("%v", v)
 87		}
 88	}
 89
 90	return tools.NewTextResponse(output), nil
 91}
 92
 93func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 94	sessionID, messageID := tools.GetContextValues(ctx)
 95	if sessionID == "" || messageID == "" {
 96		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 97	}
 98	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 99	p := b.permissions.Request(
100		permission.CreatePermissionRequest{
101			SessionID:   sessionID,
102			Path:        b.workingDir,
103			ToolName:    b.Info().Name,
104			Action:      "execute",
105			Description: permissionDescription,
106			Params:      params.Input,
107		},
108	)
109	if !p {
110		return tools.NewTextErrorResponse("permission denied"), nil
111	}
112
113	switch b.mcpConfig.Type {
114	case config.MCPStdio:
115		c, err := client.NewStdioMCPClient(
116			b.mcpConfig.Command,
117			b.mcpConfig.Env,
118			b.mcpConfig.Args...,
119		)
120		if err != nil {
121			return tools.NewTextErrorResponse(err.Error()), nil
122		}
123		return runTool(ctx, c, b.tool.Name, params.Input)
124	case config.MCPHttp:
125		c, err := client.NewStreamableHttpClient(
126			b.mcpConfig.URL,
127			transport.WithHTTPHeaders(b.mcpConfig.Headers),
128		)
129		if err != nil {
130			return tools.NewTextErrorResponse(err.Error()), nil
131		}
132		return runTool(ctx, c, b.tool.Name, params.Input)
133	case config.MCPSse:
134		c, err := client.NewSSEMCPClient(
135			b.mcpConfig.URL,
136			client.WithHeaders(b.mcpConfig.Headers),
137		)
138		if err != nil {
139			return tools.NewTextErrorResponse(err.Error()), nil
140		}
141		return runTool(ctx, c, b.tool.Name, params.Input)
142	}
143
144	return tools.NewTextErrorResponse("invalid mcp type"), nil
145}
146
147func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
148	return &mcpTool{
149		mcpName:     name,
150		tool:        tool,
151		mcpConfig:   mcpConfig,
152		permissions: permissions,
153		workingDir:  workingDir,
154	}
155}
156
157var mcpTools []tools.BaseTool
158
159func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
160	var stdioTools []tools.BaseTool
161	initRequest := mcp.InitializeRequest{}
162	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
163	initRequest.Params.ClientInfo = mcp.Implementation{
164		Name:    "Crush",
165		Version: version.Version,
166	}
167
168	_, err := c.Initialize(ctx, initRequest)
169	if err != nil {
170		slog.Error("error initializing mcp client", "error", err)
171		return stdioTools
172	}
173	toolsRequest := mcp.ListToolsRequest{}
174	tools, err := c.ListTools(ctx, toolsRequest)
175	if err != nil {
176		slog.Error("error listing tools", "error", err)
177		return stdioTools
178	}
179	for _, t := range tools.Tools {
180		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
181	}
182	defer c.Close()
183	return stdioTools
184}
185
186func GetMcpTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
187	if len(mcpTools) > 0 {
188		return mcpTools
189	}
190	for name, m := range cfg.MCP {
191		switch m.Type {
192		case config.MCPStdio:
193			c, err := client.NewStdioMCPClient(
194				m.Command,
195				m.Env,
196				m.Args...,
197			)
198			if err != nil {
199				slog.Error("error creating mcp client", "error", err)
200				continue
201			}
202
203			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
204		case config.MCPHttp:
205			c, err := client.NewStreamableHttpClient(
206				m.URL,
207				transport.WithHTTPHeaders(m.Headers),
208			)
209			if err != nil {
210				slog.Error("error creating mcp client", "error", err)
211				continue
212			}
213			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
214		case config.MCPSse:
215			c, err := client.NewSSEMCPClient(
216				m.URL,
217				client.WithHeaders(m.Headers),
218			)
219			if err != nil {
220				slog.Error("error creating mcp client", "error", err)
221				continue
222			}
223			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
224		}
225	}
226
227	return mcpTools
228}