mcp-tools.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"sync"
  9
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/llm/tools"
 12
 13	"github.com/charmbracelet/crush/internal/permission"
 14	"github.com/charmbracelet/crush/internal/version"
 15
 16	"github.com/mark3labs/mcp-go/client"
 17	"github.com/mark3labs/mcp-go/client/transport"
 18	"github.com/mark3labs/mcp-go/mcp"
 19)
 20
 21type mcpTool struct {
 22	mcpName     string
 23	tool        mcp.Tool
 24	mcpConfig   config.MCPConfig
 25	permissions permission.Service
 26	workingDir  string
 27}
 28
 29type MCPClient interface {
 30	Initialize(
 31		ctx context.Context,
 32		request mcp.InitializeRequest,
 33	) (*mcp.InitializeResult, error)
 34	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 35	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 36	Close() error
 37}
 38
 39func (b *mcpTool) Name() string {
 40	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 41}
 42
 43func (b *mcpTool) Info() tools.ToolInfo {
 44	required := b.tool.InputSchema.Required
 45	if required == nil {
 46		required = make([]string, 0)
 47	}
 48	return tools.ToolInfo{
 49		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 50		Description: b.tool.Description,
 51		Parameters:  b.tool.InputSchema.Properties,
 52		Required:    required,
 53	}
 54}
 55
 56func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 57	defer c.Close()
 58	initRequest := mcp.InitializeRequest{}
 59	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 60	initRequest.Params.ClientInfo = mcp.Implementation{
 61		Name:    "Crush",
 62		Version: version.Version,
 63	}
 64
 65	_, err := c.Initialize(ctx, initRequest)
 66	if err != nil {
 67		return tools.NewTextErrorResponse(err.Error()), nil
 68	}
 69
 70	toolRequest := mcp.CallToolRequest{}
 71	toolRequest.Params.Name = toolName
 72	var args map[string]any
 73	if err = json.Unmarshal([]byte(input), &args); err != nil {
 74		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 75	}
 76	toolRequest.Params.Arguments = args
 77	result, err := c.CallTool(ctx, toolRequest)
 78	if err != nil {
 79		return tools.NewTextErrorResponse(err.Error()), nil
 80	}
 81
 82	output := ""
 83	for _, v := range result.Content {
 84		if v, ok := v.(mcp.TextContent); ok {
 85			output = v.Text
 86		} else {
 87			output = fmt.Sprintf("%v", v)
 88		}
 89	}
 90
 91	return tools.NewTextResponse(output), nil
 92}
 93
 94func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 95	sessionID, messageID := tools.GetContextValues(ctx)
 96	if sessionID == "" || messageID == "" {
 97		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 98	}
 99	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
100	p := b.permissions.Request(
101		permission.CreatePermissionRequest{
102			SessionID:   sessionID,
103			ToolCallID:  params.ID,
104			Path:        b.workingDir,
105			ToolName:    b.Info().Name,
106			Action:      "execute",
107			Description: permissionDescription,
108			Params:      params.Input,
109		},
110	)
111	if !p {
112		return tools.ToolResponse{}, permission.ErrorPermissionDenied
113	}
114
115	switch b.mcpConfig.Type {
116	case config.MCPStdio:
117		c, err := client.NewStdioMCPClient(
118			b.mcpConfig.Command,
119			b.mcpConfig.ResolvedEnv(),
120			b.mcpConfig.Args...,
121		)
122		if err != nil {
123			return tools.NewTextErrorResponse(err.Error()), nil
124		}
125		return runTool(ctx, c, b.tool.Name, params.Input)
126	case config.MCPHttp:
127		c, err := client.NewStreamableHttpClient(
128			b.mcpConfig.URL,
129			transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
130		)
131		if err != nil {
132			return tools.NewTextErrorResponse(err.Error()), nil
133		}
134		return runTool(ctx, c, b.tool.Name, params.Input)
135	case config.MCPSse:
136		c, err := client.NewSSEMCPClient(
137			b.mcpConfig.URL,
138			client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
139		)
140		if err != nil {
141			return tools.NewTextErrorResponse(err.Error()), nil
142		}
143		return runTool(ctx, c, b.tool.Name, params.Input)
144	}
145
146	return tools.NewTextErrorResponse("invalid mcp type"), nil
147}
148
149func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
150	return &mcpTool{
151		mcpName:     name,
152		tool:        tool,
153		mcpConfig:   mcpConfig,
154		permissions: permissions,
155		workingDir:  workingDir,
156	}
157}
158
159func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
160	var stdioTools []tools.BaseTool
161	initRequest := mcp.InitializeRequest{}
162	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
163	initRequest.Params.ClientInfo = mcp.Implementation{
164		Name:    "Crush",
165		Version: version.Version,
166	}
167
168	_, err := c.Initialize(ctx, initRequest)
169	if err != nil {
170		slog.Error("error initializing mcp client", "error", err)
171		return stdioTools
172	}
173	toolsRequest := mcp.ListToolsRequest{}
174	tools, err := c.ListTools(ctx, toolsRequest)
175	if err != nil {
176		slog.Error("error listing tools", "error", err)
177		return stdioTools
178	}
179	for _, t := range tools.Tools {
180		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
181	}
182	defer c.Close()
183	return stdioTools
184}
185
186var (
187	mcpToolsOnce sync.Once
188	mcpTools     []tools.BaseTool
189)
190
191func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
192	mcpToolsOnce.Do(func() {
193		mcpTools = doGetMCPTools(ctx, permissions, cfg)
194	})
195	return mcpTools
196}
197
198func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
199	var mu sync.Mutex
200	var wg sync.WaitGroup
201	var result []tools.BaseTool
202	for name, m := range cfg.MCP {
203		if m.Disabled {
204			slog.Debug("skipping disabled mcp", "name", name)
205			continue
206		}
207		wg.Add(1)
208		go func(name string, m config.MCPConfig) {
209			defer wg.Done()
210			switch m.Type {
211			case config.MCPStdio:
212				c, err := client.NewStdioMCPClient(
213					m.Command,
214					m.ResolvedEnv(),
215					m.Args...,
216				)
217				if err != nil {
218					slog.Error("error creating mcp client", "error", err)
219					return
220				}
221
222				mu.Lock()
223				result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
224				mu.Unlock()
225			case config.MCPHttp:
226				c, err := client.NewStreamableHttpClient(
227					m.URL,
228					transport.WithHTTPHeaders(m.ResolvedHeaders()),
229				)
230				if err != nil {
231					slog.Error("error creating mcp client", "error", err)
232					return
233				}
234				mu.Lock()
235				result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
236				mu.Unlock()
237			case config.MCPSse:
238				c, err := client.NewSSEMCPClient(
239					m.URL,
240					client.WithHeaders(m.ResolvedHeaders()),
241				)
242				if err != nil {
243					slog.Error("error creating mcp client", "error", err)
244					return
245				}
246				mu.Lock()
247				result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
248				mu.Unlock()
249			}
250		}(name, m)
251	}
252	wg.Wait()
253	return result
254}