1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7
  8	"github.com/charmbracelet/crush/internal/config"
  9	"github.com/charmbracelet/crush/internal/llm/tools"
 10	"github.com/charmbracelet/crush/internal/logging"
 11	"github.com/charmbracelet/crush/internal/permission"
 12	"github.com/charmbracelet/crush/internal/version"
 13
 14	"github.com/mark3labs/mcp-go/client"
 15	"github.com/mark3labs/mcp-go/mcp"
 16)
 17
 18type mcpTool struct {
 19	mcpName     string
 20	tool        mcp.Tool
 21	mcpConfig   config.MCP
 22	permissions permission.Service
 23}
 24
 25type MCPClient interface {
 26	Initialize(
 27		ctx context.Context,
 28		request mcp.InitializeRequest,
 29	) (*mcp.InitializeResult, error)
 30	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 31	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 32	Close() error
 33}
 34
 35func (b *mcpTool) Name() string {
 36	return fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name)
 37}
 38
 39func (b *mcpTool) Info() tools.ToolInfo {
 40	return tools.ToolInfo{
 41		Name:        fmt.Sprintf("%s_%s", b.mcpName, b.tool.Name),
 42		Description: b.tool.Description,
 43		Parameters:  b.tool.InputSchema.Properties,
 44		Required:    b.tool.InputSchema.Required,
 45	}
 46}
 47
 48func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 49	defer c.Close()
 50	initRequest := mcp.InitializeRequest{}
 51	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 52	initRequest.Params.ClientInfo = mcp.Implementation{
 53		Name:    "Crush",
 54		Version: version.Version,
 55	}
 56
 57	_, err := c.Initialize(ctx, initRequest)
 58	if err != nil {
 59		return tools.NewTextErrorResponse(err.Error()), nil
 60	}
 61
 62	toolRequest := mcp.CallToolRequest{}
 63	toolRequest.Params.Name = toolName
 64	var args map[string]any
 65	if err = json.Unmarshal([]byte(input), &args); err != nil {
 66		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 67	}
 68	toolRequest.Params.Arguments = args
 69	result, err := c.CallTool(ctx, toolRequest)
 70	if err != nil {
 71		return tools.NewTextErrorResponse(err.Error()), nil
 72	}
 73
 74	output := ""
 75	for _, v := range result.Content {
 76		if v, ok := v.(mcp.TextContent); ok {
 77			output = v.Text
 78		} else {
 79			output = fmt.Sprintf("%v", v)
 80		}
 81	}
 82
 83	return tools.NewTextResponse(output), nil
 84}
 85
 86func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
 87	sessionID, messageID := tools.GetContextValues(ctx)
 88	if sessionID == "" || messageID == "" {
 89		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
 90	}
 91	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
 92	p := b.permissions.Request(
 93		permission.CreatePermissionRequest{
 94			SessionID:   sessionID,
 95			Path:        config.WorkingDirectory(),
 96			ToolName:    b.Info().Name,
 97			Action:      "execute",
 98			Description: permissionDescription,
 99			Params:      params.Input,
100		},
101	)
102	if !p {
103		return tools.NewTextErrorResponse("permission denied"), nil
104	}
105
106	switch b.mcpConfig.Type {
107	case config.MCPStdio:
108		c, err := client.NewStdioMCPClient(
109			b.mcpConfig.Command,
110			b.mcpConfig.Env,
111			b.mcpConfig.Args...,
112		)
113		if err != nil {
114			return tools.NewTextErrorResponse(err.Error()), nil
115		}
116		return runTool(ctx, c, b.tool.Name, params.Input)
117	case config.MCPSse:
118		c, err := client.NewSSEMCPClient(
119			b.mcpConfig.URL,
120			client.WithHeaders(b.mcpConfig.Headers),
121		)
122		if err != nil {
123			return tools.NewTextErrorResponse(err.Error()), nil
124		}
125		return runTool(ctx, c, b.tool.Name, params.Input)
126	}
127
128	return tools.NewTextErrorResponse("invalid mcp type"), nil
129}
130
131func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCP) tools.BaseTool {
132	return &mcpTool{
133		mcpName:     name,
134		tool:        tool,
135		mcpConfig:   mcpConfig,
136		permissions: permissions,
137	}
138}
139
140var mcpTools []tools.BaseTool
141
142func getTools(ctx context.Context, name string, m config.MCP, permissions permission.Service, c MCPClient) []tools.BaseTool {
143	var stdioTools []tools.BaseTool
144	initRequest := mcp.InitializeRequest{}
145	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
146	initRequest.Params.ClientInfo = mcp.Implementation{
147		Name:    "Crush",
148		Version: version.Version,
149	}
150
151	_, err := c.Initialize(ctx, initRequest)
152	if err != nil {
153		logging.Error("error initializing mcp client", "error", err)
154		return stdioTools
155	}
156	toolsRequest := mcp.ListToolsRequest{}
157	tools, err := c.ListTools(ctx, toolsRequest)
158	if err != nil {
159		logging.Error("error listing tools", "error", err)
160		return stdioTools
161	}
162	for _, t := range tools.Tools {
163		stdioTools = append(stdioTools, NewMcpTool(name, t, permissions, m))
164	}
165	defer c.Close()
166	return stdioTools
167}
168
169func GetMcpTools(ctx context.Context, permissions permission.Service) []tools.BaseTool {
170	if len(mcpTools) > 0 {
171		return mcpTools
172	}
173	for name, m := range config.Get().MCP {
174		switch m.Type {
175		case config.MCPStdio:
176			c, err := client.NewStdioMCPClient(
177				m.Command,
178				m.Env,
179				m.Args...,
180			)
181			if err != nil {
182				logging.Error("error creating mcp client", "error", err)
183				continue
184			}
185
186			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
187		case config.MCPSse:
188			c, err := client.NewSSEMCPClient(
189				m.URL,
190				client.WithHeaders(m.Headers),
191			)
192			if err != nil {
193				logging.Error("error creating mcp client", "error", err)
194				continue
195			}
196			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c)...)
197		}
198	}
199
200	return mcpTools
201}