diff --git a/internal/app/app.go b/internal/app/app.go index 21ddcd25eff1c9aeebb9d6700f9340ab0932e7ab..39e94feecb1f738fc46ff3012edac415c3759266 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -278,7 +278,6 @@ func (app *App) InitCoderAgent() error { } var err error app.CoderAgent, err = agent.NewAgent( - app.globalCtx, coderAgentCfg, app.Permissions, app.Sessions, diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 02ff02c2df5c85b688b892971472e22fa4aed0b7..cda31cfd5c747a620ce174f6e62a02a01ea3feb5 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -72,6 +72,8 @@ type agent struct { mcpTools []McpTool tools *csync.LazySlice[tools.BaseTool] + // We need this to be able to update it when model changes + agentToolFn func() (tools.BaseTool, error) provider provider.Provider providerID string @@ -91,7 +93,6 @@ var agentPromptMap = map[string]prompt.PromptID{ } func NewAgent( - ctx context.Context, agentCfg config.Agent, // These services are needed in the tools permissions permission.Service, @@ -102,18 +103,19 @@ func NewAgent( ) (Service, error) { cfg := config.Get() - var agentTool tools.BaseTool + var agentToolFn func() (tools.BaseTool, error) if agentCfg.ID == "coder" { - taskAgentCfg := config.Get().Agents["task"] - if taskAgentCfg.ID == "" { - return nil, fmt.Errorf("task agent not found in config") - } - taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients) - if err != nil { - return nil, fmt.Errorf("failed to create task agent: %w", err) + agentToolFn = func() (tools.BaseTool, error) { + taskAgentCfg := config.Get().Agents["task"] + if taskAgentCfg.ID == "" { + return nil, fmt.Errorf("task agent not found in config") + } + taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients) + if err != nil { + return nil, fmt.Errorf("failed to create task agent: %w", err) + } + return NewAgentTool(taskAgent, sessions, messages), nil } - - agentTool = NewAgentTool(taskAgent, sessions, messages) } providerCfg := config.Get().GetProviderForModel(agentCfg.Model) @@ -195,7 +197,7 @@ func NewAgent( } mcpToolsOnce.Do(func() { - mcpTools = doGetMCPTools(ctx, permissions, cfg) + mcpTools = doGetMCPTools(permissions, cfg) }) allTools = append(allTools, mcpTools...) @@ -203,10 +205,6 @@ func NewAgent( allTools = append(allTools, tools.NewDiagnosticsTool(lspClients)) } - if agentTool != nil { - allTools = append(allTools, agentTool) - } - if agentCfg.AllowedTools == nil { return allTools } @@ -230,6 +228,7 @@ func NewAgent( titleProvider: titleProvider, summarizeProvider: summarizeProvider, summarizeProviderID: string(providerCfg.ID), + agentToolFn: agentToolFn, activeRequests: csync.NewMap[string, context.CancelFunc](), tools: csync.NewLazySlice(toolFn), promptQueue: csync.NewMap[string, []string](), @@ -500,6 +499,18 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string }) } +func (a *agent) getAllTools() ([]tools.BaseTool, error) { + allTools := slices.Collect(a.tools.Seq()) + if a.agentToolFn != nil { + agentTool, agentToolErr := a.agentToolFn() + if agentToolErr != nil { + return nil, agentToolErr + } + allTools = append(allTools, agentTool) + } + return allTools, nil +} + func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) { ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID) @@ -514,8 +525,12 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err) } + allTools, toolsErr := a.getAllTools() + if toolsErr != nil { + return assistantMsg, nil, toolsErr + } // Now collect tools (which may block on MCP initialization) - eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq())) + eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools) // Add the session and message ID into the context if needed by tools. ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID) @@ -554,7 +569,8 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg default: // Continue processing var tool tools.BaseTool - for availableTool := range a.tools.Seq() { + allTools, _ := a.getAllTools() + for _, availableTool := range allTools { if availableTool.Info().Name == toolCall.Name { tool = availableTool break diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index bb50231da028e714c783f50cc7ebd8a1f4b595db..d0389ee321b55181c9e38546da8e256422fdc34f 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -275,7 +275,7 @@ var mcpInitRequest = mcp.InitializeRequest{ }, } -func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool { +func doGetMCPTools(permissions permission.Service, cfg *config.Config) []tools.BaseTool { var wg sync.WaitGroup result := csync.NewSlice[tools.BaseTool]() @@ -309,7 +309,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con } }() - ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m)) + ctx, cancel := context.WithTimeout(context.Background(), mcpTimeout(m)) defer cancel() c, err := createAndInitializeClient(ctx, name, m) if err != nil {