refactor(mcp): improve MCP event handling and tool initialization

林玮 (Jade Lin) created

- Refactor MCP event subscription to handle context cancellation and channel closure
- Remove global toolsMaker function in favor of direct tool creation in getTools
- Add permissions and workingDir parameters to getTools function
- Update tool creation to include permissions and working directory directly

Change summary

internal/llm/agent/agent.go     | 39 +++++++++++++++++++---------
internal/llm/agent/mcp-tools.go | 46 +++++++++++++---------------------
2 files changed, 44 insertions(+), 41 deletions(-)

Detailed changes

internal/llm/agent/agent.go 🔗

@@ -1114,22 +1114,35 @@ func (a *agent) setupEvents(ctx context.Context) {
 	ctx, cancel := context.WithCancel(ctx)
 
 	go func() {
-		for event := range SubscribeMCPEvents(ctx) {
-			switch event.Payload.Type {
-			case MCPEventToolsListChanged:
-				name := event.Payload.Name
-				c, ok := mcpClients.Get(name)
+		subCh := SubscribeMCPEvents(ctx)
+
+		for {
+			select {
+			case event, ok := <-subCh:
 				if !ok {
-					slog.Warn("MCP client not found for tools update", "name", name)
+					slog.Debug("MCPEvents subscription channel closed")
+					return
+				}
+				switch event.Payload.Type {
+				case MCPEventToolsListChanged:
+					name := event.Payload.Name
+					c, ok := mcpClients.Get(name)
+					if !ok {
+						slog.Warn("MCP client not found for tools update", "name", name)
+						continue
+					}
+					cfg := config.Get()
+					tools := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
+					updateMcpTools(name, tools)
+					// Update the lazy map with the new tools
+					a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
+					updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
+				default:
 					continue
 				}
-				tools := getTools(ctx, name, c)
-				updateMcpTools(name, tools)
-				// Update the lazy map with the new tools
-				a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
-				updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
-			default:
-				continue
+			case <-ctx.Done():
+				slog.Debug("MCPEvents subscription cancelled")
+				return
 			}
 		}
 	}()

internal/llm/agent/mcp-tools.go 🔗

@@ -78,12 +78,11 @@ type MCPClientInfo struct {
 
 var (
 	mcpToolsOnce    sync.Once
-	mcpTools                                                  = csync.NewMap[string, tools.BaseTool]()
-	mcpClient2Tools                                           = csync.NewMap[string, []tools.BaseTool]()
-	mcpClients                                                = csync.NewMap[string, *client.Client]()
-	mcpStates                                                 = csync.NewMap[string, MCPClientInfo]()
-	mcpBroker                                                 = pubsub.NewBroker[MCPEvent]()
-	toolsMaker      func(string, []mcp.Tool) []tools.BaseTool = nil
+	mcpTools        = csync.NewMap[string, tools.BaseTool]()
+	mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
+	mcpClients      = csync.NewMap[string, *client.Client]()
+	mcpStates       = csync.NewMap[string, MCPClientInfo]()
+	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 )
 
 type McpTool struct {
@@ -198,22 +197,7 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
 	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
 }
 
-func createToolsMaker(permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool {
-	return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool {
-		mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList))
-		for _, tool := range mcpToolsList {
-			mcpTools = append(mcpTools, &McpTool{
-				mcpName:     name,
-				tool:        tool,
-				permissions: permissions,
-				workingDir:  workingDir,
-			})
-		}
-		return mcpTools
-	}
-}
-
-func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool {
+func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
 	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
 	if err != nil {
 		slog.Error("error listing tools", "error", err)
@@ -221,7 +205,16 @@ func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTo
 		c.Close()
 		return nil
 	}
-	return toolsMaker(name, result.Tools)
+	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
+	for _, tool := range result.Tools {
+		mcpTools = append(mcpTools, &McpTool{
+			mcpName:     name,
+			tool:        tool,
+			permissions: permissions,
+			workingDir:  workingDir,
+		})
+	}
+	return mcpTools
 }
 
 // SubscribeMCPEvents returns a channel for MCP events
@@ -299,9 +292,6 @@ var mcpInitRequest = mcp.InitializeRequest{
 
 func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
 	var wg sync.WaitGroup
-
-	toolsMaker = createToolsMaker(permissions, cfg.WorkingDir())
-
 	// Initialize states for all configured MCPs
 	for name, m := range cfg.MCP {
 		if m.Disabled {
@@ -341,7 +331,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 
 			mcpClients.Set(name, c)
 
-			tools := getTools(ctx, name, c)
+			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
 			updateMcpTools(name, tools)
 			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
 		}(name, m)
@@ -349,7 +339,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 	wg.Wait()
 }
 
-// updateMcpTools updates the global mcpTools and mcpClientTools maps
+// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
 func updateMcpTools(mcpName string, tools []tools.BaseTool) {
 	if len(tools) == 0 {
 		mcpClient2Tools.Del(mcpName)