@@ -11,34 +11,28 @@ import (
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/llm/tools"
+	"github.com/charmbracelet/crush/internal/version"
 
 	"github.com/charmbracelet/crush/internal/permission"
-	"github.com/charmbracelet/crush/internal/version"
 
 	"github.com/mark3labs/mcp-go/client"
 	"github.com/mark3labs/mcp-go/client/transport"
 	"github.com/mark3labs/mcp-go/mcp"
 )
 
+var (
+	mcpToolsOnce sync.Once
+	mcpTools     []tools.BaseTool
+	mcpClients   = csync.NewMap[string, *client.Client]()
+)
+
 type McpTool struct {
 	mcpName     string
 	tool        mcp.Tool
-	client      MCPClient
-	mcpConfig   config.MCPConfig
 	permissions permission.Service
 	workingDir  string
 }
 
-type MCPClient interface {
-	Initialize(
-		ctx context.Context,
-		request mcp.InitializeRequest,
-	) (*mcp.InitializeResult, error)
-	ListTools(ctx context.Context, request mcp.ListToolsRequest) (*mcp.ListToolsResult, error)
-	CallTool(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
-	Close() error
-}
-
 func (b *McpTool) Name() string {
 	return fmt.Sprintf("mcp_%s_%s", b.mcpName, b.tool.Name)
 }
@@ -56,27 +50,21 @@ func (b *McpTool) Info() tools.ToolInfo {
 	}
 }
 
-func runTool(ctx context.Context, c MCPClient, toolName string, input string) (tools.ToolResponse, error) {
-	initRequest := mcp.InitializeRequest{}
-	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
-	initRequest.Params.ClientInfo = mcp.Implementation{
-		Name:    "Crush",
-		Version: version.Version,
-	}
-
-	_, err := c.Initialize(ctx, initRequest)
-	if err != nil {
-		return tools.NewTextErrorResponse(err.Error()), nil
-	}
-
-	toolRequest := mcp.CallToolRequest{}
-	toolRequest.Params.Name = toolName
+func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
 	var args map[string]any
-	if err = json.Unmarshal([]byte(input), &args); err != nil {
+	if err := json.Unmarshal([]byte(input), &args); err != nil {
 		return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 	}
-	toolRequest.Params.Arguments = args
-	result, err := c.CallTool(ctx, toolRequest)
+	c, ok := mcpClients.Get(name)
+	if !ok {
+		return tools.NewTextErrorResponse("mcp '" + name + "' not available"), nil
+	}
+	result, err := c.CallTool(ctx, mcp.CallToolRequest{
+		Params: mcp.CallToolParams{
+			Name:      toolName,
+			Arguments: args,
+		},
+	})
 	if err != nil {
 		return tools.NewTextErrorResponse(err.Error()), nil
 	}
@@ -114,56 +102,34 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
 		return tools.ToolResponse{}, permission.ErrorPermissionDenied
 	}
 
-	return runTool(ctx, b.client, b.tool.Name, params.Input)
+	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
 }
 
-func NewMcpTool(name string, c MCPClient, tool mcp.Tool, permissions permission.Service, mcpConfig config.MCPConfig, workingDir string) tools.BaseTool {
-	return &McpTool{
-		mcpName:     name,
-		client:      c,
-		tool:        tool,
-		mcpConfig:   mcpConfig,
-		permissions: permissions,
-		workingDir:  workingDir,
-	}
-}
-
-func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
-	var stdioTools []tools.BaseTool
-	initRequest := mcp.InitializeRequest{}
-	initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
-	initRequest.Params.ClientInfo = mcp.Implementation{
-		Name:    "Crush",
-		Version: version.Version,
-	}
-
-	_, err := c.Initialize(ctx, initRequest)
-	if err != nil {
-		slog.Error("error initializing mcp client", "error", err)
-		return stdioTools
-	}
-	toolsRequest := mcp.ListToolsRequest{}
-	tools, err := c.ListTools(ctx, toolsRequest)
+func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
+	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
 	if err != nil {
 		slog.Error("error listing tools", "error", err)
-		return stdioTools
+		c.Close()
+		mcpClients.Del(name)
+		return nil
 	}
