fix: agent tool not working when switching models

kujtimiihoxha created

Change summary

internal/app/app.go             |  1 
internal/llm/agent/agent.go     | 52 ++++++++++++++++++++++------------
internal/llm/agent/mcp-tools.go |  4 +-
3 files changed, 36 insertions(+), 21 deletions(-)

Detailed changes

internal/app/app.go 🔗

@@ -278,7 +278,6 @@ func (app *App) InitCoderAgent() error {
 	}
 	var err error
 	app.CoderAgent, err = agent.NewAgent(
-		app.globalCtx,
 		coderAgentCfg,
 		app.Permissions,
 		app.Sessions,

internal/llm/agent/agent.go 🔗

@@ -72,6 +72,8 @@ type agent struct {
 	mcpTools []McpTool
 
 	tools *csync.LazySlice[tools.BaseTool]
+	// We need this to be able to update it when model changes
+	agentToolFn func() (tools.BaseTool, error)
 
 	provider   provider.Provider
 	providerID string
@@ -91,7 +93,6 @@ var agentPromptMap = map[string]prompt.PromptID{
 }
 
 func NewAgent(
-	ctx context.Context,
 	agentCfg config.Agent,
 	// These services are needed in the tools
 	permissions permission.Service,
@@ -102,18 +103,19 @@ func NewAgent(
 ) (Service, error) {
 	cfg := config.Get()
 
-	var agentTool tools.BaseTool
+	var agentToolFn func() (tools.BaseTool, error)
 	if agentCfg.ID == "coder" {
-		taskAgentCfg := config.Get().Agents["task"]
-		if taskAgentCfg.ID == "" {
-			return nil, fmt.Errorf("task agent not found in config")
-		}
-		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
-		if err != nil {
-			return nil, fmt.Errorf("failed to create task agent: %w", err)
+		agentToolFn = func() (tools.BaseTool, error) {
+			taskAgentCfg := config.Get().Agents["task"]
+			if taskAgentCfg.ID == "" {
+				return nil, fmt.Errorf("task agent not found in config")
+			}
+			taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
+			if err != nil {
+				return nil, fmt.Errorf("failed to create task agent: %w", err)
+			}
+			return NewAgentTool(taskAgent, sessions, messages), nil
 		}
-
-		agentTool = NewAgentTool(taskAgent, sessions, messages)
 	}
 
 	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
@@ -195,7 +197,7 @@ func NewAgent(
 		}
 
 		mcpToolsOnce.Do(func() {
-			mcpTools = doGetMCPTools(ctx, permissions, cfg)
+			mcpTools = doGetMCPTools(permissions, cfg)
 		})
 		allTools = append(allTools, mcpTools...)
 
@@ -203,10 +205,6 @@ func NewAgent(
 			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
 		}
 
-		if agentTool != nil {
-			allTools = append(allTools, agentTool)
-		}
-
 		if agentCfg.AllowedTools == nil {
 			return allTools
 		}
@@ -230,6 +228,7 @@ func NewAgent(
 		titleProvider:       titleProvider,
 		summarizeProvider:   summarizeProvider,
 		summarizeProviderID: string(providerCfg.ID),
+		agentToolFn:         agentToolFn,
 		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 		tools:               csync.NewLazySlice(toolFn),
 		promptQueue:         csync.NewMap[string, []string](),
@@ -500,6 +499,18 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
 	})
 }
 
+func (a *agent) getAllTools() ([]tools.BaseTool, error) {
+	allTools := slices.Collect(a.tools.Seq())
+	if a.agentToolFn != nil {
+		agentTool, agentToolErr := a.agentToolFn()
+		if agentToolErr != nil {
+			return nil, agentToolErr
+		}
+		allTools = append(allTools, agentTool)
+	}
+	return allTools, nil
+}
+
 func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 
@@ -514,8 +525,12 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 	}
 
+	allTools, toolsErr := a.getAllTools()
+	if toolsErr != nil {
+		return assistantMsg, nil, toolsErr
+	}
 	// Now collect tools (which may block on MCP initialization)
-	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
+	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 
 	// Add the session and message ID into the context if needed by tools.
 	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
@@ -554,7 +569,8 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
 		default:
 			// Continue processing
 			var tool tools.BaseTool
-			for availableTool := range a.tools.Seq() {
+			allTools, _ := a.getAllTools()
+			for _, availableTool := range allTools {
 				if availableTool.Info().Name == toolCall.Name {
 					tool = availableTool
 					break

internal/llm/agent/mcp-tools.go 🔗

@@ -275,7 +275,7 @@ var mcpInitRequest = mcp.InitializeRequest{
 	},
 }
 
-func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
+func doGetMCPTools(permissions permission.Service, cfg *config.Config) []tools.BaseTool {
 	var wg sync.WaitGroup
 	result := csync.NewSlice[tools.BaseTool]()
 
@@ -309,7 +309,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
 				}
 			}()
 
-			ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
+			ctx, cancel := context.WithTimeout(context.Background(), mcpTimeout(m))
 			defer cancel()
 			c, err := createAndInitializeClient(ctx, name, m)
 			if err != nil {