1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/charmbracelet/crush/internal/csync"
 13	"github.com/charmbracelet/crush/internal/llm/tools"
 14	"github.com/charmbracelet/crush/internal/version"
 15
 16	"github.com/charmbracelet/crush/internal/permission"
 17
 18	"github.com/mark3labs/mcp-go/client"
 19	"github.com/mark3labs/mcp-go/client/transport"
 20	"github.com/mark3labs/mcp-go/mcp"
 21)
 22
 23var (
 24	mcpToolsOnce sync.Once
 25	mcpTools     []tools.BaseTool
 26	mcpClients   = csync.NewMap[string, *client.Client]()
 27)
 28
 29type McpTool struct {
 30	mcpName     string
 31	tool        mcp.Tool
 32	permissions permission.Service
 33	workingDir  string
 34}
 35
 36func (b *McpTool) Name() string {
 37	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 38}
 39
 40func (b *McpTool) Info() tools.ToolInfo {
 41	required := b.tool.InputSchema.Required
 42	if required == nil {
 43		required = make([]string, 0)
 44	}
 45	return tools.ToolInfo{
 46		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 47		Description: b.tool.Description,
 48		Parameters:  b.tool.InputSchema.Properties,
 49		Required:    required,
 50	}
 51}
 52
 53func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
 54	var args map[string]any
 55	if err := json.Unmarshal([]byte(input), &args); err != nil {
 56		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 57	}
 58	c, ok := mcpClients.Get(name)
 59	if !ok {
 60		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
 61	}
 62	result, err := c.CallTool(ctx, mcp.CallToolRequest{
 63		Params: mcp.CallToolParams{
 64			Name:      toolName,
 65			Arguments: args,
 66		},
 67	})
 68	if err != nil {
 69		return tools.NewTextErrorResponse(err.Error()), nil
 70	}
 71
 72	output := ""
 73	for _, v := range result.Content {
 74		if v, ok := v.(mcp.TextContent); ok {
 75			output = v.Text
 76		} else {
 77			output = fmt.Sprintf("%v", v)
 78		}
 79	}
 80
 81	return tools.NewTextResponse(output), nil
 82}
 83
 84func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 85	sessionID, messageID := tools.GetContextValues(ctx)
 86	if sessionID == "" || messageID == "" {
 87		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 88	}
 89	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 90	p := b.permissions.Request(
 91		permission.CreatePermissionRequest{
 92			SessionID:   sessionID,
 93			ToolCallID:  params.ID,
 94			Path:        b.workingDir,
 95			ToolName:    b.Info().Name,
 96			Action:      "execute",
 97			Description: permissionDescription,
 98			Params:      params.Input,
 99		},
100	)
101	if !p {
102		return tools.ToolResponse{}, permission.ErrorPermissionDenied
103	}
104
105	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
106}
107
108func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
109	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
110	if err != nil {
111		slog.Error("error listing tools", "error", err)
112		c.Close()
113		mcpClients.Del(name)
114		return nil
115	}
116	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
117	for _, tool := range result.Tools {
118		mcpTools = append(mcpTools, &McpTool{
119			mcpName:     name,
120			tool:        tool,
121			permissions: permissions,
122			workingDir:  workingDir,
123		})
124	}
125	return mcpTools
126}
127
128// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
129func CloseMCPClients() {
130	for c := range mcpClients.Seq() {
131		_ = c.Close()
132	}
133}
134
135func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
136	var wg sync.WaitGroup
137	result := csync.NewSlice[tools.BaseTool]()
138	for name, m := range cfg.MCP {
139		if m.Disabled {
140			slog.Debug("skipping disabled mcp", "name", name)
141			continue
142		}
143		wg.Add(1)
144		go func(name string, m config.MCPConfig) {
145			defer wg.Done()
146			c, err := doGetClient(m)
147			if err != nil {
148				slog.Error("error creating mcp client", "error", err)
149				return
150			}
151			if err := doInitClient(ctx, name, c); err != nil {
152				slog.Error("error initializing mcp client", "error", err)
153				return
154			}
155			result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
156		}(name, m)
157	}
158	wg.Wait()
159	return slices.Collect(result.Seq())
160}
161
162func doInitClient(ctx context.Context, name string, c *client.Client) error {
163	initRequest := mcp.InitializeRequest{
164		Params: mcp.InitializeParams{
165			ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
166			ClientInfo: mcp.Implementation{
167				Name:    "Crush",
168				Version: version.Version,
169			},
170		},
171	}
172	if _, err := c.Initialize(ctx, initRequest); err != nil {
173		c.Close()
174		return err
175	}
176	mcpClients.Set(name, c)
177	return nil
178}
179
180func doGetClient(m config.MCPConfig) (*client.Client, error) {
181	switch m.Type {
182	case config.MCPStdio:
183		return client.NewStdioMCPClient(
184			m.Command,
185			m.ResolvedEnv(),
186			m.Args...,
187		)
188	case config.MCPHttp:
189		return client.NewStreamableHttpClient(
190			m.URL,
191			transport.WithHTTPHeaders(m.ResolvedHeaders()),
192		)
193	case config.MCPSse:
194		return client.NewSSEMCPClient(
195			m.URL,
196			client.WithHeaders(m.ResolvedHeaders()),
197		)
198	default:
199		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
200	}
201}