1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"sync"
  9
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/llm/tools"
 12
 13	"github.com/charmbracelet/crush/internal/permission"
 14	"github.com/charmbracelet/crush/internal/version"
 15
 16	"github.com/mark3labs/mcp-go/client"
 17	"github.com/mark3labs/mcp-go/client/transport"
 18	"github.com/mark3labs/mcp-go/mcp"
 19)
 20
 21type mcpTool struct {
 22	mcpName     string
 23	tool        mcp.Tool
 24	mcpConfig   config.MCPConfig
 25	permissions permission.Service
 26	workingDir  string
 27}
 28
 29type MCPClient interface {
 30	Initialize(
 31		ctx context.Context,
 32		request mcp.InitializeRequest,
 33	) (*mcp.InitializeResult, error)
 34	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 35	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 36	Close() error
 37}
 38
 39func (b *mcpTool) Name() string {
 40	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 41}
 42
 43func (b *mcpTool) Info() tools.ToolInfo {
 44	required := b.tool.InputSchema.Required
 45	if required == nil {
 46		required = make([]string, 0)
 47	}
 48	return tools.ToolInfo{
 49		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 50		Description: b.tool.Description,
 51		Parameters:  b.tool.InputSchema.Properties,
 52		Required:    required,
 53	}
 54}
 55
 56func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 57	defer c.Close()
 58	initRequest := mcp.InitializeRequest{}
 59	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 60	initRequest.Params.ClientInfo = mcp.Implementation{
 61		Name:    "Crush",
 62		Version: version.Version,
 63	}
 64
 65	_, err := c.Initialize(ctx, initRequest)
 66	if err != nil {
 67		return tools.NewTextErrorResponse(err.Error()), nil
 68	}
 69
 70	toolRequest := mcp.CallToolRequest{}
 71	toolRequest.Params.Name = toolName
 72	var args map[string]any
 73	if err = json.Unmarshal([]byte(input), &args); err != nil {
 74		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 75	}
 76	toolRequest.Params.Arguments = args
 77	result, err := c.CallTool(ctx, toolRequest)
 78	if err != nil {
 79		return tools.NewTextErrorResponse(err.Error()), nil
 80	}
 81
 82	output := ""
 83	for _, v := range result.Content {
 84		if v, ok := v.(mcp.TextContent); ok {
 85			output = v.Text
 86		} else {
 87			output = fmt.Sprintf("%v", v)
 88		}
 89	}
 90
 91	return tools.NewTextResponse(output), nil
 92}
 93
 94func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 95	sessionID, messageID := tools.GetContextValues(ctx)
 96	if sessionID == "" || messageID == "" {
 97		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 98	}
 99	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
100	p := b.permissions.Request(
101		permission.CreatePermissionRequest{
102			SessionID:   sessionID,
103			Path:        b.workingDir,
104			ToolName:    b.Info().Name,
105			Action:      "execute",
106			Description: permissionDescription,
107			Params:      params.Input,
108		},
109	)
110	if !p {
111		return tools.ToolResponse{}, permission.ErrorPermissionDenied
112	}
113
114	switch b.mcpConfig.Type {
115	case config.MCPStdio:
116		c, err := client.NewStdioMCPClient(
117			b.mcpConfig.Command,
118			b.mcpConfig.ResolvedEnv(),
119			b.mcpConfig.Args...,
120		)
121		if err != nil {
122			return tools.NewTextErrorResponse(err.Error()), nil
123		}
124		return runTool(ctx, c, b.tool.Name, params.Input)
125	case config.MCPHttp:
126		c, err := client.NewStreamableHttpClient(
127			b.mcpConfig.URL,
128			transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
129		)
130		if err != nil {
131			return tools.NewTextErrorResponse(err.Error()), nil
132		}
133		return runTool(ctx, c, b.tool.Name, params.Input)
134	case config.MCPSse:
135		c, err := client.NewSSEMCPClient(
136			b.mcpConfig.URL,
137			client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
138		)
139		if err != nil {
140			return tools.NewTextErrorResponse(err.Error()), nil
141		}
142		return runTool(ctx, c, b.tool.Name, params.Input)
143	}
144
145	return tools.NewTextErrorResponse("invalid mcp type"), nil
146}
147
148func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
149	return &mcpTool{
150		mcpName:     name,
151		tool:        tool,
152		mcpConfig:   mcpConfig,
153		permissions: permissions,
154		workingDir:  workingDir,
155	}
156}
157
158func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
159	var stdioTools []tools.BaseTool
160	initRequest := mcp.InitializeRequest{}
161	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
162	initRequest.Params.ClientInfo = mcp.Implementation{
163		Name:    "Crush",
164		Version: version.Version,
165	}
166
167	_, err := c.Initialize(ctx, initRequest)
168	if err != nil {
169		slog.Error("error initializing mcp client", "error", err)
170		return stdioTools
171	}
172	toolsRequest := mcp.ListToolsRequest{}
173	tools, err := c.ListTools(ctx, toolsRequest)
174	if err != nil {
175		slog.Error("error listing tools", "error", err)
176		return stdioTools
177	}
178	for _, t := range tools.Tools {
179		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m, workingDir))
180	}
181	defer c.Close()
182	return stdioTools
183}
184
185var (
186	mcpToolsOnce sync.Once
187	mcpTools     []tools.BaseTool
188)
189
190func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
191	mcpToolsOnce.Do(func() {
192		mcpTools = doGetMCPTools(ctx, permissions, cfg)
193	})
194	return mcpTools
195}
196
197func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
198	var mu sync.Mutex
199	var wg sync.WaitGroup
200	var result []tools.BaseTool
201	for name, m := range cfg.MCP {
202		if m.Disabled {
203			slog.Debug("skipping disabled mcp", "name", name)
204			continue
205		}
206		wg.Add(1)
207		go func(name string, m config.MCPConfig) {
208			defer wg.Done()
209			switch m.Type {
210			case config.MCPStdio:
211				c, err := client.NewStdioMCPClient(
212					m.Command,
213					m.ResolvedEnv(),
214					m.Args...,
215				)
216				if err != nil {
217					slog.Error("error creating mcp client", "error", err)
218					return
219				}
220
221				mu.Lock()
222				result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
223				mu.Unlock()
224			case config.MCPHttp:
225				c, err := client.NewStreamableHttpClient(
226					m.URL,
227					transport.WithHTTPHeaders(m.ResolvedHeaders()),
228				)
229				if err != nil {
230					slog.Error("error creating mcp client", "error", err)
231					return
232				}
233				mu.Lock()
234				result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
235				mu.Unlock()
236			case config.MCPSse:
237				c, err := client.NewSSEMCPClient(
238					m.URL,
239					client.WithHeaders(m.ResolvedHeaders()),
240				)
241				if err != nil {
242					slog.Error("error creating mcp client", "error", err)
243					return
244				}
245				mu.Lock()
246				result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
247				mu.Unlock()
248			}
249		}(name, m)
250	}
251	wg.Wait()
252	return result
253}