Detailed changes
@@ -27,6 +27,25 @@ func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] {
}
}
+// NewLazyMap creates a new lazy-loaded map. The provided load function is
+// executed in a separate goroutine to populate the map.
+func NewLazyMap[K comparable, V any](load func() map[K]V) *Map[K, V] {
+ m := &Map[K, V]{}
+ m.mu.Lock()
+ go func() {
+ m.inner = load()
+ m.mu.Unlock()
+ }()
+ return m
+}
+
+// Reset replaces the inner map with the new one.
+func (m *Map[K, V]) Reset(input map[K]V) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+ m.inner = input
+}
+
// Set sets the value for the specified key in the map.
func (m *Map[K, V]) Set(key K, value V) {
m.mu.Lock()
@@ -5,6 +5,8 @@ import (
"maps"
"sync"
"testing"
+ "testing/synctest"
+ "time"
"github.com/stretchr/testify/require"
)
@@ -36,6 +38,56 @@ func TestNewMapFrom(t *testing.T) {
require.Equal(t, 1, value)
}
+func TestNewLazyMap(t *testing.T) {
+ t.Parallel()
+
+ synctest.Test(t, func(t *testing.T) {
+ t.Helper()
+
+ waiter := sync.Mutex{}
+ waiter.Lock()
+ loadCalled := false
+
+ loadFunc := func() map[string]int {
+ waiter.Lock()
+ defer waiter.Unlock()
+ loadCalled = true
+ return map[string]int{
+ "key1": 1,
+ "key2": 2,
+ }
+ }
+
+ m := NewLazyMap(loadFunc)
+ require.NotNil(t, m)
+
+ waiter.Unlock() // Allow the load function to proceed
+ time.Sleep(100 * time.Millisecond)
+ require.True(t, loadCalled)
+ require.Equal(t, 2, m.Len())
+
+ value, ok := m.Get("key1")
+ require.True(t, ok)
+ require.Equal(t, 1, value)
+ })
+}
+
+func TestMap_Reset(t *testing.T) {
+ t.Parallel()
+
+ m := NewMapFrom(map[string]int{
+ "a": 10,
+ })
+
+ m.Reset(map[string]int{
+ "b": 20,
+ })
+ value, ok := m.Get("b")
+ require.True(t, ok)
+ require.Equal(t, 20, value)
+ require.Equal(t, 1, m.Len())
+}
+
func TestMap_Set(t *testing.T) {
t.Parallel()
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log/slog"
+ "maps"
"slices"
"strings"
"time"
@@ -65,11 +66,13 @@ type agent struct {
sessions session.Service
messages message.Service
permissions permission.Service
- mcpTools []McpTool
+ baseTools *csync.Map[string, tools.BaseTool]
+ mcpTools *csync.Map[string, tools.BaseTool]
+ lspClients *csync.Map[string, *lsp.Client]
- tools *csync.LazySlice[tools.BaseTool]
// We need this to be able to update it when model changes
- agentToolFn func() (tools.BaseTool, error)
+ agentToolFn func() (tools.BaseTool, error)
+ cleanupFuncs []func()
provider provider.Provider
providerID string
@@ -171,14 +174,16 @@ func NewAgent(
return nil, err
}
- toolFn := func() []tools.BaseTool {
- slog.Info("Initializing agent tools", "agent", agentCfg.ID)
+ baseToolsFn := func() map[string]tools.BaseTool {
+ slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
defer func() {
- slog.Info("Initialized agent tools", "agent", agentCfg.ID)
+ slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
}()
+ // Base tools available to all agents
cwd := cfg.WorkingDir()
- allTools := []tools.BaseTool{
+ result := make(map[string]tools.BaseTool)
+ for _, tool := range []tools.BaseTool{
tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
tools.NewDownloadTool(permissions, cwd),
tools.NewEditTool(lspClients, permissions, history, cwd),
@@ -190,36 +195,25 @@ func NewAgent(
tools.NewSourcegraphTool(),
tools.NewViewTool(lspClients, permissions, cwd),
tools.NewWriteTool(lspClients, permissions, history, cwd),
+ } {
+ result[tool.Name()] = tool
}
+ return result
+ }
+ mcpToolsFn := func() map[string]tools.BaseTool {
+ slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
+ defer func() {
+ slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
+ }()
mcpToolsOnce.Do(func() {
- mcpTools = doGetMCPTools(ctx, permissions, cfg)
+ doGetMCPTools(ctx, permissions, cfg)
})
- withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
- if agentCfg.ID == "coder" {
- t = append(t, mcpTools...)
- if lspClients.Len() > 0 {
- t = append(t, tools.NewDiagnosticsTool(lspClients))
- }
- }
- return t
- }
-
- if agentCfg.AllowedTools == nil {
- return withCoderTools(allTools)
- }
-
- var filteredTools []tools.BaseTool
- for _, tool := range allTools {
- if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
- filteredTools = append(filteredTools, tool)
- }
- }
- return withCoderTools(filteredTools)
+ return maps.Collect(mcpTools.Seq2())
}
- return &agent{
+ a := &agent{
Broker: pubsub.NewBroker[AgentEvent](),
agentCfg: agentCfg,
provider: agentProvider,
@@ -231,10 +225,14 @@ func NewAgent(
summarizeProviderID: string(providerCfg.ID),
agentToolFn: agentToolFn,
activeRequests: csync.NewMap[string, context.CancelFunc](),
- tools: csync.NewLazySlice(toolFn),
+ mcpTools: csync.NewLazyMap(mcpToolsFn),
+ baseTools: csync.NewLazyMap(baseToolsFn),
promptQueue: csync.NewMap[string, []string](),
permissions: permissions,
- }, nil
+ lspClients: lspClients,
+ }
+ a.setupEvents(ctx)
+ return a, nil
}
func (a *agent) Model() catwalk.Model {
@@ -517,7 +515,18 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
}
func (a *agent) getAllTools() ([]tools.BaseTool, error) {
- allTools := slices.Collect(a.tools.Seq())
+ var allTools []tools.BaseTool
+ for tool := range a.baseTools.Seq() {
+ if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
+ allTools = append(allTools, tool)
+ }
+ }
+ if a.agentCfg.ID == "coder" {
+ allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
+ if a.lspClients.Len() > 0 {
+ allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
+ }
+ }
if a.agentToolFn != nil {
agentTool, agentToolErr := a.agentToolFn()
if agentToolErr != nil {
@@ -591,7 +600,7 @@ loop:
default:
// Continue processing
var tool tools.BaseTool
- allTools, _ := a.getAllTools()
+ allTools, _ = a.getAllTools()
for _, availableTool := range allTools {
if availableTool.Info().Name == toolCall.Name {
tool = availableTool
@@ -960,6 +969,12 @@ func (a *agent) CancelAll() {
a.Cancel(key) // key is sessionID
}
+ for _, cleanup := range a.cleanupFuncs {
+ if cleanup != nil {
+ cleanup()
+ }
+ }
+
timeout := time.After(5 * time.Second)
for a.IsBusy() {
select {
@@ -1071,3 +1086,48 @@ func (a *agent) UpdateModel() error {
return nil
}
+
+func (a *agent) setupEvents(ctx context.Context) {
+ ctx, cancel := context.WithCancel(ctx)
+
+ go func() {
+ subCh := SubscribeMCPEvents(ctx)
+
+ for {
+ select {
+ case event, ok := <-subCh:
+ if !ok {
+ slog.Debug("MCPEvents subscription channel closed")
+ return
+ }
+ switch event.Payload.Type {
+ case MCPEventToolsListChanged:
+ name := event.Payload.Name
+ c, ok := mcpClients.Get(name)
+ if !ok {
+ slog.Warn("MCP client not found for tools update", "name", name)
+ continue
+ }
+ cfg := config.Get()
+ tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
+ if err != nil {
+ slog.Error("error listing tools", "error", err)
+ updateMCPState(name, MCPStateError, err, nil, 0)
+ _ = c.Close()
+ continue
+ }
+ updateMcpTools(name, tools)
+ a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
+ updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
+ default:
+ continue
+ }
+ case <-ctx.Done():
+ slog.Debug("MCPEvents subscription cancelled")
+ return
+ }
+ }
+ }()
+
+ a.cleanupFuncs = append(a.cleanupFuncs, cancel)
+}
@@ -8,7 +8,6 @@ import (
"fmt"
"log/slog"
"maps"
- "slices"
"strings"
"sync"
"time"
@@ -54,7 +53,8 @@ func (s MCPState) String() string {
type MCPEventType string
const (
- MCPEventStateChanged MCPEventType = "state_changed"
+ MCPEventStateChanged MCPEventType = "state_changed"
+ MCPEventToolsListChanged MCPEventType = "tools_list_changed"
)
// MCPEvent represents an event in the MCP system
@@ -77,11 +77,12 @@ type MCPClientInfo struct {
}
var (
- mcpToolsOnce sync.Once
- mcpTools []tools.BaseTool
- mcpClients = csync.NewMap[string, *client.Client]()
- mcpStates = csync.NewMap[string, MCPClientInfo]()
- mcpBroker = pubsub.NewBroker[MCPEvent]()
+ mcpToolsOnce sync.Once
+ mcpTools = csync.NewMap[string, tools.BaseTool]()
+ mcpClient2Tools = csync.NewMap[string, []tools.BaseTool]()
+ mcpClients = csync.NewMap[string, *client.Client]()
+ mcpStates = csync.NewMap[string, MCPClientInfo]()
+ mcpBroker = pubsub.NewBroker[MCPEvent]()
)
type McpTool struct {
@@ -237,8 +238,12 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien
Client: client,
ToolCount: toolCount,
}
- if state == MCPStateConnected {
+ switch state {
+ case MCPStateConnected:
info.ConnectedAt = time.Now()
+ case MCPStateError:
+ updateMcpTools(name, nil)
+ mcpClients.Del(name)
}
mcpStates.Set(name, info)
@@ -252,6 +257,14 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien
})
}
+// publishMCPEventToolsListChanged publishes a tool list changed event
+func publishMCPEventToolsListChanged(name string) {
+ mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
+ Type: MCPEventToolsListChanged,
+ Name: name,
+ })
+}
+
// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
func CloseMCPClients() error {
var errs []error
@@ -274,10 +287,8 @@ var mcpInitRequest = mcp.InitializeRequest{
},
}
-func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
+func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) {
var wg sync.WaitGroup
- result := csync.NewSlice[tools.BaseTool]()
-
// Initialize states for all configured MCPs
for name, m := range cfg.MCP {
if m.Disabled {
@@ -316,6 +327,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
return
}
+ mcpClients.Set(name, c)
+
tools, err := getTools(ctx, name, permissions, c, cfg.WorkingDir())
if err != nil {
slog.Error("error listing tools", "error", err)
@@ -324,13 +337,26 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
return
}
+ updateMcpTools(name, tools)
mcpClients.Set(name, c)
updateMCPState(name, MCPStateConnected, nil, c, len(tools))
- result.Append(tools...)
}(name, m)
}
wg.Wait()
- return slices.Collect(result.Seq())
+}
+
+// updateMcpTools updates the global mcpTools and mcpClient2Tools maps
+func updateMcpTools(mcpName string, tools []tools.BaseTool) {
+ if len(tools) == 0 {
+ mcpClient2Tools.Del(mcpName)
+ } else {
+ mcpClient2Tools.Set(mcpName, tools)
+ }
+ for _, tools := range mcpClient2Tools.Seq2() {
+ for _, t := range tools {
+ mcpTools.Set(t.Name(), t)
+ }
+ }
}
func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig, resolver config.VariableResolver) (*client.Client, error) {
@@ -341,11 +367,22 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon
return nil, err
}
+ c.OnNotification(func(n mcp.JSONRPCNotification) {
+ slog.Debug("Received MCP notification", "name", name, "notification", n)
+ switch n.Method {
+ case "notifications/tools/list_changed":
+ publishMCPEventToolsListChanged(name)
+ default:
+ slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
+ }
+ })
+
// XXX: ideally we should be able to use context.WithTimeout here, but,
// the SSE MCP client will start failing once that context is canceled.
timeout := mcpTimeout(m)
mcpCtx, cancel := context.WithCancel(ctx)
cancelTimer := time.AfterFunc(timeout, cancel)
+
if err := c.Start(mcpCtx); err != nil {
updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
slog.Error("error starting mcp client", "error", err, "name", name)
@@ -353,6 +390,7 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon
cancel()
return nil, err
}
+
if _, err := c.Initialize(mcpCtx, mcpInitRequest); err != nil {
updateMCPState(name, MCPStateError, maybeTimeoutErr(err, timeout), nil, 0)
slog.Error("error initializing mcp client", "error", err, "name", name)
@@ -360,6 +398,7 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon
cancel()
return nil, err
}
+
cancelTimer.Stop()
slog.Info("Initialized mcp client", "name", name)
return c, nil