feat(mcp): notifications support - tools/list_changed (#967)

林玮 (Jade Lin) and Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>
Co-authored-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/csync/maps.go          |  19 +++++
internal/csync/maps_test.go     |  52 ++++++++++++++
internal/llm/agent/agent.go     | 128 +++++++++++++++++++++++++---------
internal/llm/agent/mcp-tools.go |  65 ++++++++++++++---
4 files changed, 217 insertions(+), 47 deletions(-)

Detailed changes

internal/csync/maps.go 🔗

@@ -27,6 +27,25 @@ func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] {
 	}
 }
 
+// NewLazyMap creates a new lazy-loaded map. The provided load function is
+// executed in a separate goroutine to populate the map.
+func NewLazyMap[K comparable, V any](load func() map[K]V) *Map[K, V] {
+	m := &Map[K, V]{}
+	m.mu.Lock()
+	go func() {
+		m.inner = load()
+		m.mu.Unlock()
+	}()
+	return m
+}
+
+// Reset replaces the inner map with the new one.
+func (m *Map[K, V]) Reset(input map[K]V) {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	m.inner = input
+}
+
 // Set sets the value for the specified key in the map.
 func (m *Map[K, V]) Set(key K, value V) {
 	m.mu.Lock()

internal/csync/maps_test.go 🔗

@@ -5,6 +5,8 @@ import (
 	"maps"
 	"sync"
 	"testing"
+	"testing/synctest"
+	"time"
 
 	"github.com/stretchr/testify/require"
 )
@@ -36,6 +38,56 @@ func TestNewMapFrom(t *testing.T) {
 	require.Equal(t, 1, value)
 }
 
+func TestNewLazyMap(t *testing.T) {
+	t.Parallel()
+
+	synctest.Test(t, func(t *testing.T) {
+		t.Helper()
+
+		waiter := sync.Mutex{}
+		waiter.Lock()
+		loadCalled := false
+
+		loadFunc := func() map[string]int {
+			waiter.Lock()
+			defer waiter.Unlock()
+			loadCalled = true
+			return map[string]int{
+				"key1": 1,
+				"key2": 2,
+			}
+		}
+
+		m := NewLazyMap(loadFunc)
+		require.NotNil(t, m)
+
+		waiter.Unlock() // Allow the load function to proceed
+		time.Sleep(100 * time.Millisecond)
+		require.True(t, loadCalled)
+		require.Equal(t, 2, m.Len())
+
+		value, ok := m.Get("key1")
+		require.True(t, ok)
+		require.Equal(t, 1, value)
+	})
+}
+
+func TestMap_Reset(t *testing.T) {
+	t.Parallel()
+
+	m := NewMapFrom(map[string]int{
+		"a": 10,
+	})
+
+	m.Reset(map[string]int{
+		"b": 20,
+	})
+	value, ok := m.Get("b")
+	require.True(t, ok)
+	require.Equal(t, 20, value)
+	require.Equal(t, 1, m.Len())
+}
+
 func TestMap_Set(t *testing.T) {
 	t.Parallel()
 

internal/llm/agent/agent.go 🔗

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"log/slog"
+	"maps"
 	"slices"
 	"strings"
 	"time"
@@ -65,11 +66,13 @@ type agent struct {
 	sessions    session.Service
 	messages    message.Service
 	permissions permission.Service
-	mcpTools    []McpTool
+	baseTools   *csync.Map[string, tools.BaseTool]
+	mcpTools    *csync.Map[string, tools.BaseTool]
+	lspClients  *csync.Map[string, *lsp.Client]
 
-	tools *csync.LazySlice[tools.BaseTool]
 	// We need this to be able to update it when model changes
-	agentToolFn func() (tools.BaseTool, error)
+	agentToolFn  func() (tools.BaseTool, error)
+	cleanupFuncs []func()
 
 	provider   provider.Provider
 	providerID string
@@ -171,14 +174,16 @@ func NewAgent(
 		return nil, err
 	}
 
-	toolFn := func() []tools.BaseTool {
-		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
+	baseToolsFn := func() map[string]tools.BaseTool {
+		slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
 		defer func() {
-			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
+			slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
 		}()
 
+		// Base tools available to all agents
 		cwd := cfg.WorkingDir()
-		allTools := []tools.BaseTool{
+		result := make(map[string]tools.BaseTool)
+		for _, tool := range []tools.BaseTool{
 			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 			tools.NewDownloadTool(permissions, cwd),
 			tools.NewEditTool(lspClients, permissions, history, cwd),
@@ -190,36 +195,25 @@ func NewAgent(
 			tools.NewSourcegraphTool(),
 			tools.NewViewTool(lspClients, permissions, cwd),
 			tools.NewWriteTool(lspClients, permissions, history, cwd),
+		} {
+			result[tool.Name()] = tool
 		}
+		return result
+	}
+	mcpToolsFn := func() map[string]tools.BaseTool {
+		slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
+		defer func() {
+			slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
+		}()
 
 		mcpToolsOnce.Do(func() {
-			mcpTools = doGetMCPTools(ctx, permissions, cfg)
+			doGetMCPTools(ctx, permissions, cfg)
 		})
 
-		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
-			if agentCfg.ID == "coder" {
-				t = append(t, mcpTools...)
-				if lspClients.Len() > 0 {
-					t = append(t, tools.NewDiagnosticsTool(lspClients))
-				}
-			}
-			return t
-		}
-
-		if agentCfg.AllowedTools == nil {
-			return withCoderTools(allTools)
-		}
-
-		var filteredTools []tools.BaseTool
-		for _, tool := range allTools {
-			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
-				filteredTools = append(filteredTools, tool)
-			}
-		}
-		return withCoderTools(filteredTools)
+		return maps.Collect(mcpTools.Seq2())
 	}
 
-	return &agent{
+	a := &agent{
 		Broker:              pubsub.NewBroker[AgentEvent](),
 		agentCfg:            agentCfg,
 		provider:            agentProvider,
@@ -231,10 +225,14 @@ func NewAgent(
 		summarizeProviderID: string(providerCfg.ID),
 		agentToolFn:         agentToolFn,
 		activeRequests:      csync.NewMap[string, context.CancelFunc](),
-		tools:               csync.NewLazySlice(toolFn),
+		mcpTools:            csync.NewLazyMap(mcpToolsFn),
+		baseTools:           csync.NewLazyMap(baseToolsFn),
 		promptQueue:         csync.NewMap[string, []string](),
 		permissions:         permissions,
-	}, nil
+		lspClients:          lspClients,
+	}
+	a.setupEvents(ctx)
+	return a, nil
 }
 
 func (a *agent) Model() catwalk.Model {
@@ -517,7 +515,18 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
 }
 
 func (a *agent) getAllTools() ([]tools.BaseTool, error) {
-	allTools := slices.Collect(a.tools.Seq())
+	var allTools []tools.BaseTool
+	for tool := range a.baseTools.Seq() {
+		if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
+			allTools = append(allTools, tool)
+		}
+	}
+	if a.agentCfg.ID == "coder" {
+		allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
+		if a.lspClients.Len() > 0 {
+			allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
+		}
+	}
 	if a.agentToolFn != nil {
 		agentTool, agentToolErr := a.agentToolFn()
 		if agentToolErr != nil {
@@ -591,7 +600,7 @@ loop:
 		default:
 			// Continue processing
 			var tool tools.BaseTool
-			allTools, _ := a.getAllTools()
+			allTools, _ = a.getAllTools()
 			for _, availableTool := range allTools {
 				if availableTool.Info().Name == toolCall.Name {
 					tool = availableTool
@@ -960,6 +969,12 @@ func (a *agent) CancelAll() {
 		a.Cancel(key) // key is sessionID
 	}
 
+	for _, cleanup := range a.cleanupFuncs {
+		if cleanup != nil {
+			cleanup()
+		}
+	}
+
 	timeout := time.After(5 * time.Second)
 	for a.IsBusy() {
 		select {
@@ -1071,3 +1086,48 @@ func (a *agent) UpdateModel() error {
 
 	return nil
 }
+
+func (a *agent) setupEvents(ctx context.Context) {
+	ctx, cancel := context.WithCancel(ctx)
+
+	go func() {
+		subCh := SubscribeMCPEvents(ctx)
+
+		for {
+			select {
+			case event, ok := <-subCh:
+				if !ok {
+					slog.Debug("MCPEvents subscription channel closed")
+					return
+				}
+				switch event.Payload.Type {
+				case MCPEventToolsListChanged:
+					name := event.Payload.Name
+					c, ok := mcpClients.Get(name)
+					if !ok {
+						slog.Warn("MCP client not found for tools update", "name", name)
+						continue
+					}
+					cfg := config.Get()
+					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
+					if err != nil {
+						slog.Error("error listing tools", "error", err)
+						updateMCPState(name, MCPStateError, err, nil, 0)
+						_ = c.Close()
+						continue
+					}
+					updateMcpTools(name, tools)
+					a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
+					updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
+				default:
+					continue
+				}
+			case <-ctx.Done():
+				slog.Debug("MCPEvents subscription cancelled")
+				return
+			}
+		}
+	}()
+
+	a.cleanupFuncs = append(a.cleanupFuncs, cancel)
+}

internal/llm/agent/mcp-tools.go 🔗

@@ -8,7 +8,6 @@ import (
 	"fmt"
 	"log/slog"
 	"maps"
-	"slices"
 	"strings"
 	"sync"
 	"time"
@@ -54,7 +53,8 @@ func (s MCPState) String() string {
 type MCPEventType string
 
 const (
-	MCPEventStateChanged MCPEventType = "state_changed"
+	MCPEventStateChanged     MCPEventType = "state_changed"
+	MCPEventToolsListChanged MCPEventType = "tools_list_changed"
 )
 
 // MCPEvent represents an event in the MCP system
@@ -77,11 +77,12 @@ type MCPClientInfo struct {
 }
 
 var (
-	mcpToolsOnce sync.Once
-	mcpTools     []tools.BaseTool
-	mcpClients   = csync.NewMap[string, *client.Client]()
-	mcpStates    = csync.NewMap[string, MCPClientInfo]()
-	mcpBroker    = pubsub.NewBroker[MCPEvent]()
+	mcpToolsOnce    sync.Once
+	mcpTools        = csync.NewMap[string, tools.BaseTool]()
+	mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
+	mcpClients      = csync.NewMap[string, *client.Client]()
+	mcpStates       = csync.NewMap[string, MCPClientInfo]()
+	mcpBroker       = pubsub.NewBroker[MCPEvent]()
 )
 
 type McpTool struct {
@@ -237,8 +238,12 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien
 		Client:    client,
 		ToolCount: toolCount,
 	}
-	if state == MCPStateConnected {
+	switch state {
+	case MCPStateConnected:
 		info.ConnectedAt = time.Now()
+	case MCPStateError:
+		updateMcpTools(name, nil)
+		mcpClients.Del(name)
 	}
 	mcpStates.Set(name, info)
 
@@ -252,6 +257,14 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien
 	})
 }
 
+// publishMCPEventToolsListChanged publishes a tool list changed event
+func publishMCPEventToolsListChanged(name string) {
+	mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
+		Type: MCPEventToolsListChanged,
+		Name: name,
+	})
+}
+
 // CloseMCPClients closes all MCP clients. This should be called during application shutdown.
 func CloseMCPClients() error {
 	var errs []error
@@ -274,10 +287,8 @@ var mcpInitRequest = mcp.InitializeRequest{
 	},
 }
 
-func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
+func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
 	var wg sync.WaitGroup
-	result := csync.NewSlice[tools.BaseTool]()
-
 	// Initialize states for all configured MCPs
 	for name, m := range cfg.MCP {
 		if m.Disabled {
@@ -316,6 +327,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 				return
 			}
 
+			mcpClients.Set(name, c)
+
 			tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
 			if err != nil {
 				slog.Error("error listing tools", "error", err)
@@ -324,13 +337,26 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 				return
 			}
 
+			updateMcpTools(name, tools)
 			mcpClients.Set(name, c)
 			updateMCPState(name, MCPStateConnected, nil, c, len(tools))
-			result.Append(tools...)
 		}(name, m)
 	}
 	wg.Wait()
-	return slices.Collect(result.Seq())
+}
+
+// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
+func updateMcpTools(mcpName string, tools []tools.BaseTool) {
+	if len(tools) == 0 {
+		mcpClient2Tools.Del(mcpName)
+	} else {
+		mcpClient2Tools.Set(mcpName, tools)
+	}
+	for _, tools := range mcpClient2Tools.Seq2() {
+		for _, t := range tools {
+			mcpTools.Set(t.Name(), t)
+		}
+	}
 }
 
 func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
@@ -341,11 +367,22 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon
 		return nil, err
 	}
 
+	c.OnNotification(func(n mcp.JSONRPCNotification) {
+		slog.Debug("Received MCP notification", "name", name, "notification", n)
+		switch n.Method {
+		case "notifications/tools/list_changed":
+			publishMCPEventToolsListChanged(name)
+		default:
+			slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
+		}
+	})
+
 	// XXX: ideally we should be able to use context.WithTimeout here, but,
 	// the SSE MCP client will start failing once that context is canceled.
 	timeout := mcpTimeout(m)
 	mcpCtx, cancel := context.WithCancel(ctx)
 	cancelTimer := time.AfterFunc(timeout, cancel)
+
 	if err := c.Start(mcpCtx); err != nil {
 		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
 		slog.Error("error starting mcp client", "error", err, "name", name)
@@ -353,6 +390,7 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon
 		cancel()
 		return nil, err
 	}
+
 	if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil {
 		updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
 		slog.Error("error initializing mcp client", "error", err, "name", name)
@@ -360,6 +398,7 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon
 		cancel()
 		return nil, err
 	}
+
 	cancelTimer.Stop()
 	slog.Info("Initialized mcp client", "name", name)
 	return c, nil