-	for _, t := range tools.Tools {
-		stdioTools = append(stdioTools, NewMcpTool(name, c, t, permissions, m, workingDir))
+	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
+	for _, tool := range result.Tools {
+		mcpTools = append(mcpTools, &McpTool{
+			mcpName:     name,
+			tool:        tool,
+			permissions: permissions,
+			workingDir:  workingDir,
+		})
 	}
-	return stdioTools
+	return mcpTools
 }
 
-var (
-	mcpToolsOnce sync.Once
-	mcpTools     []tools.BaseTool
-)
-
-func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
-	mcpToolsOnce.Do(func() {
-		mcpTools = doGetMCPTools(ctx, permissions, cfg)
-	})
-	return mcpTools
+// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
+func CloseMCPClients() {
+	for c := range mcpClients.Seq() {
+		_ = c.Close()
+	}
 }
 
 func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
@@ -177,42 +143,59 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 		wg.Add(1)
 		go func(name string, m config.MCPConfig) {
 			defer wg.Done()
-			switch m.Type {
-			case config.MCPStdio:
-				c, err := client.NewStdioMCPClient(
-					m.Command,
-					m.ResolvedEnv(),
-					m.Args...,
-				)
-				if err != nil {
-					slog.Error("error creating mcp client", "error", err)
-					return
-				}
-
-				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
-			case config.MCPHttp:
-				c, err := client.NewStreamableHttpClient(
-					m.URL,
-					transport.WithHTTPHeaders(m.ResolvedHeaders()),
-				)
-				if err != nil {
-					slog.Error("error creating mcp client", "error", err)
-					return
-				}
-				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
-			case config.MCPSse:
-				c, err := client.NewSSEMCPClient(
-					m.URL,
-					client.WithHeaders(m.ResolvedHeaders()),
-				)
-				if err != nil {
-					slog.Error("error creating mcp client", "error", err)
-					return
-				}
-				result.Append(getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
+			c, err := doGetClient(m)
+			if err != nil {
+				slog.Error("error creating mcp client", "error", err)
+				return
+			}
+			if err := doInitClient(ctx, name, c); err != nil {
+				slog.Error("error initializing mcp client", "error", err)
+				return
 			}
+			result.Append(getTools(ctx, name, permissions, c, cfg.WorkingDir())...)
 		}(name, m)
 	}
 	wg.Wait()
 	return slices.Collect(result.Seq())
 }
+
+func doInitClient(ctx context.Context, name string, c *client.Client) error {
+	initRequest := mcp.InitializeRequest{
+		Params: mcp.InitializeParams{
+			ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
+			ClientInfo: mcp.Implementation{
+				Name:    "Crush",
+				Version: version.Version,
+			},
+		},
+	}
+	if _, err := c.Initialize(ctx, initRequest); err != nil {
+		c.Close()
+		return err
+	}
+	mcpClients.Set(name, c)
+	return nil
+}
+
+func doGetClient(m config.MCPConfig) (*client.Client, error) {
+	switch m.Type {
+	case config.MCPStdio:
+		return client.NewStdioMCPClient(
+			m.Command,
+			m.ResolvedEnv(),
+			m.Args...,
+		)
+	case config.MCPHttp:
+		return client.NewStreamableHttpClient(
+			m.URL,
+			transport.WithHTTPHeaders(m.ResolvedHeaders()),
+		)
+	case config.MCPSse:
+		return client.NewSSEMCPClient(
+			m.URL,
+			client.WithHeaders(m.ResolvedHeaders()),
+		)
+	default:
+		return nil, fmt.Errorf("unsupported mcp type: %s", m.Type)
+	}
+}