1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10
 11	"github.com/charmbracelet/crush/internal/csync"
 12	"github.com/charmbracelet/crush/internal/llm/tools"
 13	"github.com/charmbracelet/crush/internal/resolver"
 14	"github.com/charmbracelet/crush/internal/version"
 15
 16	"github.com/charmbracelet/crush/internal/permission"
 17
 18	"github.com/mark3labs/mcp-go/client"
 19	"github.com/mark3labs/mcp-go/client/transport"
 20	"github.com/mark3labs/mcp-go/mcp"
 21)
 22
 23type MCPType string
 24
 25const (
 26	MCPStdio MCPType = "stdio"
 27	MCPSse   MCPType = "sse"
 28	MCPHttp  MCPType = "http"
 29)
 30
 31type MCPConfig struct {
 32	Command  string            `json:"command,omitempty" `
 33	Env      map[string]string `json:"env,omitempty"`
 34	Args     []string          `json:"args,omitempty"`
 35	Type     MCPType           `json:"type"`
 36	URL      string            `json:"url,omitempty"`
 37	Disabled bool              `json:"disabled,omitempty"`
 38
 39	Headers map[string]string `json:"headers,omitempty"`
 40}
 41
 42type mcpTool struct {
 43	mcpName     string
 44	tool        mcp.Tool
 45	mcpConfig   MCPConfig
 46	permissions permission.Service
 47	workingDir  string
 48}
 49
 50type MCPClient interface {
 51	Initialize(
 52		ctx context.Context,
 53		request mcp.InitializeRequest,
 54	) (*mcp.InitializeResult, error)
 55	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 56	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 57	Close() error
 58}
 59
 60func (b *mcpTool) Name() string {
 61	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 62}
 63
 64func (b *mcpTool) Info() tools.ToolInfo {
 65	required := b.tool.InputSchema.Required
 66	if required == nil {
 67		required = make([]string, 0)
 68	}
 69	return tools.ToolInfo{
 70		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 71		Description: b.tool.Description,
 72		Parameters:  b.tool.InputSchema.Properties,
 73		Required:    required,
 74	}
 75}
 76
 77func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 78	defer c.Close()
 79	initRequest := mcp.InitializeRequest{}
 80	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 81	initRequest.Params.ClientInfo = mcp.Implementation{
 82		Name:    "crush",
 83		Version: version.Version,
 84	}
 85
 86	_, err := c.Initialize(ctx, initRequest)
 87	if err != nil {
 88		return tools.NewTextErrorResponse(err.Error()), nil
 89	}
 90
 91	toolRequest := mcp.CallToolRequest{}
 92	toolRequest.Params.Name = toolName
 93	var args map[string]any
 94	if err = json.Unmarshal([]byte(input), &args); err != nil {
 95		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 96	}
 97	toolRequest.Params.Arguments = args
 98	result, err := c.CallTool(ctx, toolRequest)
 99	if err != nil {
100		return tools.NewTextErrorResponse(err.Error()), nil
101	}
102
103	output := ""
104	for _, v := range result.Content {
105		if v, ok := v.(mcp.TextContent); ok {
106			output = v.Text
107		} else {
108			output = fmt.Sprintf("%v", v)
109		}
110	}
111
112	return tools.NewTextResponse(output), nil
113}
114
115func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
116	sessionID, messageID := tools.GetContextValues(ctx)
117	if sessionID == "" || messageID == "" {
118		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
119	}
120	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
121	p := b.permissions.Request(
122		permission.CreatePermissionRequest{
123			SessionID:   sessionID,
124			ToolCallID:  params.ID,
125			Path:        b.workingDir,
126			ToolName:    b.Info().Name,
127			Action:      "execute",
128			Description: permissionDescription,
129			Params:      params.Input,
130		},
131	)
132	if !p {
133		return tools.ToolResponse{}, permission.ErrorPermissionDenied
134	}
135
136	switch b.mcpConfig.Type {
137	case MCPStdio:
138		c, err := client.NewStdioMCPClient(
139			b.mcpConfig.Command,
140			b.mcpConfig.ResolvedEnv(),
141			b.mcpConfig.Args...,
142		)
143		if err != nil {
144			return tools.NewTextErrorResponse(err.Error()), nil
145		}
146		return runTool(ctx, c, b.tool.Name, params.Input)
147	case MCPHttp:
148		c, err := client.NewStreamableHttpClient(
149			b.mcpConfig.URL,
150			transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
151		)
152		if err != nil {
153			return tools.NewTextErrorResponse(err.Error()), nil
154		}
155		return runTool(ctx, c, b.tool.Name, params.Input)
156	case MCPSse:
157		c, err := client.NewSSEMCPClient(
158			b.mcpConfig.URL,
159			client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
160		)
161		if err != nil {
162			return tools.NewTextErrorResponse(err.Error()), nil
163		}
164		return runTool(ctx, c, b.tool.Name, params.Input)
165	}
166
167	return tools.NewTextErrorResponse("invalid mcp type"), nil
168}
169
170func NewMcpTool(name, cwd string, tool mcp.Tool, permissions permission.Service, mcpConfig MCPConfig) tools.BaseTool {
171	return &mcpTool{
172		mcpName:     name,
173		tool:        tool,
174		mcpConfig:   mcpConfig,
175		permissions: permissions,
176		workingDir:  cwd,
177	}
178}
179
180func getTools(ctx context.Context, cwd string, name string, m MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
181	var stdioTools []tools.BaseTool
182	initRequest := mcp.InitializeRequest{}
183	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
184	initRequest.Params.ClientInfo = mcp.Implementation{
185		Name:    "crush",
186		Version: version.Version,
187	}
188
189	_, err := c.Initialize(ctx, initRequest)
190	if err != nil {
191		slog.Error("error initializing mcp client", "error", err)
192		return stdioTools
193	}
194	toolsRequest := mcp.ListToolsRequest{}
195	tools, err := c.ListTools(ctx, toolsRequest)
196	if err != nil {
197		slog.Error("error listing tools", "error", err)
198		return stdioTools
199	}
200	for _, t := range tools.Tools {
201		stdioTools = append(stdioTools, NewMcpTool(name, cwd, t, permissions, m))
202	}
203	defer c.Close()
204	return stdioTools
205}
206
207var (
208	mcpToolsOnce sync.Once
209	mcpTools     []tools.BaseTool
210)
211
212func GetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool {
213	mcpToolsOnce.Do(func() {
214		mcpTools = doGetMCPTools(ctx, cwd, mcps, permissions)
215	})
216	return mcpTools
217}
218
219func doGetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool {
220	var wg sync.WaitGroup
221	result := csync.NewSlice[tools.BaseTool]()
222	for name, m := range mcps {
223		if m.Disabled {
224			slog.Debug("skipping disabled mcp", "name", name)
225			continue
226		}
227		wg.Add(1)
228		go func(name string, m MCPConfig) {
229			defer wg.Done()
230			switch m.Type {
231			case MCPStdio:
232				c, err := client.NewStdioMCPClient(
233					m.Command,
234					m.ResolvedEnv(),
235					m.Args...,
236				)
237				if err != nil {
238					slog.Error("error creating mcp client", "error", err)
239					return
240				}
241
242				result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
243			case MCPHttp:
244				c, err := client.NewStreamableHttpClient(
245					m.URL,
246					transport.WithHTTPHeaders(m.ResolvedHeaders()),
247				)
248				if err != nil {
249					slog.Error("error creating mcp client", "error", err)
250					return
251				}
252				result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
253			case MCPSse:
254				c, err := client.NewSSEMCPClient(
255					m.URL,
256					client.WithHeaders(m.ResolvedHeaders()),
257				)
258				if err != nil {
259					slog.Error("error creating mcp client", "error", err)
260					return
261				}
262				result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
263			}
264		}(name, m)
265	}
266	wg.Wait()
267	return slices.Collect(result.Seq())
268}
269
270func (m MCPConfig) ResolvedEnv() []string {
271	resolver := resolver.New()
272	for e, v := range m.Env {
273		var err error
274		m.Env[e], err = resolver.ResolveValue(v)
275		if err != nil {
276			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
277			continue
278		}
279	}
280
281	env := make([]string, 0, len(m.Env))
282	for k, v := range m.Env {
283		env = append(env, fmt.Sprintf("%s=%s", k, v))
284	}
285	return env
286}
287
288func (m MCPConfig) ResolvedHeaders() map[string]string {
289	resolver := resolver.New()
290	for e, v := range m.Headers {
291		var err error
292		m.Headers[e], err = resolver.ResolveValue(v)
293		if err != nil {
294			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
295			continue
296		}
297	}
298	return m.Headers
299}