From 4ca34bbcdd91a906069999a83f6c3ee0fcff2fa6 Mon Sep 17 00:00:00 2001 From: Carlos Alexandro Becker Date: Mon, 3 Nov 2025 11:02:48 -0300 Subject: [PATCH] feat(mcp): refactor, support prompts Signed-off-by: Carlos Alexandro Becker --- internal/agent/coordinator.go | 15 +- internal/agent/tools/mcp-tools.go | 525 ------------------ internal/agent/tools/mcp/init.go | 400 +++++++++++++ internal/agent/tools/mcp/prompts.go | 83 +++ internal/agent/tools/mcp/tools.go | 169 ++++++ internal/app/app.go | 11 +- .../components/dialogs/commands/arguments.go | 162 +++--- .../components/dialogs/commands/commands.go | 104 +++- .../tui/components/dialogs/commands/loader.go | 101 +++- internal/tui/components/mcp/mcp.go | 24 +- internal/tui/tui.go | 35 +- 11 files changed, 972 insertions(+), 657 deletions(-) delete mode 100644 internal/agent/tools/mcp-tools.go create mode 100644 internal/agent/tools/mcp/init.go create mode 100644 internal/agent/tools/mcp/prompts.go create mode 100644 internal/agent/tools/mcp/tools.go diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 3fd57c69d2b90ecf1477eccda399cbfceb3c0ef6..64d2c83a1bb68b32cd695a6ed411d25b746c12e6 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -18,6 +18,7 @@ import ( "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/agent/prompt" "github.com/charmbracelet/crush/internal/agent/tools" + "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/history" @@ -344,25 +345,23 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan } } - mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg) - - for _, mcpTool := range mcpTools { + for tool := range mcp.GetMCPTools() { if agent.AllowedMCP == nil { // No MCP restrictions - filteredTools = append(filteredTools, mcpTool) + filteredTools = append(filteredTools, tool) } else if len(agent.AllowedMCP) == 0 { // no mcps allowed break } for mcp, tools := range agent.AllowedMCP { - if mcp == mcpTool.MCP() { + if mcp == tool.MCP() { if len(tools) == 0 { - filteredTools = append(filteredTools, mcpTool) + filteredTools = append(filteredTools, tool) } for _, t := range tools { - if t == mcpTool.MCPToolName() { - filteredTools = append(filteredTools, mcpTool) + if t == tool.MCPToolName() { + filteredTools = append(filteredTools, tool) } } break diff --git a/internal/agent/tools/mcp-tools.go b/internal/agent/tools/mcp-tools.go deleted file mode 100644 index 79006de46a316a4aa7293dd6122afeb567c2d524..0000000000000000000000000000000000000000 --- a/internal/agent/tools/mcp-tools.go +++ /dev/null @@ -1,525 +0,0 @@ -package tools - -import ( - "cmp" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "maps" - "net/http" - "os" - "os/exec" - "slices" - "strings" - "sync" - "time" - - "charm.land/fantasy" - "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/home" - "github.com/charmbracelet/crush/internal/permission" - "github.com/charmbracelet/crush/internal/pubsub" - "github.com/charmbracelet/crush/internal/version" - "github.com/modelcontextprotocol/go-sdk/mcp" -) - -// MCPState represents the current state of an MCP client -type MCPState int - -const ( - MCPStateDisabled MCPState = iota - MCPStateStarting - MCPStateConnected - MCPStateError -) - -func (s MCPState) String() string { - switch s { - case MCPStateDisabled: - return "disabled" - case MCPStateStarting: - return "starting" - case MCPStateConnected: - return "connected" - case MCPStateError: - return "error" - default: - return "unknown" - } -} - -// MCPEventType represents the type of MCP event -type MCPEventType string - -const ( - MCPEventStateChanged MCPEventType = "state_changed" - MCPEventToolsListChanged MCPEventType = "tools_list_changed" -) - -// MCPEvent represents an event in the MCP system -type MCPEvent struct { - Type MCPEventType - Name string - State MCPState - Error error - ToolCount int -} - -// MCPClientInfo holds information about an MCP client's state -type MCPClientInfo struct { - Name string - State MCPState - Error error - Client *mcp.ClientSession - ToolCount int - ConnectedAt time.Time -} - -var ( - mcpToolsOnce sync.Once - mcpTools = csync.NewMap[string, *McpTool]() - mcpClient2Tools = csync.NewMap[string, []*McpTool]() - mcpClients = csync.NewMap[string, *mcp.ClientSession]() - mcpStates = csync.NewMap[string, MCPClientInfo]() - mcpBroker = pubsub.NewBroker[MCPEvent]() -) - -type McpTool struct { - mcpName string - tool *mcp.Tool - permissions permission.Service - workingDir string - providerOptions fantasy.ProviderOptions -} - -func (m *McpTool) SetProviderOptions(opts fantasy.ProviderOptions) { - m.providerOptions = opts -} - -func (m *McpTool) ProviderOptions() fantasy.ProviderOptions { - return m.providerOptions -} - -func (m *McpTool) Name() string { - return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name) -} - -func (m *McpTool) MCP() string { - return m.mcpName -} - -func (m *McpTool) MCPToolName() string { - return m.tool.Name -} - -func (b *McpTool) Info() fantasy.ToolInfo { - parameters := make(map[string]any) - required := make([]string, 0) - - if input, ok := b.tool.InputSchema.(map[string]any); ok { - if props, ok := input["properties"].(map[string]any); ok { - parameters = props - } - if req, ok := input["required"].([]any); ok { - // Convert []any -> []string when elements are strings - for _, v := range req { - if s, ok := v.(string); ok { - required = append(required, s) - } - } - } else if reqStr, ok := input["required"].([]string); ok { - // Handle case where it's already []string - required = reqStr - } - } - - return fantasy.ToolInfo{ - Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name), - Description: b.tool.Description, - Parameters: parameters, - Required: required, - } -} - -func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) { - var args map[string]any - if err := json.Unmarshal([]byte(input), &args); err != nil { - return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil - } - - c, err := getOrRenewClient(ctx, name) - if err != nil { - return fantasy.NewTextErrorResponse(err.Error()), nil - } - result, err := c.CallTool(ctx, &mcp.CallToolParams{ - Name: toolName, - Arguments: args, - }) - if err != nil { - return fantasy.NewTextErrorResponse(err.Error()), nil - } - - output := make([]string, 0, len(result.Content)) - for _, v := range result.Content { - if vv, ok := v.(*mcp.TextContent); ok { - output = append(output, vv.Text) - } else { - output = append(output, fmt.Sprintf("%v", v)) - } - } - return fantasy.NewTextResponse(strings.Join(output, "\n")), nil -} - -func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) { - sess, ok := mcpClients.Get(name) - if !ok { - return nil, fmt.Errorf("mcp '%s' not available", name) - } - - cfg := config.Get() - m := cfg.MCP[name] - state, _ := mcpStates.Get(name) - - timeout := mcpTimeout(m) - pingCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - err := sess.Ping(pingCtx, nil) - if err == nil { - return sess, nil - } - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount) - - sess, err = createMCPSession(ctx, name, m, cfg.Resolver()) - if err != nil { - return nil, err - } - - updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount) - mcpClients.Set(name, sess) - return sess, nil -} - -func (m *McpTool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) { - sessionID := GetSessionFromContext(ctx) - if sessionID == "" { - return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") - } - permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name) - p := m.permissions.Request( - permission.CreatePermissionRequest{ - SessionID: sessionID, - ToolCallID: params.ID, - Path: m.workingDir, - ToolName: m.Info().Name, - Action: "execute", - Description: permissionDescription, - Params: params.Input, - }, - ) - if !p { - return fantasy.ToolResponse{}, permission.ErrorPermissionDenied - } - - return runTool(ctx, m.mcpName, m.tool.Name, params.Input) -} - -func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*McpTool, error) { - result, err := c.ListTools(ctx, &mcp.ListToolsParams{}) - if err != nil { - return nil, err - } - mcpTools := make([]*McpTool, 0, len(result.Tools)) - for _, tool := range result.Tools { - mcpTools = append(mcpTools, &McpTool{ - mcpName: name, - tool: tool, - permissions: permissions, - workingDir: workingDir, - }) - } - return mcpTools, nil -} - -// SubscribeMCPEvents returns a channel for MCP events -func SubscribeMCPEvents(ctx context.Context) <-chan pubsub.Event[MCPEvent] { - return mcpBroker.Subscribe(ctx) -} - -// GetMCPStates returns the current state of all MCP clients -func GetMCPStates() map[string]MCPClientInfo { - return maps.Collect(mcpStates.Seq2()) -} - -// GetMCPState returns the state of a specific MCP client -func GetMCPState(name string) (MCPClientInfo, bool) { - return mcpStates.Get(name) -} - -// updateMCPState updates the state of an MCP client and publishes an event -func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) { - info := MCPClientInfo{ - Name: name, - State: state, - Error: err, - Client: client, - ToolCount: toolCount, - } - switch state { - case MCPStateConnected: - info.ConnectedAt = time.Now() - case MCPStateError: - updateMcpTools(name, nil) - mcpClients.Del(name) - } - mcpStates.Set(name, info) - - // Publish state change event - mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ - Type: MCPEventStateChanged, - Name: name, - State: state, - Error: err, - ToolCount: toolCount, - }) -} - -// CloseMCPClients closes all MCP clients. This should be called during application shutdown. -func CloseMCPClients() error { - var errs []error - for name, c := range mcpClients.Seq2() { - if err := c.Close(); err != nil && - !errors.Is(err, io.EOF) && - !errors.Is(err, context.Canceled) && - err.Error() != "signal: killed" { - errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err)) - } - } - mcpBroker.Shutdown() - return errors.Join(errs...) -} - -func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []*McpTool { - mcpToolsOnce.Do(func() { - var wg sync.WaitGroup - // Initialize states for all configured MCPs - for name, m := range cfg.MCP { - if m.Disabled { - updateMCPState(name, MCPStateDisabled, nil, nil, 0) - slog.Debug("skipping disabled mcp", "name", name) - continue - } - - // Set initial starting state - updateMCPState(name, MCPStateStarting, nil, nil, 0) - - wg.Add(1) - go func(name string, m config.MCPConfig) { - defer func() { - wg.Done() - if r := recover(); r != nil { - var err error - switch v := r.(type) { - case error: - err = v - case string: - err = fmt.Errorf("panic: %s", v) - default: - err = fmt.Errorf("panic: %v", v) - } - updateMCPState(name, MCPStateError, err, nil, 0) - slog.Error("panic in mcp client initialization", "error", err, "name", name) - } - }() - - ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) - defer cancel() - - c, err := createMCPSession(ctx, name, m, cfg.Resolver()) - if err != nil { - return - } - - mcpClients.Set(name, c) - - tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir()) - if err != nil { - slog.Error("error listing tools", "error", err) - updateMCPState(name, MCPStateError, err, nil, 0) - c.Close() - return - } - - updateMcpTools(name, tools) - mcpClients.Set(name, c) - updateMCPState(name, MCPStateConnected, nil, c, len(tools)) - }(name, m) - } - wg.Wait() - }) - return slices.Collect(mcpTools.Seq()) -} - -// updateMcpTools updates the global mcpTools and mcpClient2Tools maps -func updateMcpTools(mcpName string, tools []*McpTool) { - if len(tools) == 0 { - mcpClient2Tools.Del(mcpName) - } else { - mcpClient2Tools.Set(mcpName, tools) - } - for _, tools := range mcpClient2Tools.Seq2() { - for _, t := range tools { - mcpTools.Set(t.Name(), t) - } - } -} - -func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) { - timeout := mcpTimeout(m) - mcpCtx, cancel := context.WithCancel(ctx) - cancelTimer := time.AfterFunc(timeout, cancel) - - transport, err := createMCPTransport(mcpCtx, m, resolver) - if err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) - slog.Error("error creating mcp client", "error", err, "name", name) - cancel() - cancelTimer.Stop() - return nil, err - } - - client := mcp.NewClient( - &mcp.Implementation{ - Name: "crush", - Version: version.Version, - Title: "Crush", - }, - &mcp.ClientOptions{ - ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) { - mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ - Type: MCPEventToolsListChanged, - Name: name, - }) - }, - KeepAlive: time.Minute * 10, - }, - ) - - session, err := client.Connect(mcpCtx, transport, nil) - if err != nil { - err = maybeStdioErr(err, transport) - updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) - slog.Error("error starting mcp client", "error", err, "name", name) - cancel() - cancelTimer.Stop() - return nil, err - } - - cancelTimer.Stop() - slog.Info("Initialized mcp client", "name", name) - return session, nil -} - -// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail -// to parse, and the cli will then close it, causing the EOF error. -// so, if we got an EOF err, and the transport is STDIO, we try to exec it -// again with a timeout and collect the output so we can add details to the -// error. -// this happens particularly when starting things with npx, e.g. if node can't -// be found or some other error like that. -func maybeStdioErr(err error, transport mcp.Transport) error { - if !errors.Is(err, io.EOF) { - return err - } - ct, ok := transport.(*mcp.CommandTransport) - if !ok { - return err - } - if err2 := stdioMCPCheck(ct.Command); err2 != nil { - err = errors.Join(err, err2) - } - return err -} - -func maybeTimeoutErr(err error, timeout time.Duration) error { - if errors.Is(err, context.Canceled) { - return fmt.Errorf("timed out after %s", timeout) - } - return err -} - -func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) { - switch m.Type { - case config.MCPStdio: - command, err := resolver.ResolveValue(m.Command) - if err != nil { - return nil, fmt.Errorf("invalid mcp command: %w", err) - } - if strings.TrimSpace(command) == "" { - return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field") - } - cmd := exec.CommandContext(ctx, home.Long(command), m.Args...) - cmd.Env = append(os.Environ(), m.ResolvedEnv()...) - return &mcp.CommandTransport{ - Command: cmd, - }, nil - case config.MCPHttp: - if strings.TrimSpace(m.URL) == "" { - return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field") - } - client := &http.Client{ - Transport: &headerRoundTripper{ - headers: m.ResolvedHeaders(), - }, - } - return &mcp.StreamableClientTransport{ - Endpoint: m.URL, - HTTPClient: client, - }, nil - case config.MCPSSE: - if strings.TrimSpace(m.URL) == "" { - return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field") - } - client := &http.Client{ - Transport: &headerRoundTripper{ - headers: m.ResolvedHeaders(), - }, - } - return &mcp.SSEClientTransport{ - Endpoint: m.URL, - HTTPClient: client, - }, nil - default: - return nil, fmt.Errorf("unsupported mcp type: %s", m.Type) - } -} - -type headerRoundTripper struct { - headers map[string]string -} - -func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - for k, v := range rt.headers { - req.Header.Set(k, v) - } - return http.DefaultTransport.RoundTrip(req) -} - -func mcpTimeout(m config.MCPConfig) time.Duration { - return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second -} - -func stdioMCPCheck(old *exec.Cmd) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - cmd := exec.CommandContext(ctx, old.Path, old.Args...) - cmd.Env = old.Env - out, err := cmd.CombinedOutput() - if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) { - return nil - } - return fmt.Errorf("%w: %s", err, string(out)) -} diff --git a/internal/agent/tools/mcp/init.go b/internal/agent/tools/mcp/init.go new file mode 100644 index 0000000000000000000000000000000000000000..600c47517fc66d22b416c5362cb6019a81e59247 --- /dev/null +++ b/internal/agent/tools/mcp/init.go @@ -0,0 +1,400 @@ +package mcp + +import ( + "cmp" + "context" + "errors" + "fmt" + "io" + "log/slog" + "maps" + "net/http" + "os" + "os/exec" + "strings" + "sync" + "time" + + "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/pubsub" + "github.com/charmbracelet/crush/internal/version" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + sessions = csync.NewMap[string, *mcp.ClientSession]() + states = csync.NewMap[string, ClientInfo]() + broker = pubsub.NewBroker[Event]() +) + +// State represents the current state of an MCP client +type State int + +const ( + StateDisabled State = iota + StateStarting + StateConnected + StateError +) + +func (s State) String() string { + switch s { + case StateDisabled: + return "disabled" + case StateStarting: + return "starting" + case StateConnected: + return "connected" + case StateError: + return "error" + default: + return "unknown" + } +} + +// EventType represents the type of MCP event +type EventType string + +const ( + EventStateChanged EventType = "state_changed" + EventToolsListChanged EventType = "tools_list_changed" + EventPromptsListChanged EventType = "prompts_list_changed" +) + +// Event represents an event in the MCP system +type Event struct { + Type EventType + Name string + State State + Error error + Counts Counts +} + +// Counts number of available tools, prompts, etc. +type Counts struct { + Tools int + Prompts int +} + +// ClientInfo holds information about an MCP client's state +type ClientInfo struct { + Name string + State State + Error error + Client *mcp.ClientSession + Counts Counts + ConnectedAt time.Time +} + +// SubscribeEvents returns a channel for MCP events +func SubscribeEvents(ctx context.Context) <-chan pubsub.Event[Event] { + return broker.Subscribe(ctx) +} + +// GetStates returns the current state of all MCP clients +func GetStates() map[string]ClientInfo { + return maps.Collect(states.Seq2()) +} + +// GetState returns the state of a specific MCP client +func GetState(name string) (ClientInfo, bool) { + return states.Get(name) +} + +// Close closes all MCP clients. This should be called during application shutdown. +func Close() error { + var errs []error + for name, c := range sessions.Seq2() { + if err := c.Close(); err != nil && + !errors.Is(err, io.EOF) && + !errors.Is(err, context.Canceled) && + err.Error() != "signal: killed" { + errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err)) + } + } + broker.Shutdown() + return errors.Join(errs...) +} + +// Initialize initializes MCP clients based on the provided configuration. +func Initialize(ctx context.Context, permissions permission.Service, cfg *config.Config) { + var wg sync.WaitGroup + // Initialize states for all configured MCPs + for name, m := range cfg.MCP { + if m.Disabled { + updateState(name, StateDisabled, nil, nil, Counts{}) + slog.Debug("skipping disabled mcp", "name", name) + continue + } + + // Set initial starting state + updateState(name, StateStarting, nil, nil, Counts{}) + + wg.Add(1) + go func(name string, m config.MCPConfig) { + defer func() { + wg.Done() + if r := recover(); r != nil { + var err error + switch v := r.(type) { + case error: + err = v + case string: + err = fmt.Errorf("panic: %s", v) + default: + err = fmt.Errorf("panic: %v", v) + } + updateState(name, StateError, err, nil, Counts{}) + slog.Error("panic in mcp client initialization", "error", err, "name", name) + } + }() + + ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) + defer cancel() + + session, err := createSession(ctx, name, m, cfg.Resolver()) + if err != nil { + return + } + + tools, err := getTools(ctx, name, permissions, session, cfg.WorkingDir()) + if err != nil { + slog.Error("error listing tools", "error", err) + updateState(name, StateError, err, nil, Counts{}) + session.Close() + return + } + + prompts, err := getPrompts(ctx, session) + if err != nil { + slog.Error("error listing prompts", "error", err) + updateState(name, StateError, err, nil, Counts{}) + session.Close() + return + } + + updateTools(name, tools) + updatePrompts(name, prompts) + sessions.Set(name, session) + + updateState(name, StateConnected, nil, session, Counts{ + Tools: len(tools), + Prompts: len(prompts), + }) + }(name, m) + } + wg.Wait() +} + +func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) { + sess, ok := sessions.Get(name) + if !ok { + return nil, fmt.Errorf("mcp '%s' not available", name) + } + + cfg := config.Get() + m := cfg.MCP[name] + state, _ := states.Get(name) + + timeout := mcpTimeout(m) + pingCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + err := sess.Ping(pingCtx, nil) + if err == nil { + return sess, nil + } + updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, state.Counts) + + sess, err = createSession(ctx, name, m, cfg.Resolver()) + if err != nil { + return nil, err + } + + updateState(name, StateConnected, nil, sess, state.Counts) + sessions.Set(name, sess) + return sess, nil +} + +// updateState updates the state of an MCP client and publishes an event +func updateState(name string, state State, err error, client *mcp.ClientSession, counts Counts) { + info := ClientInfo{ + Name: name, + State: state, + Error: err, + Client: client, + Counts: counts, + } + switch state { + case StateConnected: + info.ConnectedAt = time.Now() + case StateError: + updateTools(name, nil) + sessions.Del(name) + } + states.Set(name, info) + + // Publish state change event + broker.Publish(pubsub.UpdatedEvent, Event{ + Type: EventStateChanged, + Name: name, + State: state, + Error: err, + Counts: counts, + }) +} + +func createSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) { + timeout := mcpTimeout(m) + mcpCtx, cancel := context.WithCancel(ctx) + cancelTimer := time.AfterFunc(timeout, cancel) + + transport, err := createTransport(mcpCtx, m, resolver) + if err != nil { + updateState(name, StateError, err, nil, Counts{}) + slog.Error("error creating mcp client", "error", err, "name", name) + cancel() + cancelTimer.Stop() + return nil, err + } + + client := mcp.NewClient( + &mcp.Implementation{ + Name: "crush", + Version: version.Version, + Title: "Crush", + }, + &mcp.ClientOptions{ + ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) { + broker.Publish(pubsub.UpdatedEvent, Event{ + Type: EventToolsListChanged, + Name: name, + }) + }, + PromptListChangedHandler: func(context.Context, *mcp.PromptListChangedRequest) { + broker.Publish(pubsub.UpdatedEvent, Event{ + Type: EventPromptsListChanged, + Name: name, + }) + }, + KeepAlive: time.Minute * 10, + }, + ) + + session, err := client.Connect(mcpCtx, transport, nil) + if err != nil { + err = maybeStdioErr(err, transport) + updateState(name, StateError, maybeTimeoutErr(err, timeout), nil, Counts{}) + slog.Error("error starting mcp client", "error", err, "name", name) + cancel() + cancelTimer.Stop() + return nil, err + } + + cancelTimer.Stop() + slog.Info("Initialized mcp client", "name", name) + return session, nil +} + +// maybeStdioErr if a stdio mcp prints an error in non-json format, it'll fail +// to parse, and the cli will then close it, causing the EOF error. +// so, if we got an EOF err, and the transport is STDIO, we try to exec it +// again with a timeout and collect the output so we can add details to the +// error. +// this happens particularly when starting things with npx, e.g. if node can't +// be found or some other error like that. +func maybeStdioErr(err error, transport mcp.Transport) error { + if !errors.Is(err, io.EOF) { + return err + } + ct, ok := transport.(*mcp.CommandTransport) + if !ok { + return err + } + if err2 := stdioCheck(ct.Command); err2 != nil { + err = errors.Join(err, err2) + } + return err +} + +func maybeTimeoutErr(err error, timeout time.Duration) error { + if errors.Is(err, context.Canceled) { + return fmt.Errorf("timed out after %s", timeout) + } + return err +} + +func createTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) { + switch m.Type { + case config.MCPStdio: + command, err := resolver.ResolveValue(m.Command) + if err != nil { + return nil, fmt.Errorf("invalid mcp command: %w", err) + } + if strings.TrimSpace(command) == "" { + return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field") + } + cmd := exec.CommandContext(ctx, home.Long(command), m.Args...) + cmd.Env = append(os.Environ(), m.ResolvedEnv()...) + return &mcp.CommandTransport{ + Command: cmd, + }, nil + case config.MCPHttp: + if strings.TrimSpace(m.URL) == "" { + return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field") + } + client := &http.Client{ + Transport: &headerRoundTripper{ + headers: m.ResolvedHeaders(), + }, + } + return &mcp.StreamableClientTransport{ + Endpoint: m.URL, + HTTPClient: client, + }, nil + case config.MCPSSE: + if strings.TrimSpace(m.URL) == "" { + return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field") + } + client := &http.Client{ + Transport: &headerRoundTripper{ + headers: m.ResolvedHeaders(), + }, + } + return &mcp.SSEClientTransport{ + Endpoint: m.URL, + HTTPClient: client, + }, nil + default: + return nil, fmt.Errorf("unsupported mcp type: %s", m.Type) + } +} + +type headerRoundTripper struct { + headers map[string]string +} + +func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + for k, v := range rt.headers { + req.Header.Set(k, v) + } + return http.DefaultTransport.RoundTrip(req) +} + +func mcpTimeout(m config.MCPConfig) time.Duration { + return time.Duration(cmp.Or(m.Timeout, 15)) * time.Second +} + +func stdioCheck(old *exec.Cmd) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + cmd := exec.CommandContext(ctx, old.Path, old.Args...) + cmd.Env = old.Env + out, err := cmd.CombinedOutput() + if err == nil || errors.Is(ctx.Err(), context.DeadlineExceeded) { + return nil + } + return fmt.Errorf("%w: %s", err, string(out)) +} diff --git a/internal/agent/tools/mcp/prompts.go b/internal/agent/tools/mcp/prompts.go new file mode 100644 index 0000000000000000000000000000000000000000..4439a6d215efb8d8645dcc0e197a05c822672829 --- /dev/null +++ b/internal/agent/tools/mcp/prompts.go @@ -0,0 +1,83 @@ +package mcp + +import ( + "context" + "iter" + + "github.com/charmbracelet/crush/internal/csync" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type Prompt = mcp.Prompt + +var ( + allPrompts = csync.NewMap[string, *Prompt]() + client2Prompts = csync.NewMap[string, []*Prompt]() +) + +// GetPrompts returns all available MCP prompts. +func GetPrompts() iter.Seq2[string, *Prompt] { + return allPrompts.Seq2() +} + +// GetPrompt returns a specific MCP prompt by name. +func GetPrompt(name string) (*Prompt, bool) { + return allPrompts.Get(name) +} + +// GetPromptsByClient returns all prompts for a specific MCP client. +func GetPromptsByClient(clientName string) ([]*Prompt, bool) { + return client2Prompts.Get(clientName) +} + +// GetPromptMessages retrieves the content of an MCP prompt with the given arguments. +func GetPromptMessages(ctx context.Context, clientName, promptName string, args map[string]string) ([]string, error) { + c, err := getOrRenewClient(ctx, clientName) + if err != nil { + return nil, err + } + result, err := c.GetPrompt(ctx, &mcp.GetPromptParams{ + Name: promptName, + Arguments: args, + }) + if err != nil { + return nil, err + } + + var messages []string + for _, msg := range result.Messages { + if msg.Role != "user" { + continue + } + if textContent, ok := msg.Content.(*mcp.TextContent); ok { + messages = append(messages, textContent.Text) + } + } + return messages, nil +} + +func getPrompts(ctx context.Context, c *mcp.ClientSession) ([]*Prompt, error) { + if c.InitializeResult().Capabilities.Prompts == nil { + return nil, nil + } + result, err := c.ListPrompts(ctx, &mcp.ListPromptsParams{}) + if err != nil { + return nil, err + } + return result.Prompts, nil +} + +// updatePrompts updates the global mcpPrompts and mcpClient2Prompts maps +func updatePrompts(mcpName string, prompts []*Prompt) { + if len(prompts) == 0 { + client2Prompts.Del(mcpName) + } else { + client2Prompts.Set(mcpName, prompts) + } + for mcpName, prompts := range client2Prompts.Seq2() { + for _, p := range prompts { + key := mcpName + ":" + p.Name + allPrompts.Set(key, p) + } + } +} diff --git a/internal/agent/tools/mcp/tools.go b/internal/agent/tools/mcp/tools.go new file mode 100644 index 0000000000000000000000000000000000000000..5b30514dcc31c69bbce4fee41637fa5505c8e684 --- /dev/null +++ b/internal/agent/tools/mcp/tools.go @@ -0,0 +1,169 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "iter" + "strings" + + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/agent/tools" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/permission" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +var ( + allTools = csync.NewMap[string, *Tool]() + client2Tools = csync.NewMap[string, []*Tool]() +) + +// GetMCPTools returns all available MCP tools. +func GetMCPTools() iter.Seq[*Tool] { + return allTools.Seq() +} + +type Tool struct { + mcpName string + tool *mcp.Tool + permissions permission.Service + workingDir string + providerOptions fantasy.ProviderOptions +} + +func (m *Tool) SetProviderOptions(opts fantasy.ProviderOptions) { + m.providerOptions = opts +} + +func (m *Tool) ProviderOptions() fantasy.ProviderOptions { + return m.providerOptions +} + +func (m *Tool) Name() string { + return fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name) +} + +func (m *Tool) MCP() string { + return m.mcpName +} + +func (m *Tool) MCPToolName() string { + return m.tool.Name +} + +func (m *Tool) Info() fantasy.ToolInfo { + parameters := make(map[string]any) + required := make([]string, 0) + + if input, ok := m.tool.InputSchema.(map[string]any); ok { + if props, ok := input["properties"].(map[string]any); ok { + parameters = props + } + if req, ok := input["required"].([]any); ok { + // Convert []any -> []string when elements are strings + for _, v := range req { + if s, ok := v.(string); ok { + required = append(required, s) + } + } + } else if reqStr, ok := input["required"].([]string); ok { + // Handle case where it's already []string + required = reqStr + } + } + + return fantasy.ToolInfo{ + Name: fmt.Sprintf("mcp_%s_%s", m.mcpName, m.tool.Name), + Description: m.tool.Description, + Parameters: parameters, + Required: required, + } +} + +func (m *Tool) Run(ctx context.Context, params fantasy.ToolCall) (fantasy.ToolResponse, error) { + sessionID := tools.GetSessionFromContext(ctx) + if sessionID == "" { + return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") + } + permissionDescription := fmt.Sprintf("execute %s with the following parameters:", m.Info().Name) + p := m.permissions.Request( + permission.CreatePermissionRequest{ + SessionID: sessionID, + ToolCallID: params.ID, + Path: m.workingDir, + ToolName: m.Info().Name, + Action: "execute", + Description: permissionDescription, + Params: params.Input, + }, + ) + if !p { + return fantasy.ToolResponse{}, permission.ErrorPermissionDenied + } + + return runTool(ctx, m.mcpName, m.tool.Name, params.Input) +} + +func runTool(ctx context.Context, name, toolName string, input string) (fantasy.ToolResponse, error) { + var args map[string]any + if err := json.Unmarshal([]byte(input), &args); err != nil { + return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil + } + + c, err := getOrRenewClient(ctx, name) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + result, err := c.CallTool(ctx, &mcp.CallToolParams{ + Name: toolName, + Arguments: args, + }) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + + output := make([]string, 0, len(result.Content)) + for _, v := range result.Content { + if vv, ok := v.(*mcp.TextContent); ok { + output = append(output, vv.Text) + } else { + output = append(output, fmt.Sprintf("%v", v)) + } + } + return fantasy.NewTextResponse(strings.Join(output, "\n")), nil +} + +func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]*Tool, error) { + if c.InitializeResult().Capabilities.Tools == nil { + return nil, nil + } + result, err := c.ListTools(ctx, &mcp.ListToolsParams{}) + if err != nil { + return nil, err + } + mcpTools := make([]*Tool, 0, len(result.Tools)) + for _, tool := range result.Tools { + mcpTools = append(mcpTools, &Tool{ + mcpName: name, + tool: tool, + permissions: permissions, + workingDir: workingDir, + }) + } + return mcpTools, nil +} + +// updateTools updates the global mcpTools and mcpClient2Tools maps +func updateTools(mcpName string, tools []*Tool) { + if len(tools) == 0 { + client2Tools.Del(mcpName) + } else { + client2Tools.Set(mcpName, tools) + } + for _, tools := range client2Tools.Seq2() { + for _, t := range tools { + allTools.Set(t.Name(), t) + } + } +} diff --git a/internal/app/app.go b/internal/app/app.go index 87e0a80186a8f837f47ad3a3cffe14e3112a825f..fe0e2957dede4b410ab1db76e85c3bbc4bc2a49b 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -16,7 +16,7 @@ import ( "charm.land/fantasy" tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/agent" - "github.com/charmbracelet/crush/internal/agent/tools" + "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" @@ -91,6 +91,11 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { // Initialize LSP clients in the background. app.initLSPClients(ctx) + go func() { + slog.Info("Initializing MCP clients") + mcp.Initialize(ctx, app.Permissions, cfg) + }() + // cleanup database upon app shutdown app.cleanupFuncs = append(app.cleanupFuncs, conn.Close) @@ -260,7 +265,7 @@ func (app *App) setupEvents() { setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events) setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events) setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events) - setupSubscriber(ctx, app.serviceEventsWG, "mcp", tools.SubscribeMCPEvents, app.events) + setupSubscriber(ctx, app.serviceEventsWG, "mcp", mcp.SubscribeEvents, app.events) setupSubscriber(ctx, app.serviceEventsWG, "lsp", SubscribeLSPEvents, app.events) cleanupFunc := func() error { cancel() @@ -324,7 +329,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error { } // Add MCP client cleanup to shutdown process - app.cleanupFuncs = append(app.cleanupFuncs, tools.CloseMCPClients) + app.cleanupFuncs = append(app.cleanupFuncs, mcp.Close) return nil } diff --git a/internal/tui/components/dialogs/commands/arguments.go b/internal/tui/components/dialogs/commands/arguments.go index 66ad3f7ba06ae41fa2a4d0e033906ceda5298c22..997c1c3056b68fed5451298e30b33cd610980a9d 100644 --- a/internal/tui/components/dialogs/commands/arguments.go +++ b/internal/tui/components/dialogs/commands/arguments.go @@ -1,8 +1,7 @@ package commands import ( - "fmt" - "strings" + "cmp" "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" @@ -20,9 +19,10 @@ const ( // ShowArgumentsDialogMsg is a message that is sent to show the arguments dialog. type ShowArgumentsDialogMsg struct { - CommandID string - Content string - ArgNames []string + CommandID string + Description string + ArgNames []string + OnSubmit func(args map[string]string) tea.Cmd } // CloseArgumentsDialogMsg is a message that is sent when the arguments dialog is closed. @@ -39,26 +39,39 @@ type CommandArgumentsDialog interface { } type commandArgumentsDialogCmp struct { - width int - wWidth int // Width of the terminal window - wHeight int // Height of the terminal window - - inputs []textinput.Model - focusIndex int - keys ArgumentsDialogKeyMap - commandID string - content string - argNames []string - help help.Model + wWidth, wHeight int + width, height int + + inputs []textinput.Model + focused int + keys ArgumentsDialogKeyMap + arguments []Argument + help help.Model + + id string + title string + name string + description string + + onSubmit func(args map[string]string) tea.Cmd +} + +type Argument struct { + Name, Title, Description string + Required bool } -func NewCommandArgumentsDialog(commandID, content string, argNames []string) CommandArgumentsDialog { +func NewCommandArgumentsDialog( + id, title, name, description string, + arguments []Argument, + onSubmit func(args map[string]string) tea.Cmd, +) CommandArgumentsDialog { t := styles.CurrentTheme() - inputs := make([]textinput.Model, len(argNames)) + inputs := make([]textinput.Model, len(arguments)) - for i, name := range argNames { + for i, arg := range arguments { ti := textinput.New() - ti.Placeholder = fmt.Sprintf("Enter value for %s...", name) + ti.Placeholder = cmp.Or(arg.Description, "Enter value for "+arg.Title) ti.SetWidth(40) ti.SetVirtualCursor(false) ti.Prompt = "" @@ -75,14 +88,16 @@ func NewCommandArgumentsDialog(commandID, content string, argNames []string) Com } return &commandArgumentsDialogCmp{ - inputs: inputs, - keys: DefaultArgumentsDialogKeyMap(), - commandID: commandID, - content: content, - argNames: argNames, - focusIndex: 0, - width: 60, - help: help.New(), + inputs: inputs, + keys: DefaultArgumentsDialogKeyMap(), + id: id, + name: name, + title: title, + description: description, + arguments: arguments, + width: 60, + help: help.New(), + onSubmit: onSubmit, } } @@ -97,47 +112,51 @@ func (c *commandArgumentsDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { case tea.WindowSizeMsg: c.wWidth = msg.Width c.wHeight = msg.Height + c.width = min(90, c.wWidth) + c.height = min(15, c.wHeight) + for i := range c.inputs { + c.inputs[i].SetWidth(c.width - (paddingHorizontal * 2)) + } case tea.KeyPressMsg: switch { + case key.Matches(msg, c.keys.Close): + return c, util.CmdHandler(dialogs.CloseDialogMsg{}) case key.Matches(msg, c.keys.Confirm): - if c.focusIndex == len(c.inputs)-1 { - content := c.content - for i, name := range c.argNames { + if c.focused == len(c.inputs)-1 { + args := make(map[string]string) + for i, arg := range c.arguments { value := c.inputs[i].Value() - placeholder := "$" + name - content = strings.ReplaceAll(content, placeholder, value) + args[arg.Name] = value } return c, tea.Sequence( util.CmdHandler(dialogs.CloseDialogMsg{}), - util.CmdHandler(CommandRunCustomMsg{ - Content: content, - }), + c.onSubmit(args), ) } // Otherwise, move to the next input - c.inputs[c.focusIndex].Blur() - c.focusIndex++ - c.inputs[c.focusIndex].Focus() + c.inputs[c.focused].Blur() + c.focused++ + c.inputs[c.focused].Focus() case key.Matches(msg, c.keys.Next): // Move to the next input - c.inputs[c.focusIndex].Blur() - c.focusIndex = (c.focusIndex + 1) % len(c.inputs) - c.inputs[c.focusIndex].Focus() + c.inputs[c.focused].Blur() + c.focused = (c.focused + 1) % len(c.inputs) + c.inputs[c.focused].Focus() case key.Matches(msg, c.keys.Previous): // Move to the previous input - c.inputs[c.focusIndex].Blur() - c.focusIndex = (c.focusIndex - 1 + len(c.inputs)) % len(c.inputs) - c.inputs[c.focusIndex].Focus() + c.inputs[c.focused].Blur() + c.focused = (c.focused - 1 + len(c.inputs)) % len(c.inputs) + c.inputs[c.focused].Focus() case key.Matches(msg, c.keys.Close): return c, util.CmdHandler(dialogs.CloseDialogMsg{}) default: var cmd tea.Cmd - c.inputs[c.focusIndex], cmd = c.inputs[c.focusIndex].Update(msg) + c.inputs[c.focused], cmd = c.inputs[c.focused].Update(msg) return c, cmd } case tea.PasteMsg: var cmd tea.Cmd - c.inputs[c.focusIndex], cmd = c.inputs[c.focusIndex].Update(msg) + c.inputs[c.focused], cmd = c.inputs[c.focused].Update(msg) return c, cmd } return c, nil @@ -152,26 +171,28 @@ func (c *commandArgumentsDialogCmp) View() string { Foreground(t.Primary). Bold(true). Padding(0, 1). - Render("Command Arguments") + Render(cmp.Or(c.title, c.name)) - explanation := t.S().Text. + promptName := t.S().Text. Padding(0, 1). - Render("This command requires arguments.") + Render(c.description) - // Create input fields for each argument inputFields := make([]string, len(c.inputs)) for i, input := range c.inputs { - // Highlight the label of the focused input - labelStyle := baseStyle. - Padding(1, 1, 0, 1) + labelStyle := baseStyle.Padding(1, 1, 0, 1) - if i == c.focusIndex { + if i == c.focused { labelStyle = labelStyle.Foreground(t.FgBase).Bold(true) } else { labelStyle = labelStyle.Foreground(t.FgMuted) } - label := labelStyle.Render(c.argNames[i] + ":") + arg := c.arguments[i] + argName := cmp.Or(arg.Title, arg.Name) + if arg.Required { + argName += "*" + } + label := labelStyle.Render(argName + ":") field := t.S().Text. Padding(0, 1). @@ -180,18 +201,14 @@ func (c *commandArgumentsDialogCmp) View() string { inputFields[i] = lipgloss.JoinVertical(lipgloss.Left, label, field) } - // Join all elements vertically - elements := []string{title, explanation} + elements := []string{title, promptName} elements = append(elements, inputFields...) c.help.ShowAll = false helpText := baseStyle.Padding(0, 1).Render(c.help.View(c.keys)) elements = append(elements, "", helpText) - content := lipgloss.JoinVertical( - lipgloss.Left, - elements..., - ) + content := lipgloss.JoinVertical(lipgloss.Left, elements...) return baseStyle.Padding(1, 1, 0, 1). Border(lipgloss.RoundedBorder()). @@ -201,26 +218,33 @@ func (c *commandArgumentsDialogCmp) View() string { } func (c *commandArgumentsDialogCmp) Cursor() *tea.Cursor { - cursor := c.inputs[c.focusIndex].Cursor() + if len(c.inputs) == 0 { + return nil + } + cursor := c.inputs[c.focused].Cursor() if cursor != nil { cursor = c.moveCursor(cursor) } return cursor } +const ( + headerHeight = 3 + itemHeight = 3 + paddingHorizontal = 3 +) + func (c *commandArgumentsDialogCmp) moveCursor(cursor *tea.Cursor) *tea.Cursor { row, col := c.Position() - offset := row + 3 + (1+c.focusIndex)*3 + offset := row + headerHeight + (1+c.focused)*itemHeight cursor.Y += offset - cursor.X = cursor.X + col + 3 + cursor.X = cursor.X + col + paddingHorizontal return cursor } func (c *commandArgumentsDialogCmp) Position() (int, int) { - row := c.wHeight / 2 - row -= c.wHeight / 2 - col := c.wWidth / 2 - col -= c.width / 2 + row := (c.wHeight / 2) - (c.height / 2) + col := (c.wWidth / 2) - (c.width / 2) return row, col } diff --git a/internal/tui/components/dialogs/commands/commands.go b/internal/tui/components/dialogs/commands/commands.go index 72ee6e353932cbb0714042dc325189f564066aaa..9d98005798aabd4879a8cb9a843b776869b3f326 100644 --- a/internal/tui/components/dialogs/commands/commands.go +++ b/internal/tui/components/dialogs/commands/commands.go @@ -2,6 +2,8 @@ package commands import ( "os" + "slices" + "strings" "github.com/charmbracelet/bubbles/v2/help" "github.com/charmbracelet/bubbles/v2/key" @@ -10,7 +12,10 @@ import ( "github.com/charmbracelet/lipgloss/v2" "github.com/charmbracelet/crush/internal/agent" + "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/config" + "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/components/dialogs" @@ -25,9 +30,14 @@ const ( defaultWidth int = 70 ) +type commandType uint + +func (c commandType) String() string { return []string{"System", "User", "MCP"}[c] } + const ( - SystemCommands int = iota + SystemCommands commandType = iota UserCommands + MCPPrompts ) type listModel = list.FilterableList[list.CompletionItem[Command]] @@ -54,9 +64,10 @@ type commandDialogCmp struct { commandList listModel keyMap CommandsDialogKeyMap help help.Model - commandType int // SystemCommands or UserCommands - userCommands []Command // User-defined commands - sessionID string // Current session ID + selected commandType // Selected SystemCommands, UserCommands, or MCPPrompts + userCommands []Command // User-defined commands + mcpPrompts *csync.Slice[Command] // MCP prompts + sessionID string // Current session ID } type ( @@ -102,8 +113,9 @@ func NewCommandDialog(sessionID string) CommandsDialog { width: defaultWidth, keyMap: DefaultCommandsDialogKeyMap(), help: help, - commandType: SystemCommands, + selected: SystemCommands, sessionID: sessionID, + mcpPrompts: csync.NewSlice[Command](), } } @@ -113,7 +125,8 @@ func (c *commandDialogCmp) Init() tea.Cmd { return util.ReportError(err) } c.userCommands = commands - return c.SetCommandType(c.commandType) + c.mcpPrompts.SetSlice(loadMCPPrompts()) + return c.setCommandType(c.selected) } func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { @@ -122,9 +135,19 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { c.wWidth = msg.Width c.wHeight = msg.Height return c, tea.Batch( - c.SetCommandType(c.commandType), + c.setCommandType(c.selected), c.commandList.SetSize(c.listWidth(), c.listHeight()), ) + case pubsub.Event[mcp.Event]: + // Reload MCP prompts when MCP state changes + if msg.Type == pubsub.UpdatedEvent { + c.mcpPrompts.SetSlice(loadMCPPrompts()) + // If we're currently viewing MCP prompts, refresh the list + if c.selected == MCPPrompts { + return c, c.setCommandType(MCPPrompts) + } + return c, nil + } case tea.KeyPressMsg: switch { case key.Matches(msg, c.keyMap.Select): @@ -138,15 +161,10 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { command.Handler(command), ) case key.Matches(msg, c.keyMap.Tab): - if len(c.userCommands) == 0 { + if len(c.userCommands) == 0 && c.mcpPrompts.Len() == 0 { return c, nil } - // Toggle command type between System and User commands - if c.commandType == SystemCommands { - return c, c.SetCommandType(UserCommands) - } else { - return c, c.SetCommandType(SystemCommands) - } + return c, c.setCommandType(c.next()) case key.Matches(msg, c.keyMap.Close): return c, util.CmdHandler(dialogs.CloseDialogMsg{}) default: @@ -158,13 +176,35 @@ func (c *commandDialogCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { return c, nil } +func (c *commandDialogCmp) next() commandType { + switch c.selected { + case SystemCommands: + if len(c.userCommands) > 0 { + return UserCommands + } + if c.mcpPrompts.Len() > 0 { + return MCPPrompts + } + fallthrough + case UserCommands: + if c.mcpPrompts.Len() > 0 { + return MCPPrompts + } + fallthrough + case MCPPrompts: + return SystemCommands + default: + return SystemCommands + } +} + func (c *commandDialogCmp) View() string { t := styles.CurrentTheme() listView := c.commandList radio := c.commandTypeRadio() header := t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-lipgloss.Width(radio)-5) + " " + radio) - if len(c.userCommands) == 0 { + if len(c.userCommands) == 0 && c.mcpPrompts.Len() == 0 { header = t.S().Base.Padding(0, 1, 1, 1).Render(core.Title("Commands", c.width-4)) } content := lipgloss.JoinVertical( @@ -190,27 +230,41 @@ func (c *commandDialogCmp) Cursor() *tea.Cursor { func (c *commandDialogCmp) commandTypeRadio() string { t := styles.CurrentTheme() - choices := []string{"System", "User"} - iconSelected := "◉" - iconUnselected := "○" - if c.commandType == SystemCommands { - return t.S().Base.Foreground(t.FgHalfMuted).Render(iconSelected + " " + choices[0] + " " + iconUnselected + " " + choices[1]) + + fn := func(i commandType) string { + if i == c.selected { + return "◉ " + i.String() + } + return "○ " + i.String() + } + + parts := []string{ + fn(SystemCommands), + } + if len(c.userCommands) > 0 { + parts = append(parts, fn(UserCommands)) + } + if c.mcpPrompts.Len() > 0 { + parts = append(parts, fn(MCPPrompts)) } - return t.S().Base.Foreground(t.FgHalfMuted).Render(iconUnselected + " " + choices[0] + " " + iconSelected + " " + choices[1]) + return t.S().Base.Foreground(t.FgHalfMuted).Render(strings.Join(parts, " ")) } func (c *commandDialogCmp) listWidth() int { return defaultWidth - 2 // 4 for padding } -func (c *commandDialogCmp) SetCommandType(commandType int) tea.Cmd { - c.commandType = commandType +func (c *commandDialogCmp) setCommandType(commandType commandType) tea.Cmd { + c.selected = commandType var commands []Command - if c.commandType == SystemCommands { + switch c.selected { + case SystemCommands: commands = c.defaultCommands() - } else { + case UserCommands: commands = c.userCommands + case MCPPrompts: + commands = slices.Collect(c.mcpPrompts.Seq()) } commandItems := []list.CompletionItem[Command]{} diff --git a/internal/tui/components/dialogs/commands/loader.go b/internal/tui/components/dialogs/commands/loader.go index 74d9c7e4baee2e2d19f8baca914942f0c0d34cd3..ae52104b35e7614bb46bf6e30986fc90b43bb716 100644 --- a/internal/tui/components/dialogs/commands/loader.go +++ b/internal/tui/components/dialogs/commands/loader.go @@ -1,22 +1,26 @@ package commands import ( + "context" "fmt" "io/fs" + "log/slog" "os" "path/filepath" "regexp" "strings" tea "github.com/charmbracelet/bubbletea/v2" + "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/home" + "github.com/charmbracelet/crush/internal/tui/components/chat" "github.com/charmbracelet/crush/internal/tui/util" ) const ( - UserCommandPrefix = "user:" - ProjectCommandPrefix = "project:" + userCommandPrefix = "user:" + projectCommandPrefix = "project:" ) var namedArgPattern = regexp.MustCompile(`\$([A-Z][A-Z0-9_]*)`) @@ -50,7 +54,7 @@ func buildCommandSources(cfg *config.Config) []commandSource { if dir := getXDGCommandsDir(); dir != "" { sources = append(sources, commandSource{ path: dir, - prefix: UserCommandPrefix, + prefix: userCommandPrefix, }) } @@ -58,14 +62,14 @@ func buildCommandSources(cfg *config.Config) []commandSource { if home := home.Dir(); home != "" { sources = append(sources, commandSource{ path: filepath.Join(home, ".crush", "commands"), - prefix: UserCommandPrefix, + prefix: userCommandPrefix, }) } // Project directory sources = append(sources, commandSource{ path: filepath.Join(cfg.Options.DataDirectory, "commands"), - prefix: ProjectCommandPrefix, + prefix: projectCommandPrefix, }) return sources @@ -127,12 +131,13 @@ func (l *commandLoader) loadCommand(path, baseDir, prefix string) (Command, erro } id := buildCommandID(path, baseDir, prefix) + desc := fmt.Sprintf("Custom command from %s", filepath.Base(path)) return Command{ ID: id, Title: id, - Description: fmt.Sprintf("Custom command from %s", filepath.Base(path)), - Handler: createCommandHandler(id, string(content)), + Description: desc, + Handler: createCommandHandler(id, desc, string(content)), }, nil } @@ -149,21 +154,35 @@ func buildCommandID(path, baseDir, prefix string) string { return prefix + strings.Join(parts, ":") } -func createCommandHandler(id string, content string) func(Command) tea.Cmd { +func createCommandHandler(id, desc, content string) func(Command) tea.Cmd { return func(cmd Command) tea.Cmd { args := extractArgNames(content) - if len(args) > 0 { - return util.CmdHandler(ShowArgumentsDialogMsg{ - CommandID: id, - Content: content, - ArgNames: args, + if len(args) == 0 { + return util.CmdHandler(CommandRunCustomMsg{ + Content: content, }) } + return util.CmdHandler(ShowArgumentsDialogMsg{ + CommandID: id, + Description: desc, + ArgNames: args, + OnSubmit: func(args map[string]string) tea.Cmd { + return execUserPrompt(content, args) + }, + }) + } +} - return util.CmdHandler(CommandRunCustomMsg{ +func execUserPrompt(content string, args map[string]string) tea.Cmd { + return func() tea.Msg { + for name, value := range args { + placeholder := "$" + name + content = strings.ReplaceAll(content, placeholder, value) + } + return CommandRunCustomMsg{ Content: content, - }) + } } } @@ -201,3 +220,55 @@ func isMarkdownFile(name string) bool { type CommandRunCustomMsg struct { Content string } + +func loadMCPPrompts() []Command { + var commands []Command + for key, prompt := range mcp.GetPrompts() { + clientName, promptName, ok := strings.Cut(key, ":") + if !ok { + slog.Warn("prompt not found", "key", key) + continue + } + commands = append(commands, Command{ + ID: key, + Title: clientName + ":" + promptName, + Description: prompt.Description, + Handler: createMCPPromptHandler(clientName, promptName, prompt), + }) + } + + return commands +} + +func createMCPPromptHandler(clientName, promptName string, prompt *mcp.Prompt) func(Command) tea.Cmd { + return func(cmd Command) tea.Cmd { + if len(prompt.Arguments) == 0 { + return execMCPPrompt(clientName, promptName, nil) + } + return util.CmdHandler(ShowMCPPromptArgumentsDialogMsg{ + Prompt: prompt, + OnSubmit: func(args map[string]string) tea.Cmd { + return execMCPPrompt(clientName, promptName, args) + }, + }) + } +} + +func execMCPPrompt(clientName, promptName string, args map[string]string) tea.Cmd { + return func() tea.Msg { + ctx := context.Background() + result, err := mcp.GetPromptMessages(ctx, clientName, promptName, args) + if err != nil { + return util.ReportError(err) + } + + return chat.SendMsg{ + Text: strings.Join(result, " "), + } + } +} + +type ShowMCPPromptArgumentsDialogMsg struct { + Prompt *mcp.Prompt + OnSubmit func(arg map[string]string) tea.Cmd +} diff --git a/internal/tui/components/mcp/mcp.go b/internal/tui/components/mcp/mcp.go index 06662105851173cea2a63b03f92d0e9451d66016..355341c9d9a08e70422c1c6c464c5fddb7cb1213 100644 --- a/internal/tui/components/mcp/mcp.go +++ b/internal/tui/components/mcp/mcp.go @@ -2,10 +2,11 @@ package mcp import ( "fmt" + "strings" "github.com/charmbracelet/lipgloss/v2" - "github.com/charmbracelet/crush/internal/agent/tools" + "github.com/charmbracelet/crush/internal/agent/tools/mcp" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/tui/components/core" "github.com/charmbracelet/crush/internal/tui/styles" @@ -40,7 +41,7 @@ func RenderMCPList(opts RenderOptions) []string { } // Get MCP states - mcpStates := tools.GetMCPStates() + mcpStates := mcp.GetStates() // Determine how many items to show maxItems := len(mcps) @@ -56,21 +57,24 @@ func RenderMCPList(opts RenderOptions) []string { // Determine icon and color based on state icon := t.ItemOfflineIcon description := "" - extraContent := "" + extraContent := []string{} if state, exists := mcpStates[l.Name]; exists { switch state.State { - case tools.MCPStateDisabled: + case mcp.StateDisabled: description = t.S().Subtle.Render("disabled") - case tools.MCPStateStarting: + case mcp.StateStarting: icon = t.ItemBusyIcon description = t.S().Subtle.Render("starting...") - case tools.MCPStateConnected: + case mcp.StateConnected: icon = t.ItemOnlineIcon - if state.ToolCount > 0 { - extraContent = t.S().Subtle.Render(fmt.Sprintf("%d tools", state.ToolCount)) + if count := state.Counts.Tools; count > 0 { + extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d tools", count))) } - case tools.MCPStateError: + if count := state.Counts.Prompts; count > 0 { + extraContent = append(extraContent, t.S().Subtle.Render(fmt.Sprintf("%d prompts", count))) + } + case mcp.StateError: icon = t.ItemErrorIcon if state.Error != nil { description = t.S().Subtle.Render(fmt.Sprintf("error: %s", state.Error.Error())) @@ -88,7 +92,7 @@ func RenderMCPList(opts RenderOptions) []string { Icon: icon.String(), Title: l.Name, Description: description, - ExtraContent: extraContent, + ExtraContent: strings.Join(extraContent, " "), }, opts.MaxWidth, ), diff --git a/internal/tui/tui.go b/internal/tui/tui.go index d2e36764c573ecc6a783e1148e07cb28934b7019..ea0ed6649e5cc6de57597c589dc9785fb441bb72 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -33,6 +33,8 @@ import ( "github.com/charmbracelet/crush/internal/tui/styles" "github.com/charmbracelet/crush/internal/tui/util" "github.com/charmbracelet/lipgloss/v2" + "golang.org/x/text/cases" + "golang.org/x/text/language" ) var lastMouseEvent time.Time @@ -156,15 +158,44 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.dialog = u.(dialogs.DialogCmp) return a, tea.Batch(completionCmd, dialogCmd) case commands.ShowArgumentsDialogMsg: + var args []commands.Argument + for _, arg := range msg.ArgNames { + args = append(args, commands.Argument{ + Name: arg, + Title: cases.Title(language.English).String(arg), + Required: true, + }) + } return a, util.CmdHandler( dialogs.OpenDialogMsg{ Model: commands.NewCommandArgumentsDialog( msg.CommandID, - msg.Content, - msg.ArgNames, + msg.CommandID, + msg.CommandID, + msg.Description, + args, + msg.OnSubmit, ), }, ) + case commands.ShowMCPPromptArgumentsDialogMsg: + args := make([]commands.Argument, 0, len(msg.Prompt.Arguments)) + for _, arg := range msg.Prompt.Arguments { + args = append(args, commands.Argument(*arg)) + } + dialog := commands.NewCommandArgumentsDialog( + msg.Prompt.Name, + msg.Prompt.Title, + msg.Prompt.Name, + msg.Prompt.Description, + args, + msg.OnSubmit, + ) + return a, util.CmdHandler( + dialogs.OpenDialogMsg{ + Model: dialog, + }, + ) // Page change messages case page.PageChangeMsg: return a, a.moveToPage(msg.ID)