mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/csync"
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14
 15	"github.com/charmbracelet/crush/internal/permission"
 16	"github.com/charmbracelet/crush/internal/version"
 17
 18	"github.com/mark3labs/mcp-go/client"
 19	"github.com/mark3labs/mcp-go/client/transport"
 20	"github.com/mark3labs/mcp-go/mcp"
 21)
 22
 23type McpTool struct {
 24	mcpName     string
 25	tool        mcp.Tool
 26	client      MCPClient
 27	mcpConfig   config.MCPConfig
 28	permissions permission.Service
 29	workingDir  string
 30}
 31
 32type MCPClient interface {
 33	Initialize(
 34		ctx context.Context,
 35		request mcp.InitializeRequest,
 36	) (*mcp.InitializeResult, error)
 37	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 38	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 39	Close() error
 40}
 41
 42func (b *McpTool) Name() string {
 43	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 44}
 45
 46func (b *McpTool) Info() tools.ToolInfo {
 47	required := b.tool.InputSchema.Required
 48	if required == nil {
 49		required = make([]string, 0)
 50	}
 51	return tools.ToolInfo{
 52		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 53		Description: b.tool.Description,
 54		Parameters:  b.tool.InputSchema.Properties,
 55		Required:    required,
 56	}
 57}
 58
 59func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 60	initRequest := mcp.InitializeRequest{}
 61	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 62	initRequest.Params.ClientInfo = mcp.Implementation{
 63		Name:    "Crush",
 64		Version: version.Version,
 65	}
 66
 67	_, err := c.Initialize(ctx, initRequest)
 68	if err != nil {
 69		return tools.NewTextErrorResponse(err.Error()), nil
 70	}
 71
 72	toolRequest := mcp.CallToolRequest{}
 73	toolRequest.Params.Name = toolName
 74	var args map[string]any
 75	if err = json.Unmarshal([]byte(input), &args); err != nil {
 76		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 77	}
 78	toolRequest.Params.Arguments = args
 79	result, err := c.CallTool(ctx, toolRequest)
 80	if err != nil {
 81		return tools.NewTextErrorResponse(err.Error()), nil
 82	}
 83
 84	output := ""
 85	for _, v := range result.Content {
 86		if v, ok := v.(mcp.TextContent); ok {
 87			output = v.Text
 88		} else {
 89			output = fmt.Sprintf("%v", v)
 90		}
 91	}
 92
 93	return tools.NewTextResponse(output), nil
 94}
 95
 96func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 97	sessionID, messageID := tools.GetContextValues(ctx)
 98	if sessionID == "" || messageID == "" {
 99		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
100	}
101	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
102	p := b.permissions.Request(
103		permission.CreatePermissionRequest{
104			SessionID:   sessionID,
105			ToolCallID:  params.ID,
106			Path:        b.workingDir,
107			ToolName:    b.Info().Name,
108			Action:      "execute",
109			Description: permissionDescription,
110			Params:      params.Input,
111		},
112	)
113	if !p {
114		return tools.ToolResponse{}, permission.ErrorPermissionDenied
115	}
116
117	return runTool(ctx, b.client, b.tool.Name, params.Input)
118}
119
120func NewMcpTool(name string, c MCPClient, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
121	return &McpTool{
122		mcpName:     name,
123		client:      c,
124		tool:        tool,
125		mcpConfig:   mcpConfig,
126		permissions: permissions,
127		workingDir:  workingDir,
128	}
129}
130
131func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
132	var stdioTools []tools.BaseTool
133	initRequest := mcp.InitializeRequest{}
134	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
135	initRequest.Params.ClientInfo = mcp.Implementation{
136		Name:    "Crush",
137		Version: version.Version,
138	}
139
140	_, err := c.Initialize(ctx, initRequest)
141	if err != nil {
142		slog.Error("error initializing mcp client", "error", err)
143		return stdioTools
144	}
145	toolsRequest := mcp.ListToolsRequest{}
146	tools, err := c.ListTools(ctx, toolsRequest)
147	if err != nil {
148		slog.Error("error listing tools", "error", err)
149		return stdioTools
150	}
151	for _, t := range tools.Tools {
152		stdioTools = append(stdioTools, NewMcpTool(name, c, t, permissions, m, workingDir))
153	}
154	return stdioTools
155}
156
157var (
158	mcpToolsOnce sync.Once
159	mcpTools     []tools.BaseTool
160)
161
162func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
163	mcpToolsOnce.Do(func() {
164		mcpTools = doGetMCPTools(ctx, permissions, cfg)
165	})
166	return mcpTools
167}
168
169func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
170	var wg sync.WaitGroup
171	result := csync.NewSlice[tools.BaseTool]()
172	for name, m := range cfg.MCP {
173		if m.Disabled {
174			slog.Debug("skipping disabled mcp", "name", name)
175			continue
176		}
177		wg.Add(1)
178		go func(name string, m config.MCPConfig) {
179			defer wg.Done()
180			switch m.Type {
181			case config.MCPStdio:
182				c, err := client.NewStdioMCPClient(
183					m.Command,
184					m.ResolvedEnv(),
185					m.Args...,
186				)
187				if err != nil {
188					slog.Error("error creating mcp client", "error", err)
189					return
190				}
191
192				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
193			case config.MCPHttp:
194				c, err := client.NewStreamableHttpClient(
195					m.URL,
196					transport.WithHTTPHeaders(m.ResolvedHeaders()),
197				)
198				if err != nil {
199					slog.Error("error creating mcp client", "error", err)
200					return
201				}
202				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
203			case config.MCPSse:
204				c, err := client.NewSSEMCPClient(
205					m.URL,
206					client.WithHeaders(m.ResolvedHeaders()),
207				)
208				if err != nil {
209					slog.Error("error creating mcp client", "error", err)
210					return
211				}
212				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
213			}
214		}(name, m)
215	}
216	wg.Wait()
217	return slices.Collect(result.Seq())
218}