refactor(mcp): simplify tool management and improve error handling

林玮 (Jade Lin) created

- Replace mcpClientTools with mcpClient2Tools for direct tool storage
- Consolidate tool updates in updateMcpTools function
- Improve error handling by removing client on connection errors
- Remove redundant client deletion in getTools function

Change summary

internal/llm/agent/mcp-tools.go | 48 +++++++++++++---------------------
1 file changed, 19 insertions(+), 29 deletions(-)

Detailed changes

internal/llm/agent/mcp-tools.go 🔗

@@ -76,14 +76,13 @@ type MCPClientInfo struct {
 }
 
 var (
-	mcpToolsOnce sync.Once
-	mcpTools     = csync.NewMap[string, tools.BaseTool]()
-	// mcpClientTools maps MCP name to tool names
-	mcpClientTools                                           = csync.NewMap[string, []string]()
-	mcpClients                                               = csync.NewMap[string, *client.Client]()
-	mcpStates                                                = csync.NewMap[string, MCPClientInfo]()
-	mcpBroker                                                = pubsub.NewBroker[MCPEvent]()
-	toolsMaker     func(string, []mcp.Tool) []tools.BaseTool = nil
+	mcpToolsOnce    sync.Once
+	mcpTools                                                  = csync.NewMap[string, tools.BaseTool]()
+	mcpClient2Tools                                           = csync.NewMap[string, []tools.BaseTool]()
+	mcpClients                                                = csync.NewMap[string, *client.Client]()
+	mcpStates                                                 = csync.NewMap[string, MCPClientInfo]()
+	mcpBroker                                                 = pubsub.NewBroker[MCPEvent]()
+	toolsMaker      func(string, []mcp.Tool) []tools.BaseTool = nil
 )
 
 type McpTool struct {
@@ -217,7 +216,6 @@ func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTo
 		slog.Error("error listing tools", "error", err)
 		updateMCPState(name, MCPStateError, err, nil, 0)
 		c.Close()
-		mcpClients.Del(name)
 		return nil
 	}
 	return toolsMaker(name, result.Tools)
@@ -247,8 +245,12 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien
 		Client:    client,
 		ToolCount: toolCount,
 	}
-	if state == MCPStateConnected {
+	switch state {
+	case MCPStateConnected:
 		info.ConnectedAt = time.Now()
+	case MCPStateError:
+		updateMcpTools(name, nil)
+		mcpClients.Del(name)
 	}
 	mcpStates.Set(name, info)
 
@@ -346,28 +348,16 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 
 // updateMcpTools updates the global mcpTools and mcpClientTools maps
 func updateMcpTools(mcpName string, tools []tools.BaseTool) {
-	toolNames := make([]string, 0, len(tools))
-	for _, tool := range tools {
-		name := tool.Name()
-		if _, ok := mcpTools.Get(name); !ok {
-			slog.Info("Added MCP tool", "name", name, "mcp", mcpName)
-		}
-		mcpTools.Set(name, tool)
-		toolNames = append(toolNames, name)
+	if len(tools) == 0 {
+		mcpClient2Tools.Del(mcpName)
+	} else {
+		mcpClient2Tools.Set(mcpName, tools)
 	}
-
-	// remove the tools that are no longer available
-	old, ok := mcpClientTools.Get(mcpName)
-	if ok {
-		slices.Sort(toolNames)
-		for _, name := range old {
-			if _, ok := slices.BinarySearch(toolNames, name); !ok {
-				mcpTools.Del(name)
-				slog.Info("Removed MCP tool", "name", name, "mcp", mcpName)
-			}
+	for _, tools := range mcpClient2Tools.Seq2() {
+		for _, t := range tools {
+			mcpTools.Set(t.Name(), t)
 		}
 	}
-	mcpClientTools.Set(mcpName, toolNames)
 }
 
 func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {