Detailed changes
@@ -27,7 +27,7 @@ require (
github.com/google/uuid v1.6.0
github.com/invopop/jsonschema v0.13.0
github.com/joho/godotenv v1.5.1
- github.com/mark3labs/mcp-go v0.38.0
+ github.com/mark3labs/mcp-go v0.39.1
github.com/muesli/termenv v0.16.0
github.com/ncruces/go-sqlite3 v0.28.0
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
@@ -189,6 +189,8 @@ github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mark3labs/mcp-go v0.38.0 h1:E5tmJiIXkhwlV0pLAwAT0O5ZjUZSISE/2Jxg+6vpq4I=
github.com/mark3labs/mcp-go v0.38.0/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g=
+github.com/mark3labs/mcp-go v0.39.1 h1:2oPxk7aDbQhouakkYyKl2T4hKFU1c6FDaubWyGyVE1k=
+github.com/mark3labs/mcp-go v0.39.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
@@ -27,6 +27,28 @@ func NewMapFrom[K comparable, V any](m map[K]V) *Map[K, V] {
}
}
+// NewMapFromSeq creates a new thread-safe map from an iter.Seq2 of key-value pairs.
+func NewMapFromSeq[k comparable, v any](seq iter.Seq2[k, v]) *Map[k, v] {
+ m := make(map[k]v)
+ seq(func(kk k, vv v) bool {
+ m[kk] = vv
+ return true
+ })
+ return NewMapFrom(m)
+}
+
+// NewLazyMap creates a new lazy-loaded map. The provided load function is
+// executed in a separate goroutine to populate the map.
+func NewLazyMap[K comparable, V any](load func() map[K]V) *Map[K, V] {
+ m := &Map[K, V]{}
+ m.mu.Lock()
+ go func() {
+ m.inner = load()
+ m.mu.Unlock()
+ }()
+ return m
+}
+
// Set sets the value for the specified key in the map.
func (m *Map[K, V]) Set(key K, value V) {
m.mu.Lock()
@@ -5,6 +5,7 @@ import (
"maps"
"sync"
"testing"
+ "time"
"github.com/stretchr/testify/require"
)
@@ -36,6 +37,61 @@ func TestNewMapFrom(t *testing.T) {
require.Equal(t, 1, value)
}
+func TestNewMapFromSeq(t *testing.T) {
+ t.Parallel()
+
+ original := map[string]int{
+ "key1": 1,
+ "key2": 2,
+ }
+ seq := func(f func(string, int) bool) {
+ for k, v := range original {
+ if !f(k, v) {
+ break
+ }
+ }
+ }
+
+ m := NewMapFromSeq(seq)
+ require.NotNil(t, m)
+ require.Equal(t, original, m.inner)
+ require.Equal(t, 2, m.Len())
+
+ value, ok := m.Get("key2")
+ require.True(t, ok)
+ require.Equal(t, 2, value)
+}
+
+func TestNewLazyMap(t *testing.T) {
+ t.Parallel()
+
+ waiter := sync.Mutex{}
+ waiter.Lock()
+ loadCalled := false
+
+ loadFunc := func() map[string]int {
+ waiter.Lock()
+ defer waiter.Unlock()
+ loadCalled = true
+ return map[string]int{
+ "key1": 1,
+ "key2": 2,
+ }
+ }
+
+ m := NewLazyMap(loadFunc)
+ require.NotNil(t, m)
+
+ waiter.Unlock() // Allow the load function to proceed
+ time.Sleep(100 * time.Millisecond)
+ require.True(t, loadCalled)
+ require.Equal(t, 2, m.Len())
+
+ value, ok := m.Get("key1")
+ require.True(t, ok)
+ require.Equal(t, 1, value)
+}
+
func TestMap_Set(t *testing.T) {
t.Parallel()
@@ -65,13 +65,16 @@ type Service interface {
}
type agent struct {
+ globalCtx context.Context
+ cleanupFuncs []func()
+
*pubsub.Broker[AgentEvent]
agentCfg config.Agent
sessions session.Service
messages message.Service
- mcpTools []McpTool
- tools *csync.Map[string, tools.BaseTool]
+ baseTools *csync.Map[string, tools.BaseTool]
+ mcpTools *csync.Map[string, tools.BaseTool]
provider provider.Provider
providerID string
@@ -173,17 +176,16 @@ func NewAgent(
return nil, err
}
- toolFn := func() *csync.Map[string, tools.BaseTool] {
- slog.Info("Initializing agent tools", "agent", agentCfg.ID)
+ baseToolsFn := func() map[string]tools.BaseTool {
+ slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
defer func() {
- slog.Info("Initialized agent tools", "agent", agentCfg.ID)
+ slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
}()
- cwd := cfg.WorkingDir()
- toolMap := csync.NewMap[string, tools.BaseTool]()
-
// Base tools available to all agents
- baseTools := []tools.BaseTool{
+ cwd := cfg.WorkingDir()
+ result := make(map[string]tools.BaseTool)
+ for _, tool := range []tools.BaseTool{
tools.NewBashTool(permissions, cwd),
tools.NewDownloadTool(permissions, cwd),
tools.NewEditTool(lspClients, permissions, history, cwd),
@@ -195,39 +197,40 @@ func NewAgent(
tools.NewSourcegraphTool(),
tools.NewViewTool(lspClients, permissions, cwd),
tools.NewWriteTool(lspClients, permissions, history, cwd),
- }
- for _, tool := range baseTools {
- toolMap.Set(tool.Name(), tool)
- }
-
- mcpToolsOnce.Do(func() {
- mcpTools = doGetMCPTools(ctx, permissions, cfg)
- })
- for _, mcpTool := range mcpTools {
- toolMap.Set(mcpTool.Name(), mcpTool)
+ } {
+ result[tool.Name()] = tool
}
if len(lspClients) > 0 {
diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
- toolMap.Set(diagnosticsTool.Name(), diagnosticsTool)
+ result[diagnosticsTool.Name()] = diagnosticsTool
}
if agentTool != nil {
- toolMap.Set(agentTool.Name(), agentTool)
+ result[agentTool.Name()] = agentTool
}
- if agentCfg.AllowedTools != nil {
- // Filter tools based on allowed tools list
- for toolName := range toolMap.Seq2() {
- if !slices.Contains(agentCfg.AllowedTools, toolName) {
- toolMap.Del(toolName)
- }
- }
+ return result
+ }
+ mcpToolsFn := func() map[string]tools.BaseTool {
+ slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
+ defer func() {
+ slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
+ }()
+
+ mcpToolsOnce.Do(func() {
+ doGetMCPTools(ctx, permissions, cfg)
+ })
+
+ result := make(map[string]tools.BaseTool)
+ for _, mcpTool := range mcpTools.Seq2() {
+ result[mcpTool.Name()] = mcpTool
}
- return toolMap
+ return result
}
- return &agent{
+ a := &agent{
+ globalCtx: ctx,
Broker: pubsub.NewBroker[AgentEvent](),
agentCfg: agentCfg,
provider: agentProvider,
@@ -238,9 +241,12 @@ func NewAgent(
summarizeProvider: summarizeProvider,
summarizeProviderID: string(providerCfg.ID),
activeRequests: csync.NewMap[string, context.CancelFunc](),
- tools: toolFn(),
+ mcpTools: csync.NewLazyMap(mcpToolsFn),
+ baseTools: csync.NewLazyMap(baseToolsFn),
promptQueue: csync.NewMap[string, []string](),
- }, nil
+ }
+ a.setupEvents()
+ return a, nil
}
func (a *agent) Model() catwalk.Model {
@@ -525,7 +531,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
}
// Now collect tools (which may block on MCP initialization)
- eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
+ eventChan := a.provider.StreamResponse(ctx, msgHistory, a.allTools())
// Add the session and message ID into the context if needed by tools.
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
@@ -563,7 +569,7 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
goto out
default:
// Continue processing
- tool, ok := a.tools.Get(toolCall.Name)
+ tool, ok := a.getToolByName(toolCall.Name)
// Tool not found
if !ok {
@@ -924,6 +930,12 @@ func (a *agent) CancelAll() {
a.Cancel(key) // key is sessionID
}
+ for _, cleanup := range a.cleanupFuncs {
+ if cleanup != nil {
+ cleanup()
+ }
+ }
+
timeout := time.After(5 * time.Second)
for a.IsBusy() {
select {
@@ -1029,3 +1041,46 @@ func (a *agent) UpdateModel() error {
return nil
}
+
+func (a *agent) allTools() []tools.BaseTool {
+ result := slices.Collect(a.baseTools.Seq())
+ result = slices.AppendSeq(result, a.mcpTools.Seq())
+ return result
+}
+
+func (a *agent) getToolByName(name string) (tools.BaseTool, bool) {
+ tool, ok := a.baseTools.Get(name)
+ if !ok {
+ tool, ok = a.mcpTools.Get(name)
+ }
+ return tool, ok
+}
+
+func (a *agent) setupEvents() {
+ ctx, cancel := context.WithCancel(a.globalCtx)
+
+ go func() {
+ for event := range SubscribeMCPEvents(ctx) {
+ switch event.Payload.Type {
+ case MCPEventToolsListChanged:
+ name := event.Payload.Name
+ c, ok := mcpClients.Get(name)
+ if !ok {
+ slog.Warn("MCP client not found for tools update", "name", name)
+ continue
+ }
+ tools := getTools(ctx, name, c)
+ updateMcpTools(name, tools)
+ // Update the lazy map with the new tools
+ a.mcpTools = csync.NewMapFromSeq(mcpTools.Seq2())
+ updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
+ default:
+ continue
+ }
+ }
+ }()
+
+ a.cleanupFuncs = append(a.cleanupFuncs, func() {
+ cancel()
+ })
+}
@@ -52,7 +52,8 @@ func (s MCPState) String() string {
type MCPEventType string
const (
- MCPEventStateChanged MCPEventType = "state_changed"
+ MCPEventStateChanged MCPEventType = "state_changed"
+ MCPEventToolsListChanged MCPEventType = "tools_list_changed"
)
// MCPEvent represents an event in the MCP system
@@ -76,10 +77,13 @@ type MCPClientInfo struct {
var (
mcpToolsOnce sync.Once
- mcpTools []tools.BaseTool
- mcpClients = csync.NewMap[string, *client.Client]()
- mcpStates = csync.NewMap[string, MCPClientInfo]()
- mcpBroker = pubsub.NewBroker[MCPEvent]()
+ mcpTools = csync.NewMap[string, tools.BaseTool]()
+ // mcpClientTools maps MCP name to tool names
+ mcpClientTools = csync.NewMap[string, []string]()
+ mcpClients = csync.NewMap[string, *client.Client]()
+ mcpStates = csync.NewMap[string, MCPClientInfo]()
+ mcpBroker = pubsub.NewBroker[MCPEvent]()
+ toolsMaker func(string, []mcp.Tool) []tools.BaseTool = nil
)
type McpTool struct {
@@ -192,7 +196,22 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
}
-func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
+func createToolsMaker(ctx context.Context, permissions permission.Service, workingDir string) func(string, []mcp.Tool) []tools.BaseTool {
+ return func(name string, mcpToolsList []mcp.Tool) []tools.BaseTool {
+ mcpTools := make([]tools.BaseTool, 0, len(mcpToolsList))
+ for _, tool := range mcpToolsList {
+ mcpTools = append(mcpTools, &McpTool{
+ mcpName: name,
+ tool: tool,
+ permissions: permissions,
+ workingDir: workingDir,
+ })
+ }
+ return mcpTools
+ }
+}
+
+func getTools(ctx context.Context, name string, c *client.Client) []tools.BaseTool {
result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
slog.Error("error listing tools", "error", err)
@@ -201,16 +220,7 @@ func getTools(ctx context.Context, name string, permissions permission.Service,
mcpClients.Del(name)
return nil
}
- mcpTools := make([]tools.BaseTool, 0, len(result.Tools))
- for _, tool := range result.Tools {
- mcpTools = append(mcpTools, &McpTool{
- mcpName: name,
- tool: tool,
- permissions: permissions,
- workingDir: workingDir,
- })
- }
- return mcpTools
+ return toolsMaker(name, result.Tools)
}
// SubscribeMCPEvents returns a channel for MCP events
@@ -252,6 +262,14 @@ func updateMCPState(name string, state MCPState, err error, client *client.Clien
})
}
+// publishMCPEventToolsListChanged publishes a tool list changed event
+func publishMCPEventToolsListChanged(name string) {
+ mcpBroker.Publish(pubsub.UpdatedEvent, MCPEvent{
+ Type: MCPEventToolsListChanged,
+ Name: name,
+ })
+}
+
// CloseMCPClients closes all MCP clients. This should be called during application shutdown.
func CloseMCPClients() {
for c := range mcpClients.Seq() {
@@ -274,6 +292,8 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
var wg sync.WaitGroup
result := csync.NewSlice[tools.BaseTool]()
+ toolsMaker = createToolsMaker(ctx, permissions, cfg.WorkingDir())
+
// Initialize states for all configured MCPs
for name, m := range cfg.MCP {
if m.Disabled {
@@ -310,17 +330,46 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
if err != nil {
return
}
+
mcpClients.Set(name, c)
- tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
- updateMCPState(name, MCPStateConnected, nil, c, len(tools))
+ tools := getTools(ctx, name, c)
result.Append(tools...)
+ updateMcpTools(name, tools)
+ updateMCPState(name, MCPStateConnected, nil, c, len(tools))
}(name, m)
}
wg.Wait()
+
return slices.Collect(result.Seq())
}
+// updateMcpTools updates the global mcpTools and mcpClientTools maps
+func updateMcpTools(mcpName string, tools []tools.BaseTool) {
+ toolNames := make([]string, 0, len(tools))
+ for _, tool := range tools {
+ name := tool.Name()
+ if _, ok := mcpTools.Get(name); !ok {
+ slog.Info("Added MCP tool", "name", name, "mcp", mcpName)
+ }
+ mcpTools.Set(name, tool)
+ toolNames = append(toolNames, name)
+ }
+
+ // remove the tools that are no longer available
+ old, ok := mcpClientTools.Get(mcpName)
+ if ok {
+ slices.Sort(toolNames)
+ for _, name := range old {
+ if _, ok := slices.BinarySearch(toolNames, name); !ok {
+ mcpTools.Del(name)
+ slog.Info("Removed MCP tool", "name", name, "mcp", mcpName)
+ }
+ }
+ }
+ mcpClientTools.Set(mcpName, toolNames)
+}
+
func createAndInitializeClient(ctx context.Context, name string, m config.MCPConfig) (*client.Client, error) {
c, err := createMcpClient(m)
if err != nil {
@@ -328,15 +377,24 @@ func createAndInitializeClient(ctx context.Context, name string, m config.MCPCon
slog.Error("error creating mcp client", "error", err, "name", name)
return nil, err
}
- // Only call Start() for non-stdio clients, as stdio clients auto-start
- if m.Type != config.MCPStdio {
- if err := c.Start(ctx); err != nil {
- updateMCPState(name, MCPStateError, err, nil, 0)
- slog.Error("error starting mcp client", "error", err, "name", name)
- _ = c.Close()
- return nil, err
+
+ c.OnNotification(func(n mcp.JSONRPCNotification) {
+ slog.Debug("Received MCP notification", "name", name, "notification", n)
+ switch n.Method {
+ case "notifications/tools/list_changed":
+ publishMCPEventToolsListChanged(name)
+ default:
+ slog.Debug("Unhandled MCP notification", "name", name, "method", n.Method)
}
+ })
+
+ if err := c.Start(ctx); err != nil {
+ updateMCPState(name, MCPStateError, err, nil, 0)
+ slog.Error("error starting mcp client", "error", err, "name", name)
+ _ = c.Close()
+ return nil, err
}
+
if _, err := c.Initialize(ctx, mcpInitRequest); err != nil {
updateMCPState(name, MCPStateError, err, nil, 0)
slog.Error("error initializing mcp client", "error", err, "name", name)