1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/csync"
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14
 15	"github.com/charmbracelet/crush/internal/permission"
 16	"github.com/charmbracelet/crush/internal/version"
 17
 18	"github.com/mark3labs/mcp-go/client"
 19	"github.com/mark3labs/mcp-go/client/transport"
 20	"github.com/mark3labs/mcp-go/mcp"
 21)
 22
 23type mcpTool struct {
 24	mcpName     string
 25	tool        mcp.Tool
 26	mcpConfig   config.MCPConfig
 27	permissions permission.Service
 28	workingDir  string
 29}
 30
 31type MCPClient interface {
 32	Initialize(
 33		ctx context.Context,
 34		request mcp.InitializeRequest,
 35	) (*mcp.InitializeResult, error)
 36	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 37	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 38	Close() error
 39}
 40
 41func (b *mcpTool) Name() string {
 42	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 43}
 44
 45func (b *mcpTool) Info() tools.ToolInfo {
 46	required := b.tool.InputSchema.Required
 47	if required == nil {
 48		required = make([]string, 0)
 49	}
 50	return tools.ToolInfo{
 51		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 52		Description: b.tool.Description,
 53		Parameters:  b.tool.InputSchema.Properties,
 54		Required:    required,
 55	}
 56}
 57
 58func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 59	defer c.Close()
 60	initRequest := mcp.InitializeRequest{}
 61	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 62	initRequest.Params.ClientInfo = mcp.Implementation{
 63		Name:    "Crush",
 64		Version: version.Version,
 65	}
 66
 67	_, err := c.Initialize(ctx, initRequest)
 68	if err != nil {
 69		return tools.NewTextErrorResponse(err.Error()), nil
 70	}
 71
 72	toolRequest := mcp.CallToolRequest{}
 73	toolRequest.Params.Name = toolName
 74	var args map[string]any
 75	if err = json.Unmarshal([]byte(input), &args); err != nil {
 76		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 77	}
 78	toolRequest.Params.Arguments = args
 79	result, err := c.CallTool(ctx, toolRequest)
 80	if err != nil {
 81		return tools.NewTextErrorResponse(err.Error()), nil
 82	}
 83
 84	output := ""
 85	for _, v := range result.Content {
 86		if v, ok := v.(mcp.TextContent); ok {
 87			output = v.Text
 88		} else {
 89			output = fmt.Sprintf("%v", v)
 90		}
 91	}
 92
 93	return tools.NewTextResponse(output), nil
 94}
 95
 96func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 97	sessionID, messageID := tools.GetContextValues(ctx)
 98	if sessionID == "" || messageID == "" {
 99		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
100	}
101	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
102	p := b.permissions.Request(
103		permission.CreatePermissionRequest{
104			SessionID:   sessionID,
105			ToolCallID:  params.ID,
106			Path:        b.workingDir,
107			ToolName:    b.Info().Name,
108			Action:      "execute",
109			Description: permissionDescription,
110			Params:      params.Input,
111		},
112	)
113	if !p {
114		return tools.ToolResponse{}, permission.ErrorPermissionDenied
115	}
116
117	switch b.mcpConfig.Type {
118	case config.MCPStdio:
119		c, err := client.NewStdioMCPClient(
120			b.mcpConfig.Command,
121			b.mcpConfig.ResolvedEnv(),
122			b.mcpConfig.Args...,
123		)
124		if err != nil {
125			return tools.NewTextErrorResponse(err.Error()), nil
126		}
127		return runTool(ctx, c, b.tool.Name, params.Input)
128	case config.MCPHttp:
129		c, err := client.NewStreamableHttpClient(
130			b.mcpConfig.URL,
131			transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
132		)
133		if err != nil {
134			return tools.NewTextErrorResponse(err.Error()), nil
135		}
136		return runTool(ctx, c, b.tool.Name, params.Input)
137	case config.MCPSse:
138		c, err := client.NewSSEMCPClient(
139			b.mcpConfig.URL,
140			client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
141		)
142		if err != nil {
143			return tools.NewTextErrorResponse(err.Error()), nil
144		}
145		return runTool(ctx, c, b.tool.Name, params.Input)
146	}
147
148	return tools.NewTextErrorResponse("invalid mcp type"), nil
149}
150
151func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
152	return &mcpTool{
153		mcpName:     name,
154		tool:        tool,
155		mcpConfig:   mcpConfig,
156		permissions: permissions,
157		workingDir:  workingDir,
158	}
159}
160
161func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
162	var stdioTools []tools.BaseTool
163	initRequest := mcp.InitializeRequest{}
164	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
165	initRequest.Params.ClientInfo = mcp.Implementation{
166		Name:    "Crush",
167		Version: version.Version,
168	}
169
170	_, err := c.Initialize(ctx, initRequest)
171	if err != nil {
172		slog.Error("error initializing mcp client", "error", err)
173		return stdioTools
174	}
175	toolsRequest := mcp.ListToolsRequest{}
176	tools, err := c.ListTools(ctx, toolsRequest)
177	if err != nil {
178		slog.Error("error listing tools", "error", err)
179		return stdioTools
180	}
181	for _, t := range tools.Tools {
182		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
183	}
184	defer c.Close()
185	return stdioTools
186}
187
188var (
189	mcpToolsOnce sync.Once
190	mcpTools     []tools.BaseTool
191)
192
193func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
194	mcpToolsOnce.Do(func() {
195		mcpTools = doGetMCPTools(ctx, permissions, cfg)
196	})
197	return mcpTools
198}
199
200func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
201	var wg sync.WaitGroup
202	result := csync.NewSlice[tools.BaseTool]()
203	for name, m := range cfg.MCP {
204		if m.Disabled {
205			slog.Debug("skipping disabled mcp", "name", name)
206			continue
207		}
208		wg.Add(1)
209		go func(name string, m config.MCPConfig) {
210			defer wg.Done()
211			switch m.Type {
212			case config.MCPStdio:
213				c, err := client.NewStdioMCPClient(
214					m.Command,
215					m.ResolvedEnv(),
216					m.Args...,
217				)
218				if err != nil {
219					slog.Error("error creating mcp client", "error", err)
220					return
221				}
222
223				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
224			case config.MCPHttp:
225				c, err := client.NewStreamableHttpClient(
226					m.URL,
227					transport.WithHTTPHeaders(m.ResolvedHeaders()),
228				)
229				if err != nil {
230					slog.Error("error creating mcp client", "error", err)
231					return
232				}
233				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
234			case config.MCPSse:
235				c, err := client.NewSSEMCPClient(
236					m.URL,
237					client.WithHeaders(m.ResolvedHeaders()),
238				)
239				if err != nil {
240					slog.Error("error creating mcp client", "error", err)
241					return
242				}
243				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
244			}
245		}(name, m)
246	}
247	wg.Wait()
248	return slices.Collect(result.Seq())
249}