@@ -6,8 +6,11 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
"log/slog"
"maps"
+ "net/http"
+ "os/exec"
"strings"
"sync"
"time"
@@ -19,9 +22,7 @@ import (
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/version"
- "github.com/mark3labs/mcp-go/client"
- "github.com/mark3labs/mcp-go/client/transport"
- "github.com/mark3labs/mcp-go/mcp"
+ "github.com/modelcontextprotocol/go-sdk/mcp"
)
// MCPState represents the current state of an MCP client
@@ -71,7 +72,7 @@ type MCPClientInfo struct {
Name string
State MCPState
Error error
- Client *client.Client
+ Client *mcp.ClientSession
ToolCount int
ConnectedAt time.Time
}
@@ -80,14 +81,14 @@ var (
mcpToolsOnce sync.Once
mcpTools = csync.NewMap[string, tools.BaseTool]()
mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
- mcpClients = csync.NewMap[string, *client.Client]()
+ mcpClients = csync.NewMap[string, *mcp.ClientSession]()
mcpStates = csync.NewMap[string, MCPClientInfo]()
mcpBroker = pubsub.NewBroker[MCPEvent]()
)
type McpTool struct {
mcpName string
- tool mcp.Tool
+ tool *mcp.Tool
permissions permission.Service
workingDir string
}
@@ -97,14 +98,9 @@ func (b *McpTool) Name() string {
}
func (b *McpTool) Info() tools.ToolInfo {
- required := b.tool.InputSchema.Required
- if required == nil {
- required = make([]string, 0)
- }
- parameters := b.tool.InputSchema.Properties
- if parameters == nil {
- parameters = make(map[string]any)
- }
+ input := b.tool.InputSchema.(map[string]any)
+ required, _ := input["required"].([]string)
+ parameters, _ := input["properties"].(map[string]any)
return tools.ToolInfo{
Name: fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name),
Description: b.tool.Description,
@@ -123,11 +119,9 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
- result, err := c.CallTool(ctx, mcp.CallToolRequest{
- Params: mcp.CallToolParams{
- Name: toolName,
- Arguments: args,
- },
+ result, err := c.CallTool(ctx, &mcp.CallToolParams{
+ Name: toolName,
+ Arguments: args,
})
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
@@ -135,8 +129,8 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To
output := make([]string, 0, len(result.Content))
for _, v := range result.Content {
- if v, ok := v.(mcp.TextContent); ok {
- output = append(output, v.Text)
+ if vv, ok := v.(*mcp.TextContent); ok {
+ output = append(output, vv.Text)
} else {
output = append(output, fmt.Sprintf("%v", v))
}
@@ -144,8 +138,8 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To
return tools.NewTextResponse(strings.Join(output, "\n")), nil
}
-func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
- c, ok := mcpClients.Get(name)
+func getOrRenewClient(ctx context.Context, name string) (*mcp.ClientSession, error) {
+ sess, ok := mcpClients.Get(name)
if !ok {
return nil, fmt.Errorf("mcp '%s' not available", name)
}
@@ -157,20 +151,20 @@ func getOrRenewClient(ctx context.Context, name string) (*client.Client, error)
timeout := mcpTimeout(m)
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
- err := c.Ping(pingCtx)
+ err := sess.Ping(pingCtx, nil)
if err == nil {
- return c, nil
+ return sess, nil
}
updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, state.ToolCount)
- c, err = createAndInitializeClient(ctx, name, m, cfg.Resolver())
+ sess, err = createMCPSession(ctx, name, m, cfg.Resolver())
if err != nil {
return nil, err
}
- updateMCPState(name, MCPStateConnected, nil, c, state.ToolCount)
- mcpClients.Set(name, c)
- return c, nil
+ updateMCPState(name, MCPStateConnected, nil, sess, state.ToolCount)
+ mcpClients.Set(name, sess)
+ return sess, nil
}
func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolResponse, error) {
@@ -197,8 +191,8 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
}
-func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) ([]tools.BaseTool, error) {
- result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
+func getTools(ctx context.Context, name string, permissions permission.Service, c *mcp.ClientSession, workingDir string) ([]tools.BaseTool, error) {
+ result, err := c.ListTools(ctx, &mcp.ListToolsParams{})
if err != nil {
return nil, err
}
@@ -230,7 +224,7 @@ func GetMCPState(name string) (MCPClientInfo, bool) {
}
// updateMCPState updates the state of an MCP client and publishes an event
-func updateMCPState(name string, state MCPState, err error, client *client.Client, toolCount int) {
+func updateMCPState(name string, state MCPState, err error, client *mcp.ClientSession, toolCount int) {
info := MCPClientInfo{
Name: name,
State: state,
@@ -257,19 +251,14 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien
})
}
-// publishMCPEventToolsListChanged publishes a tool list changed event
-func publishMCPEventToolsListChanged(name string) {
- mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
- Type: MCPEventToolsListChanged,
- Name: name,
- })
-}
-
// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
func CloseMCPClients() error {
var errs []error
for name, c := range mcpClients.Seq2() {
- if err := c.Close(); err != nil {
+ if err := c.Close(); err != nil &&
+ !errors.Is(err, io.EOF) &&
+ !errors.Is(err, context.Canceled) &&
+ err.Error() != "signal: killed" {
errs = append(errs, fmt.Errorf("close mcp: %s: %w", name, err))
}
}
@@ -277,16 +266,6 @@ func CloseMCPClients() error {
return errors.Join(errs...)
}
-var mcpInitRequest = mcp.InitializeRequest{
- Params: mcp.InitializeParams{
- ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
- ClientInfo: mcp.Implementation{
- Name: "Crush",
- Version: version.Version,
- },
- },
-}
-
func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
var wg sync.WaitGroup
// Initialize states for all configured MCPs
@@ -322,7 +301,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
defer cancel()
- c, err := createAndInitializeClient(ctx, name, m, cfg.Resolver())
+ c, err := createMCPSession(ctx, name, m, cfg.Resolver())
if err != nil {
return
}
@@ -359,49 +338,46 @@ func updateMcpTools(mcpName string, tools []tools.BaseTool) {
}
}
-func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
- c, err := createMcpClient(name, m, resolver)
+func createMCPSession(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*mcp.ClientSession, error) {
+ timeout := mcpTimeout(m)
+ mcpCtx, cancel := context.WithCancel(ctx)
+ cancelTimer := time.AfterFunc(timeout, cancel)
+
+ transport, err := createMCPTransport(mcpCtx, m, resolver)
if err != nil {
updateMCPState(name, MCPStateError, err, nil, 0)
slog.Error("error creating mcp client", "error", err, "name", name)
return nil, err
}
- c.OnNotification(func(n mcp.JSONRPCNotification) {
- slog.Debug("Received MCP notification", "name", name, "notification", n)
- switch n.Method {
- case "notifications/tools/list_changed":
- publishMCPEventToolsListChanged(name)
- default:
- slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
- }
- })
-
- // XXX: ideally we should be able to use context.WithTimeout here, but,
- // the SSE MCP client will start failing once that context is canceled.
- timeout := mcpTimeout(m)
- mcpCtx, cancel := context.WithCancel(ctx)
- cancelTimer := time.AfterFunc(timeout, cancel)
+ client := mcp.NewClient(
+ &mcp.Implementation{
+ Name: "crush",
+ Version: version.Version,
+ Title: "Crush",
+ },
+ &mcp.ClientOptions{
+ ToolListChangedHandler: func(context.Context, *mcp.ToolListChangedRequest) {
+ mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
+ Type: MCPEventToolsListChanged,
+ Name: name,
+ })
+ },
+ KeepAlive: time.Minute * 10,
+ },
+ )
- if err := c.Start(mcpCtx); err != nil {
+ session, err := client.Connect(mcpCtx, transport, nil)
+ if err != nil {
updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
slog.Error("error starting mcp client", "error", err, "name", name)
- _ = c.Close()
- cancel()
- return nil, err
- }
-
- if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil {
- updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
- slog.Error("error initializing mcp client", "error", err, "name", name)
- _ = c.Close()
cancel()
return nil, err
}
cancelTimer.Stop()
slog.Info("Initialized mcp client", "name", name)
- return c, nil
+ return session, nil
}
func maybeTimeoutErr(err error, timeout time.Duration) error {
@@ -411,7 +387,7 @@ func maybeTimeoutErr(err error, timeout time.Duration) error {
return err
}
-func createMcpClient(name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
+func createMCPTransport(ctx context.Context, m config.MCPConfig, resolver config.VariableResolver) (mcp.Transport, error) {
switch m.Type {
case config.MCPStdio:
command, err := resolver.ResolveValue(m.Command)
@@ -421,44 +397,51 @@ func createMcpClient(name string, m config.MCPConfig, resolver config.VariableRe
if strings.TrimSpace(command) == "" {
return nil, fmt.Errorf("mcp stdio config requires a non-empty 'command' field")
}
- return client.NewStdioMCPClientWithOptions(
- home.Long(command),
- m.ResolvedEnv(),
- m.Args,
- transport.WithCommandLogger(mcpLogger{name: name}),
- )
+ cmd := exec.CommandContext(ctx, home.Long(command), m.Args...)
+ cmd.Env = m.ResolvedEnv()
+ return &mcp.CommandTransport{
+ Command: cmd,
+ }, nil
case config.MCPHttp:
if strings.TrimSpace(m.URL) == "" {
return nil, fmt.Errorf("mcp http config requires a non-empty 'url' field")
}
- return client.NewStreamableHttpClient(
- m.URL,
- transport.WithHTTPHeaders(m.ResolvedHeaders()),
- transport.WithHTTPLogger(mcpLogger{name: name}),
- )
- case config.MCPSse:
+ client := &http.Client{
+ Transport: &headerRoundTripper{
+ headers: m.ResolvedHeaders(),
+ },
+ }
+ return &mcp.StreamableClientTransport{
+ Endpoint: m.URL,
+ HTTPClient: client,
+ }, nil
+ case config.MCPSSE:
if strings.TrimSpace(m.URL) == "" {
return nil, fmt.Errorf("mcp sse config requires a non-empty 'url' field")
}
- return client.NewSSEMCPClient(
- m.URL,
- client.WithHeaders(m.ResolvedHeaders()),
- transport.WithSSELogger(mcpLogger{name: name}),
- )
+ client := &http.Client{
+ Transport: &headerRoundTripper{
+ headers: m.ResolvedHeaders(),
+ },
+ }
+ return &mcp.SSEClientTransport{
+ Endpoint: m.URL,
+ HTTPClient: client,
+ }, nil
default:
return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
}
}
-// for MCP's clients.
-type mcpLogger struct{ name string }
-
-func (l mcpLogger) Errorf(format string, v ...any) {
- slog.Error(fmt.Sprintf(format, v...), "name", l.name)
+type headerRoundTripper struct {
+ headers map[string]string
}
-func (l mcpLogger) Infof(format string, v ...any) {
- slog.Info(fmt.Sprintf(format, v...), "name", l.name)
+func (rt headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ for k, v := range rt.headers {
+ req.Header.Set(k, v)
+ }
+ return http.DefaultTransport.RoundTrip(req)
}
func mcpTimeout(m config.MCPConfig) time.Duration {