diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 91f02397bd7ad2610c950719b738965577649f6f..1efc3fc268392c06481d61ae6e11c9d67cdc13e8 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -515,30 +515,18 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string } func (a *agent) getAllTools() ([]tools.BaseTool, error) { - allTools := slices.Collect(a.baseTools.Seq()) - - withCoderTools := func(t []tools.BaseTool) []tools.BaseTool { - if a.agentCfg.ID == "coder" { - t = append(t, slices.Collect(a.mcpTools.Seq())...) - if a.lspClients.Len() > 0 { - t = append(t, tools.NewDiagnosticsTool(a.lspClients)) - } + var allTools []tools.BaseTool + for tool := range a.baseTools.Seq() { + if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) { + allTools = append(allTools, tool) } - return t } - - if a.agentCfg.AllowedTools == nil { - allTools = withCoderTools(allTools) - } else { - var filteredTools []tools.BaseTool - for _, tool := range allTools { - if slices.Contains(a.agentCfg.AllowedTools, tool.Name()) { - filteredTools = append(filteredTools, tool) - } + if a.agentCfg.ID == "coder" { + allTools = slices.AppendSeq(allTools, a.mcpTools.Seq()) + if a.lspClients.Len() > 0 { + allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients)) } - allTools = withCoderTools(filteredTools) } - if a.agentToolFn != nil { agentTool, agentToolErr := a.agentToolFn() if agentToolErr != nil {