From 585f459318b16b94cdaa8b72fbea1b4375a400ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E7=8E=AE=20=28Jade=20Lin=29?= Date: Fri, 5 Sep 2025 21:24:51 +0800 Subject: [PATCH] feat(agent): add MCP tool dynamic updates and lazy loading - Update mcp-go dependency from v0.38.0 to v0.39.1 - Add NewMapFromSeq and NewLazyMap functions to csync.Map - Refactor agent tool management to use separate base and MCP tool maps - Implement dynamic MCP tool updates via notifications/tools/list_changed - Add event handling system for real-time MCP tool synchronization - Improve tool initialization with lazy loading for better performance --- go.mod | 2 +- go.sum | 2 + internal/csync/maps.go | 22 ++++++ internal/csync/maps_test.go | 56 +++++++++++++++ internal/llm/agent/agent.go | 123 +++++++++++++++++++++++--------- internal/llm/agent/mcp-tools.go | 108 +++++++++++++++++++++------- 6 files changed, 253 insertions(+), 60 deletions(-) diff --git a/go.mod b/go.mod index e60c2fa5eb50811c258ed2e833c73083c6371465..d1e6fa03e783c9e323cfddad14059dd229948f8c 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,7 @@ require ( github.com/google/uuid v1.6.0 github.com/invopop/jsonschema v0.13.0 github.com/joho/godotenv v1.5.1 - github.com/mark3labs/mcp-go v0.38.0 + github.com/mark3labs/mcp-go v0.39.1 github.com/muesli/termenv v0.16.0 github.com/ncruces/go-sqlite3 v0.28.0 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index 668cc533e5d8b33c7a21de01f0608cc075b18307..705c77169acede4fca27306c1891b7d17eb10ca9 100644 --- a/go.sum +++ b/go.sum @@ -189,6 +189,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0 github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mark3labs/mcp-go v0.38.0 h1:E5tmJiIXkhwlV0pLAwAT0O5ZjUZSISE/2Jxg+6vpq4I= github.com/mark3labs/mcp-go v0.38.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= +github.com/mark3labs/mcp-go v0.39.1 h1:2oPxk7aDbQhouakkYyKl2T4hKFU1c6FDaubWyGyVE1k= +github.com/mark3labs/mcp-go v0.39.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= diff --git a/internal/csync/maps.go b/internal/csync/maps.go index b7a1f3109f6c15e7e5592cb538943a2d9e340819..c24968825140fe134a00e2f09c1df06781c96799 100644 --- a/internal/csync/maps.go +++ b/internal/csync/maps.go @@ -27,6 +27,28 @@ func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] { } } +// NewMapFromSeq creates a new thread-safe map from an iter.Seq2 of key-value pairs. +func NewMapFromSeq[k comparable, v any](seq iter.Seq2[k, v]) *Map[k, v] { + m := make(map[k]v) + seq(func(kk k, vv v) bool { + m[kk] = vv + return true + }) + return NewMapFrom(m) +} + +// NewLazyMap creates a new lazy-loaded map. The provided load function is +// executed in a separate goroutine to populate the map. +func NewLazyMap[K comparable, V any](load func() map[K]V) *Map[K, V] { + m := &Map[K, V]{} + m.mu.Lock() + go func() { + m.inner = load() + m.mu.Unlock() + }() + return m +} + // Set sets the value for the specified key in the map. func (m *Map[K, V]) Set(key K, value V) { m.mu.Lock() diff --git a/internal/csync/maps_test.go b/internal/csync/maps_test.go index 4a8019260a2610b7f5ae0d854029207c6b945d04..4fd9d5bfbf06db4f2787de6bc8f55dc4e1df44ef 100644 --- a/internal/csync/maps_test.go +++ b/internal/csync/maps_test.go @@ -5,6 +5,7 @@ import ( "maps" "sync" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -36,6 +37,61 @@ func TestNewMapFrom(t *testing.T) { require.Equal(t, 1, value) } +func TestNewMapFromSeq(t *testing.T) { + t.Parallel() + + original := map[string]int{ + "key1": 1, + "key2": 2, + } + seq := func(f func(string, int) bool) { + for k, v := range original { + if !f(k, v) { + break + } + } + } + + m := NewMapFromSeq(seq) + require.NotNil(t, m) + require.Equal(t, original, m.inner) + require.Equal(t, 2, m.Len()) + + value, ok := m.Get("key2") + require.True(t, ok) + require.Equal(t, 2, value) +} + +func TestNewLazyMap(t *testing.T) { + t.Parallel() + + waiter := sync.Mutex{} + waiter.Lock() + loadCalled := false + + loadFunc := func() map[string]int { + waiter.Lock() + defer waiter.Unlock() + loadCalled = true + return map[string]int{ + "key1": 1, + "key2": 2, + } + } + + m := NewLazyMap(loadFunc) + require.NotNil(t, m) + + waiter.Unlock() // Allow the load function to proceed + time.Sleep(100 * time.Millisecond) + require.True(t, loadCalled) + require.Equal(t, 2, m.Len()) + + value, ok := m.Get("key1") + require.True(t, ok) + require.Equal(t, 1, value) +} + func TestMap_Set(t *testing.T) { t.Parallel() diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 6f3218be830ec326156f1cfae3a40f4e94ec767d..5e1716a74db8fb92aebcf1c442db6b1bcf431eed 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -65,13 +65,16 @@ type Service interface { } type agent struct { + globalCtx context.Context + cleanupFuncs []func() + *pubsub.Broker[AgentEvent] agentCfg config.Agent sessions session.Service messages message.Service - mcpTools []McpTool - tools *csync.Map[string, tools.BaseTool] + baseTools *csync.Map[string, tools.BaseTool] + mcpTools *csync.Map[string, tools.BaseTool] provider provider.Provider providerID string @@ -173,17 +176,16 @@ func NewAgent( return nil, err } - toolFn := func() *csync.Map[string, tools.BaseTool] { - slog.Info("Initializing agent tools", "agent", agentCfg.ID) + baseToolsFn := func() map[string]tools.BaseTool { + slog.Info("Initializing agent base tools", "agent", agentCfg.ID) defer func() { - slog.Info("Initialized agent tools", "agent", agentCfg.ID) + slog.Info("Initialized agent base tools", "agent", agentCfg.ID) }() - cwd := cfg.WorkingDir() - toolMap := csync.NewMap[string, tools.BaseTool]() - // Base tools available to all agents - baseTools := []tools.BaseTool{ + cwd := cfg.WorkingDir() + result := make(map[string]tools.BaseTool) + for _, tool := range []tools.BaseTool{ tools.NewBashTool(permissions, cwd), tools.NewDownloadTool(permissions, cwd), tools.NewEditTool(lspClients, permissions, history, cwd), @@ -195,39 +197,40 @@ func NewAgent( tools.NewSourcegraphTool(), tools.NewViewTool(lspClients, permissions, cwd), tools.NewWriteTool(lspClients, permissions, history, cwd), - } - for _, tool := range baseTools { - toolMap.Set(tool.Name(), tool) - } - - mcpToolsOnce.Do(func() { - mcpTools = doGetMCPTools(ctx, permissions, cfg) - }) - for _, mcpTool := range mcpTools { - toolMap.Set(mcpTool.Name(), mcpTool) + } { + result[tool.Name()] = tool } if len(lspClients) > 0 { diagnosticsTool := tools.NewDiagnosticsTool(lspClients) - toolMap.Set(diagnosticsTool.Name(), diagnosticsTool) + result[diagnosticsTool.Name()] = diagnosticsTool } if agentTool != nil { - toolMap.Set(agentTool.Name(), agentTool) + result[agentTool.Name()] = agentTool } - if agentCfg.AllowedTools != nil { - // Filter tools based on allowed tools list - for toolName := range toolMap.Seq2() { - if !slices.Contains(agentCfg.AllowedTools, toolName) { - toolMap.Del(toolName) - } - } + return result + } + mcpToolsFn := func() map[string]tools.BaseTool { + slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID) + defer func() { + slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID) + }() + + mcpToolsOnce.Do(func() { + doGetMCPTools(ctx, permissions, cfg) + }) + + result := make(map[string]tools.BaseTool) + for _, mcpTool := range mcpTools.Seq2() { + result[mcpTool.Name()] = mcpTool } - return toolMap + return result } - return &agent{ + a := &agent{ + globalCtx: ctx, Broker: pubsub.NewBroker[AgentEvent](), agentCfg: agentCfg, provider: agentProvider, @@ -238,9 +241,12 @@ func NewAgent( summarizeProvider: summarizeProvider, summarizeProviderID: string(providerCfg.ID), activeRequests: csync.NewMap[string, context.CancelFunc](), - tools: toolFn(), + mcpTools: csync.NewLazyMap(mcpToolsFn), + baseTools: csync.NewLazyMap(baseToolsFn), promptQueue: csync.NewMap[string, []string](), - }, nil + } + a.setupEvents() + return a, nil } func (a *agent) Model() catwalk.Model { @@ -525,7 +531,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg } // Now collect tools (which may block on MCP initialization) - eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq())) + eventChan := a.provider.StreamResponse(ctx, msgHistory, a.allTools()) // Add the session and message ID into the context if needed by tools. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) @@ -563,7 +569,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg goto out default: // Continue processing - tool, ok := a.tools.Get(toolCall.Name) + tool, ok := a.getToolByName(toolCall.Name) // Tool not found if !ok { @@ -924,6 +930,12 @@ func (a *agent) CancelAll() { a.Cancel(key) // key is sessionID } + for _, cleanup := range a.cleanupFuncs { + if cleanup != nil { + cleanup() + } + } + timeout := time.After(5 * time.Second) for a.IsBusy() { select { @@ -1029,3 +1041,46 @@ func (a *agent) UpdateModel() error { return nil } + +func (a *agent) allTools() []tools.BaseTool { + result := slices.Collect(a.baseTools.Seq()) + result = slices.AppendSeq(result, a.mcpTools.Seq()) + return result +} + +func (a *agent) getToolByName(name string) (tools.BaseTool, bool) { + tool, ok := a.baseTools.Get(name) + if !ok { + tool, ok = a.mcpTools.Get(name) + } + return tool, ok +} + +func (a *agent) setupEvents() { + ctx, cancel := context.WithCancel(a.globalCtx) + + go func() { + for event := range SubscribeMCPEvents(ctx) { + switch event.Payload.Type { + case MCPEventToolsListChanged: + name := event.Payload.Name + c, ok := mcpClients.Get(name) + if !ok { + slog.Warn("MCP client not found for tools update", "name", name) + continue + } + tools := getTools(ctx, name, c) + updateMcpTools(name, tools) + // Update the lazy map with the new tools + a.mcpTools = csync.NewMapFromSeq(mcpTools.Seq2()) + updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len()) + default: + continue + } + } + }() + + a.cleanupFuncs = append(a.cleanupFuncs, func() { + cancel() + }) +} diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 0f6d2d0ab31ec34df16c9837335425e1f3b195bb..7417c8d87e6ade4c5dc854a42e04841f09f6ee92 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -52,7 +52,8 @@ func (s MCPState) String() string { type MCPEventType string const ( - MCPEventStateChanged MCPEventType = "state_changed" + MCPEventStateChanged MCPEventType = "state_changed" + MCPEventToolsListChanged MCPEventType = "tools_list_changed" ) // MCPEvent represents an event in the MCP system @@ -76,10 +77,13 @@ type MCPClientInfo struct { var ( mcpToolsOnce sync.Once - mcpTools []tools.BaseTool - mcpClients = csync.NewMap[string, *client.Client]() - mcpStates = csync.NewMap[string, MCPClientInfo]() - mcpBroker = pubsub.NewBroker[MCPEvent]() + mcpTools = csync.NewMap[string, tools.BaseTool]() + // mcpClientTools maps MCP name to tool names + mcpClientTools = csync.NewMap[string, []string]() + mcpClients = csync.NewMap[string, *client.Client]() + mcpStates = csync.NewMap[string, MCPClientInfo]() + mcpBroker = pubsub.NewBroker[MCPEvent]() + toolsMaker func(string, []mcp.Tool) []tools.BaseTool = nil ) type McpTool struct { @@ -192,7 +196,22 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return runTool(ctx, b.mcpName, b.tool.Name, params.Input) } -func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool { +func createToolsMaker(ctx context.Context, permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool { + return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool { + mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList)) + for _, tool := range mcpToolsList { + mcpTools = append(mcpTools, &McpTool{ + mcpName: name, + tool: tool, + permissions: permissions, + workingDir: workingDir, + }) + } + return mcpTools + } +} + +func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool { result, err := c.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { slog.Error("error listing tools", "error", err) @@ -201,16 +220,7 @@ func getTools(ctx context.Context, name string, permissions permission.Service, mcpClients.Del(name) return nil } - mcpTools := make([]tools.BaseTool, 0, len(result.Tools)) - for _, tool := range result.Tools { - mcpTools = append(mcpTools, &McpTool{ - mcpName: name, - tool: tool, - permissions: permissions, - workingDir: workingDir, - }) - } - return mcpTools + return toolsMaker(name, result.Tools) } // SubscribeMCPEvents returns a channel for MCP events @@ -252,6 +262,14 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien }) } +// publishMCPEventToolsListChanged publishes a tool list changed event +func publishMCPEventToolsListChanged(name string) { + mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{ + Type: MCPEventToolsListChanged, + Name: name, + }) +} + // CloseMCPClients closes all MCP clients. This should be called during application shutdown. func CloseMCPClients() { for c := range mcpClients.Seq() { @@ -274,6 +292,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con var wg sync.WaitGroup result := csync.NewSlice[tools.BaseTool]() + toolsMaker = createToolsMaker(ctx, permissions, cfg.WorkingDir()) + // Initialize states for all configured MCPs for name, m := range cfg.MCP { if m.Disabled { @@ -310,17 +330,46 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con if err != nil { return } + mcpClients.Set(name, c) - tools := getTools(ctx, name, permissions, c, cfg.WorkingDir()) - updateMCPState(name, MCPStateConnected, nil, c, len(tools)) + tools := getTools(ctx, name, c) result.Append(tools...) + updateMcpTools(name, tools) + updateMCPState(name, MCPStateConnected, nil, c, len(tools)) }(name, m) } wg.Wait() + return slices.Collect(result.Seq()) } +// updateMcpTools updates the global mcpTools and mcpClientTools maps +func updateMcpTools(mcpName string, tools []tools.BaseTool) { + toolNames := make([]string, 0, len(tools)) + for _, tool := range tools { + name := tool.Name() + if _, ok := mcpTools.Get(name); !ok { + slog.Info("Added MCP tool", "name", name, "mcp", mcpName) + } + mcpTools.Set(name, tool) + toolNames = append(toolNames, name) + } + + // remove the tools that are no longer available + old, ok := mcpClientTools.Get(mcpName) + if ok { + slices.Sort(toolNames) + for _, name := range old { + if _, ok := slices.BinarySearch(toolNames, name); !ok { + mcpTools.Del(name) + slog.Info("Removed MCP tool", "name", name, "mcp", mcpName) + } + } + } + mcpClientTools.Set(mcpName, toolNames) +} + func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) { c, err := createMcpClient(m) if err != nil { @@ -328,15 +377,24 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon slog.Error("error creating mcp client", "error", err, "name", name) return nil, err } - // Only call Start() for non-stdio clients, as stdio clients auto-start - if m.Type != config.MCPStdio { - if err := c.Start(ctx); err != nil { - updateMCPState(name, MCPStateError, err, nil, 0) - slog.Error("error starting mcp client", "error", err, "name", name) - _ = c.Close() - return nil, err + + c.OnNotification(func(n mcp.JSONRPCNotification) { + slog.Debug("Received MCP notification", "name", name, "notification", n) + switch n.Method { + case "notifications/tools/list_changed": + publishMCPEventToolsListChanged(name) + default: + slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method) } + }) + + if err := c.Start(ctx); err != nil { + updateMCPState(name, MCPStateError, err, nil, 0) + slog.Error("error starting mcp client", "error", err, "name", name) + _ = c.Close() + return nil, err } + if _, err := c.Initialize(ctx, mcpInitRequest); err != nil { updateMCPState(name, MCPStateError, err, nil, 0) slog.Error("error initializing mcp client", "error", err, "name", name)