mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/llm/tools"
 11
 12	"github.com/charmbracelet/crush/internal/permission"
 13	"github.com/charmbracelet/crush/internal/version"
 14
 15	"github.com/mark3labs/mcp-go/client"
 16	"github.com/mark3labs/mcp-go/client/transport"
 17	"github.com/mark3labs/mcp-go/mcp"
 18)
 19
 20type mcpTool struct {
 21	mcpName     string
 22	tool        mcp.Tool
 23	mcpConfig   config.MCPConfig
 24	permissions permission.Service
 25}
 26
 27type MCPClient interface {
 28	Initialize(
 29		ctx context.Context,
 30		request mcp.InitializeRequest,
 31	) (*mcp.InitializeResult, error)
 32	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 33	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 34	Close() error
 35}
 36
 37func (b *mcpTool) Name() string {
 38	return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
 39}
 40
 41func (b *mcpTool) Info() tools.ToolInfo {
 42	required := b.tool.InputSchema.Required
 43	if required == nil {
 44		required = make([]string, 0)
 45	}
 46	return tools.ToolInfo{
 47		Name:        fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
 48		Description: b.tool.Description,
 49		Parameters:  b.tool.InputSchema.Properties,
 50		Required:    required,
 51	}
 52}
 53
 54func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 55	defer c.Close()
 56	initRequest := mcp.InitializeRequest{}
 57	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 58	initRequest.Params.ClientInfo = mcp.Implementation{
 59		Name:    "Crush",
 60		Version: version.Version,
 61	}
 62
 63	_, err := c.Initialize(ctx, initRequest)
 64	if err != nil {
 65		return tools.NewTextErrorResponse(err.Error()), nil
 66	}
 67
 68	toolRequest := mcp.CallToolRequest{}
 69	toolRequest.Params.Name = toolName
 70	var args map[string]any
 71	if err = json.Unmarshal([]byte(input), &args); err != nil {
 72		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 73	}
 74	toolRequest.Params.Arguments = args
 75	result, err := c.CallTool(ctx, toolRequest)
 76	if err != nil {
 77		return tools.NewTextErrorResponse(err.Error()), nil
 78	}
 79
 80	output := ""
 81	for _, v := range result.Content {
 82		if v, ok := v.(mcp.TextContent); ok {
 83			output = v.Text
 84		} else {
 85			output = fmt.Sprintf("%v", v)
 86		}
 87	}
 88
 89	return tools.NewTextResponse(output), nil
 90}
 91
 92func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 93	sessionID, messageID := tools.GetContextValues(ctx)
 94	if sessionID == "" || messageID == "" {
 95		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 96	}
 97	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 98	p := b.permissions.Request(
 99		permission.CreatePermissionRequest{
100			SessionID:   sessionID,
101			Path:        config.Get().WorkingDir(),
102			ToolName:    b.Info().Name,
103			Action:      "execute",
104			Description: permissionDescription,
105			Params:      params.Input,
106		},
107	)
108	if !p {
109		return tools.NewTextErrorResponse("permission denied"), nil
110	}
111
112	switch b.mcpConfig.Type {
113	case config.MCPStdio:
114		c, err := client.NewStdioMCPClient(
115			b.mcpConfig.Command,
116			b.mcpConfig.Env,
117			b.mcpConfig.Args...,
118		)
119		if err != nil {
120			return tools.NewTextErrorResponse(err.Error()), nil
121		}
122		return runTool(ctx, c, b.tool.Name, params.Input)
123	case config.MCPHttp:
124		c, err := client.NewStreamableHttpClient(
125			b.mcpConfig.URL,
126			transport.WithHTTPHeaders(b.mcpConfig.Headers),
127		)
128		if err != nil {
129			return tools.NewTextErrorResponse(err.Error()), nil
130		}
131		return runTool(ctx, c, b.tool.Name, params.Input)
132	case config.MCPSse:
133		c, err := client.NewSSEMCPClient(
134			b.mcpConfig.URL,
135			client.WithHeaders(b.mcpConfig.Headers),
136		)
137		if err != nil {
138			return tools.NewTextErrorResponse(err.Error()), nil
139		}
140		return runTool(ctx, c, b.tool.Name, params.Input)
141	}
142
143	return tools.NewTextErrorResponse("invalid mcp type"), nil
144}
145
146func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig) tools.BaseTool {
147	return &mcpTool{
148		mcpName:     name,
149		tool:        tool,
150		mcpConfig:   mcpConfig,
151		permissions: permissions,
152	}
153}
154
155var mcpTools []tools.BaseTool
156
157func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
158	var stdioTools []tools.BaseTool
159	initRequest := mcp.InitializeRequest{}
160	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
161	initRequest.Params.ClientInfo = mcp.Implementation{
162		Name:    "Crush",
163		Version: version.Version,
164	}
165
166	_, err := c.Initialize(ctx, initRequest)
167	if err != nil {
168		slog.Error("error initializing mcp client", "error", err)
169		return stdioTools
170	}
171	toolsRequest := mcp.ListToolsRequest{}
172	tools, err := c.ListTools(ctx, toolsRequest)
173	if err != nil {
174		slog.Error("error listing tools", "error", err)
175		return stdioTools
176	}
177	for _, t := range tools.Tools {
178		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
179	}
180	defer c.Close()
181	return stdioTools
182}
183
184func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
185	if len(mcpTools) > 0 {
186		return mcpTools
187	}
188	for name, m := range config.Get().MCP {
189		switch m.Type {
190		case config.MCPStdio:
191			c, err := client.NewStdioMCPClient(
192				m.Command,
193				m.Env,
194				m.Args...,
195			)
196			if err != nil {
197				slog.Error("error creating mcp client", "error", err)
198				continue
199			}
200
201			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
202		case config.MCPHttp:
203			c, err := client.NewStreamableHttpClient(
204				m.URL,
205				transport.WithHTTPHeaders(m.Headers),
206			)
207			if err != nil {
208				slog.Error("error creating mcp client", "error", err)
209				continue
210			}
211			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
212		case config.MCPSse:
213			c, err := client.NewSSEMCPClient(
214				m.URL,
215				client.WithHeaders(m.Headers),
216			)
217			if err != nil {
218				slog.Error("error creating mcp client", "error", err)
219				continue
220			}
221			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
222		}
223	}
224
225	return mcpTools
226}