feat(agent): add MCP tool dynamic updates and lazy loading

林玮 (Jade Lin) created

- Update mcp-go dependency from v0.38.0 to v0.39.1
- Add NewMapFromSeq and NewLazyMap functions to csync.Map
- Refactor agent tool management to use separate base and MCP tool maps
- Implement dynamic MCP tool updates via notifications/tools/list_changed
- Add event handling system for real-time MCP tool synchronization
- Improve tool initialization with lazy loading for better performance

Change summary

go.mod                          |   2 
go.sum                          |   2 
internal/csync/maps.go          |  22 ++++++
internal/csync/maps_test.go     |  56 +++++++++++++++
internal/llm/agent/agent.go     | 123 +++++++++++++++++++++++++---------
internal/llm/agent/mcp-tools.go | 108 +++++++++++++++++++++++-------
6 files changed, 253 insertions(+), 60 deletions(-)

Detailed changes

go.mod 🔗

@@ -27,7 +27,7 @@ require (
 	github.com/google/uuid v1.6.0
 	github.com/invopop/jsonschema v0.13.0
 	github.com/joho/godotenv v1.5.1
-	github.com/mark3labs/mcp-go v0.38.0
+	github.com/mark3labs/mcp-go v0.39.1
 	github.com/muesli/termenv v0.16.0
 	github.com/ncruces/go-sqlite3 v0.28.0
 	github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646

go.sum 🔗

@@ -189,6 +189,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
 github.com/mark3labs/mcp-go v0.38.0 h1:E5tmJiIXkhwlV0pLAwAT0O5ZjUZSISE/2Jxg+6vpq4I=
 github.com/mark3labs/mcp-go v0.38.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g=
+github.com/mark3labs/mcp-go v0.39.1 h1:2oPxk7aDbQhouakkYyKl2T4hKFU1c6FDaubWyGyVE1k=
+github.com/mark3labs/mcp-go v0.39.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g=
 github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
 github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
 github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=

internal/csync/maps.go 🔗

@@ -27,6 +27,28 @@ func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] {
 	}
 }
 
