From 8c7c0db22606db51910c2a24dc15319369ce60f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Wed, 1 Oct 2025 20:03:17 +0800 Subject: [PATCH] feat(mcp): notifications support - tools/list_changed (#967) Signed-off-by: Carlos Alexandro Becker Co-authored-by: Carlos Alexandro Becker --- internal/csync/maps.go | 19 +++++ internal/csync/maps_test.go | 52 +++++++++++++ internal/llm/agent/agent.go | 128 +++++++++++++++++++++++--------- internal/llm/agent/mcp-tools.go | 65 ++++++++++++---- 4 files changed, 217 insertions(+), 47 deletions(-) diff --git a/internal/csync/maps.go b/internal/csync/maps.go index b7a1f3109f6c15e7e5592cb538943a2d9e340819..1fd2005790014b2ce4bd5a78dbb7931d54cbe66c 100644 --- a/internal/csync/maps.go +++ b/internal/csync/maps.go @@ -27,6 +27,25 @@ func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] { } } +// NewLazyMap creates a new lazy-loaded map. The provided load function is +// executed in a separate goroutine to populate the map. +func NewLazyMap[K comparable, V any](load func() map[K]V) *Map[K, V] { + m := &Map[K, V]{} + m.mu.Lock() + go func() { + m.inner = load() + m.mu.Unlock() + }() + return m +} + +// Reset replaces the inner map with the new one. +func (m *Map[K, V]) Reset(input map[K]V) { + m.mu.Lock() + defer m.mu.Unlock() + m.inner = input +} + // Set sets the value for the specified key in the map. func (m *Map[K, V]) Set(key K, value V) { m.mu.Lock() diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go index 4a8019260a2610b7f5ae0d854029207c6b945d04..4c590f008dad91e8dcbc40d1b90d87ef1b3e5750 100644 --- a/internal/csync/maps_test.go +++ b/internal/csync/maps_test.go @@ -5,6 +5,8 @@ import ( "maps" "sync" "testing" + "testing/synctest" + "time" "github.com/stretchr/testify/require" ) @@ -36,6 +38,56 @@ func TestNewMapFrom(t *testing.T) { require.Equal(t, 1, value) } +func TestNewLazyMap(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + t.Helper() + + waiter := sync.Mutex{} + waiter.Lock() + loadCalled := false + + loadFunc := func() map[string]int { + waiter.Lock() + defer waiter.Unlock() + loadCalled = true + return map[string]int{ + "key1": 1, + "key2": 2, + } + } + + m := NewLazyMap(loadFunc) + require.NotNil(t, m) + + waiter.Unlock() // Allow the load function to proceed + time.Sleep(100 * time.Millisecond) + require.True(t, loadCalled) + require.Equal(t, 2, m.Len()) + + value, ok := m.Get("key1") + require.True(t, ok) + require.Equal(t, 1, value) + }) +} + +func TestMap_Reset(t *testing.T) { + t.Parallel() + + m := NewMapFrom(map[string]int{ + "a": 10, + }) + + m.Reset(map[string]int{ + "b": 20, + }) + value, ok := m.Get("b") + require.True(t, ok) + require.Equal(t, 20, value) + require.Equal(t, 1, m.Len()) +} + func TestMap_Set(t *testing.T) { t.Parallel() diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 9bae6e5b8092b987b1c8146460cef946e595beb5..1efc3fc268392c06481d61ae6e11c9d67cdc13e8 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "maps" "slices" "strings" "time" @@ -65,11 +66,13 @@ type agent struct { sessions session.Service messages message.Service permissions permission.Service - mcpTools []McpTool + baseTools *csync.Map[string, tools.BaseTool] + mcpTools *csync.Map[string, tools.BaseTool] + lspClients *csync.Map[string, *lsp.Client] - tools *csync.LazySlice[tools.BaseTool] // We need this to be able to update it when model changes - agentToolFn func() (tools.BaseTool, error) + agentToolFn func() (tools.BaseTool, error) + cleanupFuncs []func() provider provider.Provider providerID string @@ -171,14 +174,16 @@ func NewAgent( return nil, err } - toolFn := func() []tools.BaseTool { - slog.Info("Initializing agent tools", "agent", agentCfg.ID) + baseToolsFn := func() map[string]tools.BaseTool { + slog.Info("Initializing agent base tools", "agent", agentCfg.ID) defer func() { - slog.Info("Initialized agent tools", "agent", agentCfg.ID) + slog.Info("Initialized agent base tools", "agent", agentCfg.ID) }() + // Base tools available to all agents cwd := cfg.WorkingDir() - allTools := []tools.BaseTool{ + result := make(map[string]tools.BaseTool) + for _, tool := range []tools.BaseTool{ tools.NewBashTool(permissions, cwd, cfg.Options.Attribution), tools.NewDownloadTool(permissions, cwd), tools.NewEditTool(lspClients, permissions, history, cwd), @@ -190,36 +195,25 @@ func NewAgent( tools.NewSourcegraphTool(), tools.NewViewTool(lspClients, permissions, cwd), tools.NewWriteTool(lspClients, permissions, history, cwd), + } { + result[tool.Name()] = tool } + return result + } + mcpToolsFn := func() map[string]tools.BaseTool { + slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID) + defer func() { + slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID) + }() mcpToolsOnce.Do(func() { - mcpTools = doGetMCPTools(ctx, permissions, cfg) + doGetMCPTools(ctx, permissions, cfg) }) - withCoderTools := func(t []tools.BaseTool) []tools.BaseTool { - if agentCfg.ID == "coder" { - t = append(t, mcpTools...) - if lspClients.Len() > 0 { - t = append(t, tools.NewDiagnosticsTool(lspClients)) - } - } - return t - } - - if agentCfg.AllowedTools == nil { - return withCoderTools(allTools) - } - - var filteredTools []tools.BaseTool - for _, tool := range allTools { - if slices.Contains(agentCfg.AllowedTools, tool.Name()) { - filteredTools = append(filteredTools, tool) - } - } - return withCoderTools(filteredTools) + return maps.Collect(mcpTools.Seq2()) } - return &agent{ + a := &agent{ Broker: pubsub.NewBroker[AgentEvent](), agentCfg: agentCfg, provider: agentProvider, @@ -231,10 +225,14 @@ func NewAgent( summarizeProviderID: string(providerCfg.ID), agentToolFn: agentToolFn, activeRequests: csync.NewMap[string, context.CancelFunc](), - tools: csync.NewLazySlice(toolFn), + mcpTools: csync.NewLazyMap(mcpToolsFn), + baseTools: csync.NewLazyMap(baseToolsFn), promptQueue: csync.NewMap[string, []string](), permissions: permissions, - }, nil + lspClients: lspClients, + } + a.setupEvents(ctx) + return a, nil } func (a *agent) Model() catwalk.Model { @@ -517,7 +515,18 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string } func (a *agent) getAllTools() ([]tools.BaseTool, error) { - allTools := slices.Collect(a.tools.Seq()) + var allTools []tools.BaseTool + for tool := range a.baseTools.Seq() { + if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) { + allTools = append(allTools, tool) + } + } + if a.agentCfg.ID == "coder" { + allTools = slices.AppendSeq(allTools, a.mcpTools.Seq()) + if a.lspClients.Len() > 0 { + allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients)) + } + } if a.agentToolFn != nil { agentTool, agentToolErr := a.agentToolFn() if agentToolErr != nil { @@ -591,7 +600,7 @@ loop: default: // Continue processing var tool tools.BaseTool - allTools, _ := a.getAllTools() + allTools, _ = a.getAllTools() for _, availableTool := range allTools { if availableTool.Info().Name == toolCall.Name { tool = availableTool @@ -960,6 +969,12 @@ func (a *agent) CancelAll() { a.Cancel(key) // key is sessionID } + for _, cleanup := range a.cleanupFuncs { + if cleanup != nil { + cleanup() + } + } + timeout := time.After(5 * time.Second) for a.IsBusy() { select { @@ -1071,3 +1086,48 @@ func (a *agent) UpdateModel() error { return nil } + +func (a *agent) setupEvents(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + + go func() { + subCh := SubscribeMCPEvents(ctx) + + for { + select { + case event, ok := <-subCh: + if !ok { + slog.Debug("MCPEvents subscription channel closed") + return + } + switch event.Payload.Type { + case MCPEventToolsListChanged: + name := event.Payload.Name + c, ok := mcpClients.Get(name) + if !ok { + slog.Warn("MCP client not found for tools update", "name", name) + continue + } + cfg := config.Get() + tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir()) + if err != nil { + slog.Error("error listing tools", "error", err) + updateMCPState(name, MCPStateError, err, nil, 0) + _ = c.Close() + continue + } + updateMcpTools(name, tools) + a.mcpTools.Reset(maps.Collect(mcpTools.Seq2())) + updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len()) + default: + continue + } + case <-ctx.Done(): + slog.Debug("MCPEvents subscription cancelled") + return + } + } + }() + + a.cleanupFuncs = append(a.cleanupFuncs, cancel) +} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index ebd1698f2f7bf45ecda15c9160464e3d295ce3d6..181f32b7280faf3eb36040d2ebecf3f892350f53 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -8,7 +8,6 @@ import ( "fmt" "log/slog" "maps" - "slices" "strings" "sync" "time" @@ -54,7 +53,8 @@ func (s MCPState) String() string { type MCPEventType string const ( - MCPEventStateChanged MCPEventType = "state_changed" + MCPEventStateChanged MCPEventType = "state_changed" + MCPEventToolsListChanged MCPEventType = "tools_list_changed" ) // MCPEvent represents an event in the MCP system @@ -77,11 +77,12 @@ type MCPClientInfo struct { } var ( - mcpToolsOnce sync.Once - mcpTools []tools.BaseTool - mcpClients = csync.NewMap[string, *client.Client]() - mcpStates = csync.NewMap[string, MCPClientInfo]() - mcpBroker = pubsub.NewBroker[MCPEvent]() + mcpToolsOnce sync.Once + mcpTools = csync.NewMap[string, tools.BaseTool]() + mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]() + mcpClients = csync.NewMap[string, *client.Client]() + mcpStates = csync.NewMap[string, MCPClientInfo]() + mcpBroker = pubsub.NewBroker[MCPEvent]() ) type McpTool struct { @@ -237,8 +238,12 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien Client: client, ToolCount: toolCount, } - if state == MCPStateConnected { + switch state { + case MCPStateConnected: info.ConnectedAt = time.Now() + case MCPStateError: + updateMcpTools(name, nil) + mcpClients.Del(name) } mcpStates.Set(name, info) @@ -252,6 +257,14 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien }) } +// publishMCPEventToolsListChanged publishes a tool list changed event +func publishMCPEventToolsListChanged(name string) { + mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ + Type: MCPEventToolsListChanged, + Name: name, + }) +} + // CloseMCPClients closes all MCP clients. This should be called during application shutdown. func CloseMCPClients() error { var errs []error @@ -274,10 +287,8 @@ var mcpInitRequest = mcp.InitializeRequest{ }, } -func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { +func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) { var wg sync.WaitGroup - result := csync.NewSlice[tools.BaseTool]() - // Initialize states for all configured MCPs for name, m := range cfg.MCP { if m.Disabled { @@ -316,6 +327,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con return } + mcpClients.Set(name, c) + tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir()) if err != nil { slog.Error("error listing tools", "error", err) @@ -324,13 +337,26 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con return } + updateMcpTools(name, tools) mcpClients.Set(name, c) updateMCPState(name, MCPStateConnected, nil, c, len(tools)) - result.Append(tools...) }(name, m) } wg.Wait() - return slices.Collect(result.Seq()) +} + +// updateMcpTools updates the global mcpTools and mcpClient2Tools maps +func updateMcpTools(mcpName string, tools []tools.BaseTool) { + if len(tools) == 0 { + mcpClient2Tools.Del(mcpName) + } else { + mcpClient2Tools.Set(mcpName, tools) + } + for _, tools := range mcpClient2Tools.Seq2() { + for _, t := range tools { + mcpTools.Set(t.Name(), t) + } + } } func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) { @@ -341,11 +367,22 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon return nil, err } + c.OnNotification(func(n mcp.JSONRPCNotification) { + slog.Debug("Received MCP notification", "name", name, "notification", n) + switch n.Method { + case "notifications/tools/list_changed": + publishMCPEventToolsListChanged(name) + default: + slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method) + } + }) + // XXX: ideally we should be able to use context.WithTimeout here, but, // the SSE MCP client will start failing once that context is canceled. timeout := mcpTimeout(m) mcpCtx, cancel := context.WithCancel(ctx) cancelTimer := time.AfterFunc(timeout, cancel) + if err := c.Start(mcpCtx); err != nil { updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) slog.Error("error starting mcp client", "error", err, "name", name) @@ -353,6 +390,7 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon cancel() return nil, err } + if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil { updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0) slog.Error("error initializing mcp client", "error", err, "name", name) @@ -360,6 +398,7 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon cancel() return nil, err } + cancelTimer.Stop() slog.Info("Initialized mcp client", "name", name) return c, nil