diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 0015e498f986c67dd4477a6fb35e8846c8442b9e..6f3218be830ec326156f1cfae3a40f4e94ec767d 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -71,7 +71,7 @@ type agent struct { messages message.Service mcpTools []McpTool - tools *csync.LazySlice[tools.BaseTool] + tools *csync.Map[string, tools.BaseTool] provider provider.Provider providerID string @@ -173,14 +173,17 @@ func NewAgent( return nil, err } - toolFn := func() []tools.BaseTool { + toolFn := func() *csync.Map[string, tools.BaseTool] { slog.Info("Initializing agent tools", "agent", agentCfg.ID) defer func() { slog.Info("Initialized agent tools", "agent", agentCfg.ID) }() cwd := cfg.WorkingDir() - allTools := []tools.BaseTool{ + toolMap := csync.NewMap[string, tools.BaseTool]() + + // Base tools available to all agents + baseTools := []tools.BaseTool{ tools.NewBashTool(permissions, cwd), tools.NewDownloadTool(permissions, cwd), tools.NewEditTool(lspClients, permissions, history, cwd), @@ -193,31 +196,35 @@ func NewAgent( tools.NewViewTool(lspClients, permissions, cwd), tools.NewWriteTool(lspClients, permissions, history, cwd), } + for _, tool := range baseTools { + toolMap.Set(tool.Name(), tool) + } mcpToolsOnce.Do(func() { mcpTools = doGetMCPTools(ctx, permissions, cfg) }) - allTools = append(allTools, mcpTools...) + for _, mcpTool := range mcpTools { + toolMap.Set(mcpTool.Name(), mcpTool) + } if len(lspClients) > 0 { - allTools = append(allTools, tools.NewDiagnosticsTool(lspClients)) + diagnosticsTool := tools.NewDiagnosticsTool(lspClients) + toolMap.Set(diagnosticsTool.Name(), diagnosticsTool) } if agentTool != nil { - allTools = append(allTools, agentTool) - } - - if agentCfg.AllowedTools == nil { - return allTools + toolMap.Set(agentTool.Name(), agentTool) } - var filteredTools []tools.BaseTool - for _, tool := range allTools { - if slices.Contains(agentCfg.AllowedTools, tool.Name()) { - filteredTools = append(filteredTools, tool) + if agentCfg.AllowedTools != nil { + // Filter tools based on allowed tools list + for toolName := range toolMap.Seq2() { + if !slices.Contains(agentCfg.AllowedTools, toolName) { + toolMap.Del(toolName) + } } } - return filteredTools + return toolMap } return &agent{ @@ -231,7 +238,7 @@ func NewAgent( summarizeProvider: summarizeProvider, summarizeProviderID: string(providerCfg.ID), activeRequests: csync.NewMap[string, context.CancelFunc](), - tools: csync.NewLazySlice(toolFn), + tools: toolFn(), promptQueue: csync.NewMap[string, []string](), }, nil } @@ -556,16 +563,10 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg goto out default: // Continue processing - var tool tools.BaseTool - for availableTool := range a.tools.Seq() { - if availableTool.Info().Name == toolCall.Name { - tool = availableTool - break - } - } + tool, ok := a.tools.Get(toolCall.Name) // Tool not found - if tool == nil { + if !ok { toolResults[i] = message.ToolResult{ ToolCallID: toolCall.ID, Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),