mcp.go

  1package agent
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"sync"
 10
 11	"github.com/charmbracelet/crush/internal/csync"
 12	"github.com/charmbracelet/crush/internal/llm/tools"
 13	"github.com/charmbracelet/crush/internal/resolver"
 14	"github.com/charmbracelet/crush/internal/version"
 15
 16	"github.com/charmbracelet/crush/internal/permission"
 17
 18	"github.com/mark3labs/mcp-go/client"
 19	"github.com/mark3labs/mcp-go/client/transport"
 20	"github.com/mark3labs/mcp-go/mcp"
 21)
 22
 23type MCPType string
 24
 25const (
 26	MCPStdio MCPType = "stdio"
 27	MCPSse   MCPType = "sse"
 28	MCPHttp  MCPType = "http"
 29)
 30
 31type MCPConfig struct {
 32	Command  string            `json:"command,omitempty" `
 33	Env      map[string]string `json:"env,omitempty"`
 34	Args     []string          `json:"args,omitempty"`
 35	Type     MCPType           `json:"type"`
 36	URL      string            `json:"url,omitempty"`
 37	Disabled bool              `json:"disabled,omitempty"`
 38
 39	Headers map[string]string `json:"headers,omitempty"`
 40}
 41
 42type mcpTool struct {
 43	mcpName     string
 44	tool        mcp.Tool
 45	mcpConfig   MCPConfig
 46	permissions permission.Service
 47	workingDir  string
 48}
 49
 50type MCPClient interface {
 51	Initialize(
 52		ctx context.Context,
 53		request mcp.InitializeRequest,
 54	) (*mcp.InitializeResult, error)
 55	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
 56	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
 57	Close() error
 58}
 59
 60func (b *mcpTool) Name() string {
 61	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 62}
 63
 64func (b *mcpTool) Info() tools.ToolInfo {
 65	required := b.tool.InputSchema.Required
 66	if required == nil {
 67		required = make([]string, 0)
 68	}
 69	return tools.ToolInfo{
 70		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
 71		Description: b.tool.Description,
 72		Parameters:  b.tool.InputSchema.Properties,
 73		Required:    required,
 74	}
 75}
 76
 77func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
 78	defer c.Close()
 79	initRequest := mcp.InitializeRequest{}
 80	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
 81	initRequest.Params.ClientInfo = mcp.Implementation{
 82		Name:    "crush",
 83		Version: version.Version,
 84	}
 85
 86	_, err := c.Initialize(ctx, initRequest)
 87	if err != nil {
 88		return tools.NewTextErrorResponse(err.Error()), nil
 89	}
 90
 91	toolRequest := mcp.CallToolRequest{}
 92	toolRequest.Params.Name = toolName
 93	var args map[string]any
 94	if err = json.Unmarshal([]byte(input), &args); err != nil {
 95		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 96	}
 97	toolRequest.Params.Arguments = args
 98	result, err := c.CallTool(ctx, toolRequest)
 99	if err != nil {
100		return tools.NewTextErrorResponse(err.Error()), nil
101	}
102
103	output := ""
104	for _, v := range result.Content {
105		if v, ok := v.(mcp.TextContent); ok {
106			output = v.Text
107		} else {
108			output = fmt.Sprintf("%v", v)
109		}
110	}
111
112	return tools.NewTextResponse(output), nil
113}
114
115func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
116	sessionID, messageID := tools.GetContextValues(ctx)
117	if sessionID == "" || messageID == "" {
118		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
119	}
120	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
121	p := b.permissions.Request(
122		permission.CreatePermissionRequest{
123			SessionID:   sessionID,
124			ToolCallID:  params.ID,
125			Path:        b.workingDir,
126			ToolName:    b.Info().Name,
127			Action:      "execute",
128			Description: permissionDescription,
129			Params:      params.Input,
130		},
131	)
132	if !p {
133		return tools.ToolResponse{}, permission.ErrorPermissionDenied
134	}
135
136	switch b.mcpConfig.Type {
137	case MCPStdio:
138		c, err := client.NewStdioMCPClient(
139			b.mcpConfig.Command,
140			b.mcpConfig.ResolvedEnv(),
141			b.mcpConfig.Args...,
142		)
143		if err != nil {
144			return tools.NewTextErrorResponse(err.Error()), nil
145		}
146		return runTool(ctx, c, b.tool.Name, params.Input)
147	case MCPHttp:
148		c, err := client.NewStreamableHttpClient(
149			b.mcpConfig.URL,
150			transport.WithHTTPHeaders(b.mcpConfig.ResolvedHeaders()),
151		)
152		if err != nil {
153			return tools.NewTextErrorResponse(err.Error()), nil
154		}
155		return runTool(ctx, c, b.tool.Name, params.Input)
156	case MCPSse:
157		c, err := client.NewSSEMCPClient(
158			b.mcpConfig.URL,
159			client.WithHeaders(b.mcpConfig.ResolvedHeaders()),
160		)
161		if err != nil {
162			return tools.NewTextErrorResponse(err.Error()), nil
163		}
164		return runTool(ctx, c, b.tool.Name, params.Input)
165	}
166
167	return tools.NewTextErrorResponse("invalid mcp type"), nil
168}
169
170func NewMcpTool(name, cwd string, tool mcp.Tool, permissions permission.Service, mcpConfig MCPConfig) tools.BaseTool {
171	return &mcpTool{
172		mcpName:     name,
173		tool:        tool,
174		mcpConfig:   mcpConfig,
175		permissions: permissions,
176		workingDir:  cwd,
177	}
178}
179
180func getTools(ctx context.Context, cwd string, name string, m MCPConfig, permissions permission.Service, c MCPClient) []tools.BaseTool {
181	var stdioTools []tools.BaseTool
182	initRequest := mcp.InitializeRequest{}
183	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
184	initRequest.Params.ClientInfo = mcp.Implementation{
185		Name: "dreamlover",
186	}
187
188	_, err := c.Initialize(ctx, initRequest)
189	if err != nil {
190		slog.Error("error initializing mcp client", "error", err)
191		return stdioTools
192	}
193	toolsRequest := mcp.ListToolsRequest{}
194	tools, err := c.ListTools(ctx, toolsRequest)
195	if err != nil {
196		slog.Error("error listing tools", "error", err)
197		return stdioTools
198	}
199	for _, t := range tools.Tools {
200		stdioTools = append(stdioTools, NewMcpTool(name, cwd, t, permissions, m))
201	}
202	defer c.Close()
203	return stdioTools
204}
205
206var (
207	mcpToolsOnce sync.Once
208	mcpTools     []tools.BaseTool
209)
210
211func GetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool {
212	mcpToolsOnce.Do(func() {
213		mcpTools = doGetMCPTools(ctx, cwd, mcps, permissions)
214	})
215	return mcpTools
216}
217
218func doGetMCPTools(ctx context.Context, cwd string, mcps map[string]MCPConfig, permissions permission.Service) []tools.BaseTool {
219	var wg sync.WaitGroup
220	result := csync.NewSlice[tools.BaseTool]()
221	for name, m := range mcps {
222		if m.Disabled {
223			slog.Debug("skipping disabled mcp", "name", name)
224			continue
225		}
226		wg.Add(1)
227		go func(name string, m MCPConfig) {
228			defer wg.Done()
229			switch m.Type {
230			case MCPStdio:
231				c, err := client.NewStdioMCPClient(
232					m.Command,
233					m.ResolvedEnv(),
234					m.Args...,
235				)
236				if err != nil {
237					slog.Error("error creating mcp client", "error", err)
238					return
239				}
240
241				result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
242			case MCPHttp:
243				c, err := client.NewStreamableHttpClient(
244					m.URL,
245					transport.WithHTTPHeaders(m.ResolvedHeaders()),
246				)
247				if err != nil {
248					slog.Error("error creating mcp client", "error", err)
249					return
250				}
251				result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
252			case MCPSse:
253				c, err := client.NewSSEMCPClient(
254					m.URL,
255					client.WithHeaders(m.ResolvedHeaders()),
256				)
257				if err != nil {
258					slog.Error("error creating mcp client", "error", err)
259					return
260				}
261				result.Append(getTools(ctx, cwd, name, m, permissions, c)...)
262			}
263		}(name, m)
264	}
265	wg.Wait()
266	return slices.Collect(result.Seq())
267}
268
269func (m MCPConfig) ResolvedEnv() []string {
270	resolver := resolver.New()
271	for e, v := range m.Env {
272		var err error
273		m.Env[e], err = resolver.ResolveValue(v)
274		if err != nil {
275			slog.Error("error resolving environment variable", "error", err, "variable", e, "value", v)
276			continue
277		}
278	}
279
280	env := make([]string, 0, len(m.Env))
281	for k, v := range m.Env {
282		env = append(env, fmt.Sprintf("%s=%s", k, v))
283	}
284	return env
285}
286
287func (m MCPConfig) ResolvedHeaders() map[string]string {
288	resolver := resolver.New()
289	for e, v := range m.Headers {
290		var err error
291		m.Headers[e], err = resolver.ResolveValue(v)
292		if err != nil {
293			slog.Error("error resolving header variable", "error", err, "variable", e, "value", v)
294			continue
295		}
296	}
297	return m.Headers
298}