1package agent
  2
  3import (
  4	"cmp"
  5	"context"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"slices"
 10	"strings"
 11	"sync"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/csync"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/permission"
 18	"github.com/charmbracelet/crush/internal/pubsub"
 19	"github.com/charmbracelet/crush/internal/version"
 20	"github.com/mark3labs/mcp-go/client"
 21	"github.com/mark3labs/mcp-go/client/transport"
 22	"github.com/mark3labs/mcp-go/mcp"
 23)
 24
 25// MCPState represents the current state of an MCP client
 26type MCPState int
 27
 28const (
 29	MCPStateDisabled MCPState = iota
 30	MCPStateStarting
 31	MCPStateConnected
 32	MCPStateError
 33)
 34
 35func (s MCPState) String() string {
 36	switch s {
 37	case MCPStateDisabled:
 38		return "disabled"
 39	case MCPStateStarting:
 40		return "starting"
 41	case MCPStateConnected:
 42		return "connected"
 43	case MCPStateError:
 44		return "error"
 45	default:
 46		return "unknown"
 47	}
 48}
 49
 50// MCPEventType represents the type of MCP event
 51type MCPEventType string
 52
 53const (
 54	MCPEventStateChanged MCPEventType = "state_changed"
 55)
 56
 57// MCPEvent represents an event in the MCP system
 58type MCPEvent struct {
 59	Type      MCPEventType
 60	Name      string
 61	State     MCPState
 62	Error     error
 63	ToolCount int
 64}
 65
 66// MCPClientInfo holds information about an MCP client's state
 67type MCPClientInfo struct {
 68	Name        string
 69	State       MCPState
 70	Error       error
 71	Client      *client.Client
 72	ToolCount   int
 73	ConnectedAt time.Time
 74}
 75
 76var (
 77	mcpToolsOnce sync.Once
 78	mcpTools     []tools.BaseTool
 79	mcpClients   = csync.NewMap[string, *client.Client]()
 80	mcpStates    = csync.NewMap[string, MCPClientInfo]()
 81	mcpBroker    = pubsub.NewBroker[MCPEvent]()
 82)
 83
 84type McpTool struct {
 85	mcpName     string
 86	tool        mcp.Tool
 87	permissions permission.Service
 88	workingDir  string
 89}
 90
 91func (b *McpTool) Name() string {
 92	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 93}
 94
 95func (b *McpTool) Info() tools.ToolInfo {
 96	required := b.tool.InputSchema.Required
 97	if required == nil {
 98		required = make([]string, 0)
 99	}
100	return tools.ToolInfo{
101		Name:        fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
102		Description: b.tool.Description,
103		Parameters:  b.tool.InputSchema.Properties,
104		Required:    required,
105	}
106}
107
108func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
109	var args map[string]any
110	if err := json.Unmarshal([]byte(input), &args); err != nil {
111		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
112	}
113	c, ok := mcpClients.Get(name)
114	if !ok {
115		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
116	}
117	result, err := c.CallTool(ctx, mcp.CallToolRequest{
118		Params: mcp.CallToolParams{
119			Name:      toolName,
120			Arguments: args,
121		},
122	})
123	if err != nil {
124		return tools.NewTextErrorResponse(err.Error()), nil
125	}
126
127	output := make([]string, 0, len(result.Content))
128	for _, v := range result.Content {
129		if v, ok := v.(mcp.TextContent); ok {
130			output = append(output, v.Text)
131		} else {
132			output = append(output, fmt.Sprintf("%v", v))
133		}
134	}
135	return tools.NewTextResponse(strings.Join(output, "\n")), nil
136}
137
138func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
139	sessionID, messageID := tools.GetContextValues(ctx)
140	if sessionID == "" || messageID == "" {
141		return tools.ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
142	}
143	permissionDescription := fmt.Sprintf("execute %s with the following parameters: %s", b.Info().Name, params.Input)
144	p := b.permissions.Request(
145		permission.CreatePermissionRequest{
146			SessionID:   sessionID,
147			ToolCallID:  params.ID,
148			Path:        b.workingDir,
149			ToolName:    b.Info().Name,
150			Action:      "execute",
151			Description: permissionDescription,
152			Params:      params.Input,
153		},
154	)
155	if !p {
156		return tools.ToolResponse{}, permission.ErrorPermissionDenied
157	}
158
159	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
160}
161
162func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
163	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
164	if err != nil {
165		slog.Error("error listing tools", "error", err)
166		updateMCPState(name, MCPStateError, err, nil, 0)
167		c.Close()
168		mcpClients.Del(name)
169		return nil
170	}
171	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
172	for _, tool := range result.Tools {
173		mcpTools = append(mcpTools, &McpTool{
174			mcpName:     name,
175			tool:        tool,
176			permissions: permissions,
177			workingDir:  workingDir,
178		})
179	}
180	return mcpTools
181}
182
183// SubscribeMCPEvents returns a channel for MCP events
184func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] {
185	return mcpBroker.Subscribe(ctx)
186}
187
188// GetMCPStates returns the current state of all MCP clients
189func GetMCPStates() map[string]MCPClientInfo {
190	states := make(map[string]MCPClientInfo)
191	for name, info := range mcpStates.Seq2() {
192		states[name] = info
193	}
194	return states
195}
196
197// GetMCPState returns the state of a specific MCP client
198func GetMCPState(name string) (MCPClientInfo, bool) {
199	return mcpStates.Get(name)
200}
201
202// updateMCPState updates the state of an MCP client and publishes an event
203func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
204	info := MCPClientInfo{
205		Name:      name,
206		State:     state,
207		Error:     err,
208		Client:    client,
209		ToolCount: toolCount,
210	}
211	if state == MCPStateConnected {
212		info.ConnectedAt = time.Now()
213	}
214	mcpStates.Set(name, info)
215
216	// Publish state change event
217	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
218		Type:      MCPEventStateChanged,
219		Name:      name,
220		State:     state,
221		Error:     err,
222		ToolCount: toolCount,
223	})
224}
225
226// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
227func CloseMCPClients() {
228	for c := range mcpClients.Seq() {
229		_ = c.Close()
230	}
231	mcpBroker.Shutdown()
232}
233
234var mcpInitRequest = mcp.InitializeRequest{
235	Params: mcp.InitializeParams{
236		ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
237		ClientInfo: mcp.Implementation{
238			Name:    "Crush",
239			Version: version.Version,
240		},
241	},
242}
243
244func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
245	var wg sync.WaitGroup
246	result := csync.NewSlice[tools.BaseTool]()
247
248	// Initialize states for all configured MCPs
249	for name, m := range cfg.MCP {
250		if m.Disabled {
251			updateMCPState(name, MCPStateDisabled, nil, nil, 0)
252			slog.Debug("skipping disabled mcp", "name", name)
253			continue
254		}
255
256		// Set initial starting state
257		updateMCPState(name, MCPStateStarting, nil, nil, 0)
258
259		wg.Add(1)
260		go func(name string, m config.MCPConfig) {
261			defer func() {
262				wg.Done()
263				if r := recover(); r != nil {
264					var err error
265					switch v := r.(type) {
266					case error:
267						err = v
268					case string:
269						err = fmt.Errorf("panic: %s", v)
270					default:
271						err = fmt.Errorf("panic: %v", v)
272					}
273					updateMCPState(name, MCPStateError, err, nil, 0)
274					slog.Error("panic in mcp client initialization", "error", err, "name", name)
275				}
276			}()
277
278			timeout := time.Duration(cmp.Or(m.Timeout, 15)) * time.Second
279			ctx, cancel := context.WithTimeout(ctx, timeout)
280			defer cancel()
281			c, err := createMcpClient(m)
282			if err != nil {
283				updateMCPState(name, MCPStateError, err, nil, 0)
284				slog.Error("error creating mcp client", "error", err, "name", name)
285				return
286			}
287			// Only call Start() for non-stdio clients, as stdio clients auto-start
288			if m.Type != config.MCPStdio {
289				if err := c.Start(ctx); err != nil {
290					updateMCPState(name, MCPStateError, err, nil, 0)
291					slog.Error("error starting mcp client", "error", err, "name", name)
292					_ = c.Close()
293					return
294				}
295			}
296			if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
297				updateMCPState(name, MCPStateError, err, nil, 0)
298				slog.Error("error initializing mcp client", "error", err, "name", name)
299				_ = c.Close()
300				return
301			}
302
303			slog.Info("Initialized mcp client", "name", name)
304			mcpClients.Set(name, c)
305
306			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
307			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
308			result.Append(tools...)
309		}(name, m)
310	}
311	wg.Wait()
312	return slices.Collect(result.Seq())
313}
314
315func createMcpClient(m config.MCPConfig) (*client.Client, error) {
316	switch m.Type {
317	case config.MCPStdio:
318		return client.NewStdioMCPClientWithOptions(
319			m.Command,
320			m.ResolvedEnv(),
321			m.Args,
322			transport.WithCommandLogger(mcpLogger{}),
323		)
324	case config.MCPHttp:
325		return client.NewStreamableHttpClient(
326			m.URL,
327			transport.WithHTTPHeaders(m.ResolvedHeaders()),
328			transport.WithHTTPLogger(mcpLogger{}),
329		)
330	case config.MCPSse:
331		return client.NewSSEMCPClient(
332			m.URL,
333			client.WithHeaders(m.ResolvedHeaders()),
334			transport.WithSSELogger(mcpLogger{}),
335		)
336	default:
337		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
338	}
339}
340
341// for MCP's clients.
342type mcpLogger struct{}
343
344func (l mcpLogger) Errorf(format string, v ...any) { slog.Error(fmt.Sprintf(format, v...)) }
345func (l mcpLogger) Infof(format string, v ...any)  { slog.Info(fmt.Sprintf(format, v...)) }