+// NewMapFromSeq creates a new thread-safe map from an iter.Seq2 of key-value pairs.
+func NewMapFromSeq[k comparable, v any](seq iter.Seq2[k, v]) *Map[k, v] {
+	m := make(map[k]v)
+	seq(func(kk k, vv v) bool {
+		m[kk] = vv
+		return true
+	})
+	return NewMapFrom(m)
+}
+
+// NewLazyMap creates a new lazy-loaded map. The provided load function is
+// executed in a separate goroutine to populate the map.
+func NewLazyMap[K comparable, V any](load func() map[K]V) *Map[K, V] {
+	m := &Map[K, V]{}
+	m.mu.Lock()
+	go func() {
+		m.inner = load()
+		m.mu.Unlock()
+	}()
+	return m
+}
+
 // Set sets the value for the specified key in the map.
 func (m *Map[K, V]) Set(key K, value V) {
 	m.mu.Lock()

internal/csync/maps_test.go 🔗

@@ -5,6 +5,7 @@ import (
 	"maps"
 	"sync"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/require"
 )
@@ -36,6 +37,61 @@ func TestNewMapFrom(t *testing.T) {
 	require.Equal(t, 1, value)
 }
 
+func TestNewMapFromSeq(t *testing.T) {
+	t.Parallel()
+
+	original := map[string]int{
+		"key1": 1,
+		"key2": 2,
+	}
+	seq := func(f func(string, int) bool) {
+		for k, v := range original {
+			if !f(k, v) {
+				break
+			}
+		}
+	}
+
+	m := NewMapFromSeq(seq)
+	require.NotNil(t, m)
+	require.Equal(t, original, m.inner)
+	require.Equal(t, 2, m.Len())
+
+	value, ok := m.Get("key2")
+	require.True(t, ok)
+	require.Equal(t, 2, value)
+}
+
+func TestNewLazyMap(t *testing.T) {
+	t.Parallel()
+
+	waiter := sync.Mutex{}
+	waiter.Lock()
+	loadCalled := false
+
+	loadFunc := func() map[string]int {
+		waiter.Lock()
+		defer waiter.Unlock()
+		loadCalled = true
+		return map[string]int{
+			"key1": 1,
+			"key2": 2,
+		}
+	}
+
+	m := NewLazyMap(loadFunc)
+	require.NotNil(t, m)
+
+	waiter.Unlock() // Allow the load function to proceed
+	time.Sleep(100 * time.Millisecond)
+	require.True(t, loadCalled)
+	require.Equal(t, 2, m.Len())
+
+	value, ok := m.Get("key1")
+	require.True(t, ok)
+	require.Equal(t, 1, value)
+}
+
 func TestMap_Set(t *testing.T) {
 	t.Parallel()
 

internal/llm/agent/agent.go 🔗

@@ -65,13 +65,16 @@ type Service interface {
 }
 
 type agent struct {
+	globalCtx    context.Context
+	cleanupFuncs []func()
+
 	*pubsub.Broker[AgentEvent]
 	agentCfg config.Agent
 	sessions session.Service
 	messages message.Service
-	mcpTools []McpTool
 
-	tools *csync.Map[string, tools.BaseTool]
+	baseTools *csync.Map[string, tools.BaseTool]
+	mcpTools  *csync.Map[string, tools.BaseTool]
 
 	provider   provider.Provider
 	providerID string
@@ -173,17 +176,16 @@ func NewAgent(
 		return nil, err
 	}
 
-	toolFn := func() *csync.Map[string, tools.BaseTool] {
-		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
+	baseToolsFn := func() map[string]tools.BaseTool {
+		slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
 		defer func() {
-			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
+			slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
 		}()
 
-		cwd := cfg.WorkingDir()
-		toolMap := csync.NewMap[string, tools.BaseTool]()
-
 		// Base tools available to all agents
-		baseTools := []tools.BaseTool{
+		cwd := cfg.WorkingDir()
+		result := make(map[string]tools.BaseTool)
+		for _, tool := range []tools.BaseTool{
 			tools.NewBashTool(permissions, cwd),
 			tools.NewDownloadTool(permissions, cwd),
 			tools.NewEditTool(lspClients, permissions, history, cwd),
@@ -195,39 +197,40 @@ func NewAgent(
 			tools.NewSourcegraphTool(),
 			tools.NewViewTool(lspClients, permissions, cwd),
 			tools.NewWriteTool(lspClients, permissions, history, cwd),
-		}
-		for _, tool := range baseTools {
-			toolMap.Set(tool.Name(), tool)
-		}
-
-		mcpToolsOnce.Do(func() {
-			mcpTools = doGetMCPTools(ctx, permissions, cfg)
-		})
-		for _, mcpTool := range mcpTools {
-			toolMap.Set(mcpTool.Name(), mcpTool)
+		} {
+			result[tool.Name()] = tool
 		}
 
 		if len(lspClients) > 0 {
 			diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
-			toolMap.Set(diagnosticsTool.Name(), diagnosticsTool)
+			result[diagnosticsTool.Name()] = diagnosticsTool
 		}
 
 		if agentTool != nil {
-			toolMap.Set(agentTool.Name(), agentTool)
+			result[agentTool.Name()] = agentTool
 		}
 
-		if agentCfg.AllowedTools != nil {
-			// Filter tools based on allowed tools list
-			for toolName := range toolMap.Seq2() {
-				if !slices.Contains(agentCfg.AllowedTools, toolName) {
-					toolMap.Del(toolName)
-				}
-			}
+		return result
+	}
+	mcpToolsFn := func() map[string]tools.BaseTool {
+		slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
+		defer func() {
+			slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
+		}()
+
+		mcpToolsOnce.Do(func() {
+			doGetMCPTools(ctx, permissions, cfg)
+		})
+
+		result := make(map[string]tools.BaseTool)
+		for _, mcpTool := range mcpTools.Seq2() {
+			result[mcpTool.Name()] = mcpTool
 		}
-		return toolMap
+		return result
 	}
 
-	return &agent{
+	a := &agent{
+		globalCtx:           ctx,
 		Broker:              pubsub.NewBroker[AgentEvent](),
 		agentCfg:            agentCfg,
 		provider:            agentProvider,
@@ -238,9 +241,12 @@ func NewAgent(
 		summarizeProvider:   summarizeProvider,
 		summarizeProviderID: string(providerCfg.ID),
 		activeRequests:      csync.NewMap[string, context.CancelFunc](),
-		tools:               toolFn(),
+		mcpTools:            csync.NewLazyMap(mcpToolsFn),
+		baseTools:           csync.NewLazyMap(baseToolsFn),
 		promptQueue:         csync.NewMap[string, []string](),
-	}, nil
+	}
+	a.setupEvents()
+	return a, nil
 }
 
 func (a *agent) Model() catwalk.Model {
@@ -525,7 +531,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 	}
 
 	// Now collect tools (which may block on MCP initialization)
-	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
+	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.allTools())
 
 	// Add the session and message ID into the context if needed by tools.
 	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
@@ -563,7 +569,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 			goto out
 		default:
 			// Continue processing
-			tool, ok := a.tools.Get(toolCall.Name)
+			tool, ok := a.getToolByName(toolCall.Name)
 
 			// Tool not found
 			if !ok {
@@ -924,6 +930,12 @@ func (a *agent) CancelAll() {
 		a.Cancel(key) // key is sessionID
 	}
 
+	for _, cleanup := range a.cleanupFuncs {
+		if cleanup != nil {
+			cleanup()
+		}
+	}
+
 	timeout := time.After(5 * time.Second)
 	for a.IsBusy() {
 		select {
@@ -1029,3 +1041,46 @@ func (a *agent) UpdateModel() error {
 
 	return nil
 }
+
+func (a *agent) allTools() []tools.BaseTool {
+	result := slices.Collect(a.baseTools.Seq())
+	result = slices.AppendSeq(result, a.mcpTools.Seq())
+	return result
+}
+
+func (a *agent) getToolByName(name string) (tools.BaseTool, bool) {
+	tool, ok := a.baseTools.Get(name)
+	if !ok {
+		tool, ok = a.mcpTools.Get(name)
+	}
+	return tool, ok
+}
+
+func (a *agent) setupEvents() {
+	ctx, cancel := context.WithCancel(a.globalCtx)
+
+	go func() {
+		for event := range SubscribeMCPEvents(ctx) {
+			switch event.Payload.Type {
+			case MCPEventToolsListChanged:
+				name := event.Payload.Name
+				c, ok := mcpClients.Get(name)
+				if !ok {
+					slog.Warn("MCP client not found for tools update", "name", name)
+					continue
+				}
+				tools := getTools(ctx, name, c)
+				updateMcpTools(name, tools)
+				// Update the lazy map with the new tools
+				a.mcpTools = csync.NewMapFromSeq(mcpTools.Seq2())
+				updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
+			default:
+				continue
+			}
+		}
+	}()
+
+	a.cleanupFuncs = append(a.cleanupFuncs, func() {
+		cancel()
+	})
+}

internal/llm/agent/mcp-tools.go 🔗

@@ -52,7 +52,8 @@ func (s MCPState) String() string {
 type MCPEventType string
 
 const (
-	MCPEventStateChanged MCPEventType = "state_changed"
+	MCPEventStateChanged     MCPEventType = "state_changed"
+	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 )
 
 // MCPEvent represents an event in the MCP system
@@ -76,10 +77,13 @@ type MCPClientInfo struct {
 
 var (
 	mcpToolsOnce sync.Once
-	mcpTools     []tools.BaseTool
-	mcpClients   = csync.NewMap[string, *client.Client]()
-	mcpStates    = csync.NewMap[string, MCPClientInfo]()
-	mcpBroker    = pubsub.NewBroker[MCPEvent]()
+	mcpTools     = csync.NewMap[string, tools.BaseTool]()
+	// mcpClientTools maps MCP name to tool names
+	mcpClientTools                                           = csync.NewMap[string, []string]()
+	mcpClients                                               = csync.NewMap[string, *client.Client]()
+	mcpStates                                                = csync.NewMap[string, MCPClientInfo]()
+	mcpBroker                                                = pubsub.NewBroker[MCPEvent]()
+	toolsMaker     func(string, []mcp.Tool) []tools.BaseTool = nil
 )
 
 type McpTool struct {
@@ -192,7 +196,22 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
 	return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
 }
 
-func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
+func createToolsMaker(ctx context.Context, permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool {
+	return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool {
+		mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList))
+		for _, tool := range mcpToolsList {
+			mcpTools = append(mcpTools, &McpTool{
+				mcpName:     name,
+				tool:        tool,
+				permissions: permissions,
+				workingDir:  workingDir,
+			})
+		}
+		return mcpTools
+	}
+}
+
+func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool {
 	result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
 	if err != nil {
 		slog.Error("error listing tools", "error", err)
@@ -201,16 +220,7 @@ func getTools(ctx context.Context, name string, permissions permission.Service,
 		mcpClients.Del(name)
 		return nil
 	}
-	mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
-	for _, tool := range result.Tools {
-		mcpTools = append(mcpTools, &McpTool{
-			mcpName:     name,
-			tool:        tool,
-			permissions: permissions,
-			workingDir:  workingDir,
-		})
-	}
-	return mcpTools
+	return toolsMaker(name, result.Tools)
 }
 
 // SubscribeMCPEvents returns a channel for MCP events
@@ -252,6 +262,14 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien
 	})
 }
 
+// publishMCPEventToolsListChanged publishes a tool list changed event
+func publishMCPEventToolsListChanged(name string) {
+	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
+		Type: MCPEventToolsListChanged,
+		Name: name,
+	})
+}
+
 // CloseMCPClients closes all MCP clients. This should be called during application shutdown.
 func CloseMCPClients() {
 	for c := range mcpClients.Seq() {
@@ -274,6 +292,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 	var wg sync.WaitGroup
 	result := csync.NewSlice[tools.BaseTool]()
 
+	toolsMaker = createToolsMaker(ctx, permissions, cfg.WorkingDir())
+
 	// Initialize states for all configured MCPs
 	for name, m := range cfg.MCP {
 		if m.Disabled {
@@ -310,17 +330,46 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 			if err != nil {
 				return
 			}
+
 			mcpClients.Set(name, c)
 
-			tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
-			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
+			tools := getTools(ctx, name, c)
 			result.Append(tools...)
+			updateMcpTools(name, tools)
+			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
 		}(name, m)
 	}
 	wg.Wait()
+
 	return slices.Collect(result.Seq())
 }
 
+// updateMcpTools updates the global mcpTools and mcpClientTools maps
+func updateMcpTools(mcpName string, tools []tools.BaseTool) {
+	toolNames := make([]string, 0, len(tools))
+	for _, tool := range tools {
+		name := tool.Name()
+		if _, ok := mcpTools.Get(name); !ok {
+			slog.Info("Added MCP tool", "name", name, "mcp", mcpName)
+		}
+		mcpTools.Set(name, tool)
+		toolNames = append(toolNames, name)
+	}
+
+	// remove the tools that are no longer available
+	old, ok := mcpClientTools.Get(mcpName)
+	if ok {
+		slices.Sort(toolNames)
+		for _, name := range old {
+			if _, ok := slices.BinarySearch(toolNames, name); !ok {
+				mcpTools.Del(name)
+				slog.Info("Removed MCP tool", "name", name, "mcp", mcpName)
+			}
+		}
+	}
+	mcpClientTools.Set(mcpName, toolNames)
+}
+
 func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
 	c, err := createMcpClient(m)
 	if err != nil {
@@ -328,15 +377,24 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon
 		slog.Error("error creating mcp client", "error", err, "name", name)
 		return nil, err
 	}
-	// Only call Start() for non-stdio clients, as stdio clients auto-start
-	if m.Type != config.MCPStdio {
-		if err := c.Start(ctx); err != nil {
-			updateMCPState(name, MCPStateError, err, nil, 0)
-			slog.Error("error starting mcp client", "error", err, "name", name)
-			_ = c.Close()
-			return nil, err
+
+	c.OnNotification(func(n mcp.JSONRPCNotification) {
+		slog.Debug("Received MCP notification", "name", name, "notification", n)
+		switch n.Method {
+		case "notifications/tools/list_changed":
+			publishMCPEventToolsListChanged(name)
+		default:
+			slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
 		}
+	})
+
+	if err := c.Start(ctx); err != nil {
+		updateMCPState(name, MCPStateError, err, nil, 0)
+		slog.Error("error starting mcp client", "error", err, "name", name)
+		_ = c.Close()
+		return nil, err
 	}
+
 	if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
 		updateMCPState(name, MCPStateError, err, nil, 0)
 		slog.Error("error initializing mcp client", "error", err, "name", name)