@@ -11,34 +11,28 @@ import (
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/llm/tools"
+ "github.com/charmbracelet/crush/internal/version"
"github.com/charmbracelet/crush/internal/permission"
- "github.com/charmbracelet/crush/internal/version"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)
+var (
+ mcpToolsOnce sync.Once
+ mcpTools []tools.BaseTool
+ mcpClients = csync.NewMap[string, *client.Client]()
+)
+
type McpTool struct {
mcpName string
tool mcp.Tool
- client MCPClient
- mcpConfig config.MCPConfig
permissions permission.Service
workingDir string
}
-type MCPClient interface {
- Initialize(
- ctx context.Context,
- request mcp.InitializeRequest,
- ) (*mcp.InitializeResult, error)
- ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
- CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
- Close() error
-}
-
func (b *McpTool) Name() string {
return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
}
@@ -56,27 +50,21 @@ func (b *McpTool) Info() tools.ToolInfo {
}
}
-func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
- initRequest := mcp.InitializeRequest{}
- initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
- initRequest.Params.ClientInfo = mcp.Implementation{
- Name: "Crush",
- Version: version.Version,
- }
-
- _, err := c.Initialize(ctx, initRequest)
- if err != nil {
- return tools.NewTextErrorResponse(err.Error()), nil
- }
-
- toolRequest := mcp.CallToolRequest{}
- toolRequest.Params.Name = toolName
+func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
var args map[string]any
- if err = json.Unmarshal([]byte(input), &args); err != nil {
+ if err := json.Unmarshal([]byte(input), &args); err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
- toolRequest.Params.Arguments = args
- result, err := c.CallTool(ctx, toolRequest)
+ c, ok := mcpClients.Get(name)
+ if !ok {
+ return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
+ }
+ result, err := c.CallTool(ctx, mcp.CallToolRequest{
+ Params: mcp.CallToolParams{
+ Name: toolName,
+ Arguments: args,
+ },
+ })
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
@@ -114,56 +102,34 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, permission.ErrorPermissionDenied
}
- return runTool(ctx, b.client, b.tool.Name, params.Input)
+ return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
}
-func NewMcpTool(name string, c MCPClient, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
- return &McpTool{
- mcpName: name,
- client: c,
- tool: tool,
- mcpConfig: mcpConfig,
- permissions: permissions,
- workingDir: workingDir,
- }
-}
-
-func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
- var stdioTools []tools.BaseTool
- initRequest := mcp.InitializeRequest{}
- initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
- initRequest.Params.ClientInfo = mcp.Implementation{
- Name: "Crush",
- Version: version.Version,
- }
-
- _, err := c.Initialize(ctx, initRequest)
- if err != nil {
- slog.Error("error initializing mcp client", "error", err)
- return stdioTools
- }
- toolsRequest := mcp.ListToolsRequest{}
- tools, err := c.ListTools(ctx, toolsRequest)
+func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
+ result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
slog.Error("error listing tools", "error", err)
- return stdioTools
+ c.Close()
+ mcpClients.Del(name)
+ return nil
}
- for _, t := range tools.Tools {
- stdioTools = append(stdioTools, NewMcpTool(name, c, t, permissions, m, workingDir))
+ mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
+ for _, tool := range result.Tools {
+ mcpTools = append(mcpTools, &McpTool{
+ mcpName: name,
+ tool: tool,
+ permissions: permissions,
+ workingDir: workingDir,
+ })
}
- return stdioTools
+ return mcpTools
}
-var (
- mcpToolsOnce sync.Once
- mcpTools []tools.BaseTool
-)
-
-func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
- mcpToolsOnce.Do(func() {
- mcpTools = doGetMCPTools(ctx, permissions, cfg)
- })
- return mcpTools
+// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
+func CloseMCPClients() {
+ for c := range mcpClients.Seq() {
+ _ = c.Close()
+ }
}
func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
@@ -177,42 +143,59 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
wg.Add(1)
go func(name string, m config.MCPConfig) {
defer wg.Done()
- switch m.Type {
- case config.MCPStdio:
- c, err := client.NewStdioMCPClient(
- m.Command,
- m.ResolvedEnv(),
- m.Args...,
- )
- if err != nil {
- slog.Error("error creating mcp client", "error", err)
- return
- }
-
- result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
- case config.MCPHttp:
- c, err := client.NewStreamableHttpClient(
- m.URL,
- transport.WithHTTPHeaders(m.ResolvedHeaders()),
- )
- if err != nil {
- slog.Error("error creating mcp client", "error", err)
- return
- }
- result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
- case config.MCPSse:
- c, err := client.NewSSEMCPClient(
- m.URL,
- client.WithHeaders(m.ResolvedHeaders()),
- )
- if err != nil {
- slog.Error("error creating mcp client", "error", err)
- return
- }
- result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
+ c, err := doGetClient(m)
+ if err != nil {
+ slog.Error("error creating mcp client", "error", err)
+ return
+ }
+ if err := doInitClient(ctx, name, c); err != nil {
+ slog.Error("error initializing mcp client", "error", err)
+ return
}
+ result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
}(name, m)
}
wg.Wait()
return slices.Collect(result.Seq())
}
+
+func doInitClient(ctx context.Context, name string, c *client.Client) error {
+ initRequest := mcp.InitializeRequest{
+ Params: mcp.InitializeParams{
+ ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
+ ClientInfo: mcp.Implementation{
+ Name: "Crush",
+ Version: version.Version,
+ },
+ },
+ }
+ if _, err := c.Initialize(ctx, initRequest); err != nil {
+ c.Close()
+ return err
+ }
+ mcpClients.Set(name, c)
+ return nil
+}
+
+func doGetClient(m config.MCPConfig) (*client.Client, error) {
+ switch m.Type {
+ case config.MCPStdio:
+ return client.NewStdioMCPClient(
+ m.Command,
+ m.ResolvedEnv(),
+ m.Args...,
+ )
+ case config.MCPHttp:
+ return client.NewStreamableHttpClient(
+ m.URL,
+ transport.WithHTTPHeaders(m.ResolvedHeaders()),
+ )
+ case config.MCPSse:
+ return client.NewSSEMCPClient(
+ m.URL,
+ client.WithHeaders(m.ResolvedHeaders()),
+ )
+ default:
+ return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
+ }
+}