Merge pull request #208 from charmbracelet/get-tools-mu

Kujtim Hoxha created

refactor: use sync primitives in GetMcpTools

Change summary

cmd/root.go                     |  20 ------
internal/llm/agent/agent.go     |   2 
internal/llm/agent/mcp-tools.go | 103 +++++++++++++++++++++-------------
3 files changed, 63 insertions(+), 62 deletions(-)

Detailed changes

cmd/root.go 🔗

@@ -6,14 +6,11 @@ import (
 	"io"
 	"log/slog"
 	"os"
-	"time"
 
 	tea "github.com/charmbracelet/bubbletea/v2"
 	"github.com/charmbracelet/crush/internal/app"
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/db"
-	"github.com/charmbracelet/crush/internal/llm/agent"
-	"github.com/charmbracelet/crush/internal/log"
 	"github.com/charmbracelet/crush/internal/tui"
 	"github.com/charmbracelet/crush/internal/version"
 	"github.com/charmbracelet/fang"
@@ -92,9 +89,6 @@ to assist developers in writing, debugging, and understanding code directly from
 		}
 		defer app.Shutdown()
 
-		// Initialize MCP tools early for both modes
-		initMCPTools(ctx, app, cfg)
-
 		prompt, err = maybePrependStdin(prompt)
 		if err != nil {
 			slog.Error(fmt.Sprintf("Failed to read from stdin: %v", err))
@@ -126,20 +120,6 @@ to assist developers in writing, debugging, and understanding code directly from
 	},
 }
 
-func initMCPTools(ctx context.Context, app *app.App, cfg *config.Config) {
-	go func() {
-		defer log.RecoverPanic("MCP-goroutine", nil)
-
-		// Create a context with timeout for the initial MCP tools fetch
-		ctxWithTimeout, cancel := context.WithTimeout(ctx, 30*time.Second)
-		defer cancel()
-
-		// Set this up once with proper error handling
-		agent.GetMcpTools(ctxWithTimeout, app.Permissions, cfg)
-		slog.Info("MCP message handling goroutine exiting")
-	}()
-}
-
 func Execute() {
 	if err := fang.Execute(
 		context.Background(),

internal/llm/agent/agent.go 🔗

@@ -94,7 +94,7 @@ func NewAgent(
 ) (Service, error) {
 	ctx := context.Background()
 	cfg := config.Get()
-	otherTools := GetMcpTools(ctx, permissions, cfg)
+	otherTools := GetMCPTools(ctx, permissions, cfg)
 	if len(lspClients) > 0 {
 		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 	}

internal/llm/agent/mcp-tools.go 🔗

@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"fmt"
 	"log/slog"
+	"sync"
 
 	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/llm/tools"
@@ -154,8 +155,6 @@ func NewMcpTool(name string, tool mcp.Tool, permissions permission.Service, mcpC
 	}
 }
 
-var mcpTools []tools.BaseTool
-
 func getTools(ctx context.Context, name string, m config.MCPConfig, permissions permission.Service, c MCPClient, workingDir string) []tools.BaseTool {
 	var stdioTools []tools.BaseTool
 	initRequest := mcp.InitializeRequest{}
@@ -183,50 +182,72 @@ func getTools(ctx context.Context, name string, m config.MCPConfig, permissions
 	return stdioTools
 }
 
-func GetMcpTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
-	if len(mcpTools) > 0 {
-		return mcpTools
-	}
+var (
+	mcpToolsOnce sync.Once
+	mcpTools     []tools.BaseTool
+)
+
+func GetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
+	mcpToolsOnce.Do(func() {
+		mcpTools = doGetMCPTools(ctx, permissions, cfg)
+	})
+	return mcpTools
+}
+
+func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
+	var mu sync.Mutex
+	var wg sync.WaitGroup
+	var result []tools.BaseTool
 	for name, m := range cfg.MCP {
 		if m.Disabled {
 			slog.Debug("skipping disabled mcp", "name", name)
 			continue
 		}
-		switch m.Type {
-		case config.MCPStdio:
-			c, err := client.NewStdioMCPClient(
-				m.Command,
-				m.Env,
-				m.Args...,
-			)
-			if err != nil {
-				slog.Error("error creating mcp client", "error", err)
-				continue
+		wg.Add(1)
+		go func(name string, m config.MCPConfig) {
+			defer wg.Done()
+			switch m.Type {
+			case config.MCPStdio:
+				c, err := client.NewStdioMCPClient(
+					m.Command,
+					m.Env,
+					m.Args...,
+				)
+				if err != nil {
+					slog.Error("error creating mcp client", "error", err)
+					return
+				}
+
+				mu.Lock()
+				result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
+				mu.Unlock()
+			case config.MCPHttp:
+				c, err := client.NewStreamableHttpClient(
+					m.URL,
+					transport.WithHTTPHeaders(m.Headers),
+				)
+				if err != nil {
+					slog.Error("error creating mcp client", "error", err)
+					return
+				}
+				mu.Lock()
+				result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
+				mu.Unlock()
+			case config.MCPSse:
+				c, err := client.NewSSEMCPClient(
+					m.URL,
+					client.WithHeaders(m.Headers),
+				)
+				if err != nil {
+					slog.Error("error creating mcp client", "error", err)
+					return
+				}
+				mu.Lock()
+				result = append(result, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
+				mu.Unlock()
 			}
-
-			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
-		case config.MCPHttp:
-			c, err := client.NewStreamableHttpClient(
-				m.URL,
-				transport.WithHTTPHeaders(m.Headers),
-			)
-			if err != nil {
-				slog.Error("error creating mcp client", "error", err)
-				continue
-			}
-			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
-		case config.MCPSse:
-			c, err := client.NewSSEMCPClient(
-				m.URL,
-				client.WithHeaders(m.Headers),
-			)
-			if err != nil {
-				slog.Error("error creating mcp client", "error", err)
-				continue
-			}
-			mcpTools = append(mcpTools, getTools(ctx, name, m, permissions, c, cfg.WorkingDir())...)
-		}
+		}(name, m)
 	}
-
-	return mcpTools
+	wg.Wait()
+	return result
 }