internal/app/app.go 🔗
@@ -278,7 +278,6 @@ func (app *App) InitCoderAgent() error {
}
var err error
app.CoderAgent, err = agent.NewAgent(
- app.globalCtx,
coderAgentCfg,
app.Permissions,
app.Sessions,
kujtimiihoxha created
internal/app/app.go | 1
internal/llm/agent/agent.go | 52 ++++++++++++++++++++++------------
internal/llm/agent/mcp-tools.go | 4 +-
3 files changed, 36 insertions(+), 21 deletions(-)
@@ -278,7 +278,6 @@ func (app *App) InitCoderAgent() error {
}
var err error
app.CoderAgent, err = agent.NewAgent(
- app.globalCtx,
coderAgentCfg,
app.Permissions,
app.Sessions,
@@ -72,6 +72,8 @@ type agent struct {
mcpTools []McpTool
tools *csync.LazySlice[tools.BaseTool]
+ // We need this to be able to update it when model changes
+ agentToolFn func() (tools.BaseTool, error)
provider provider.Provider
providerID string
@@ -91,7 +93,6 @@ var agentPromptMap = map[string]prompt.PromptID{
}
func NewAgent(
- ctx context.Context,
agentCfg config.Agent,
// These services are needed in the tools
permissions permission.Service,
@@ -102,18 +103,19 @@ func NewAgent(
) (Service, error) {
cfg := config.Get()
- var agentTool tools.BaseTool
+ var agentToolFn func() (tools.BaseTool, error)
if agentCfg.ID == "coder" {
- taskAgentCfg := config.Get().Agents["task"]
- if taskAgentCfg.ID == "" {
- return nil, fmt.Errorf("task agent not found in config")
- }
- taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
- if err != nil {
- return nil, fmt.Errorf("failed to create task agent: %w", err)
+ agentToolFn = func() (tools.BaseTool, error) {
+ taskAgentCfg := config.Get().Agents["task"]
+ if taskAgentCfg.ID == "" {
+ return nil, fmt.Errorf("task agent not found in config")
+ }
+ taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create task agent: %w", err)
+ }
+ return NewAgentTool(taskAgent, sessions, messages), nil
}
-
- agentTool = NewAgentTool(taskAgent, sessions, messages)
}
providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
@@ -195,7 +197,7 @@ func NewAgent(
}
mcpToolsOnce.Do(func() {
- mcpTools = doGetMCPTools(ctx, permissions, cfg)
+ mcpTools = doGetMCPTools(permissions, cfg)
})
allTools = append(allTools, mcpTools...)
@@ -203,10 +205,6 @@ func NewAgent(
allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
}
- if agentTool != nil {
- allTools = append(allTools, agentTool)
- }
-
if agentCfg.AllowedTools == nil {
return allTools
}
@@ -230,6 +228,7 @@ func NewAgent(
titleProvider: titleProvider,
summarizeProvider: summarizeProvider,
summarizeProviderID: string(providerCfg.ID),
+ agentToolFn: agentToolFn,
activeRequests: csync.NewMap[string, context.CancelFunc](),
tools: csync.NewLazySlice(toolFn),
promptQueue: csync.NewMap[string, []string](),
@@ -500,6 +499,18 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
})
}
+func (a *agent) getAllTools() ([]tools.BaseTool, error) {
+ allTools := slices.Collect(a.tools.Seq())
+ if a.agentToolFn != nil {
+ agentTool, agentToolErr := a.agentToolFn()
+ if agentToolErr != nil {
+ return nil, agentToolErr
+ }
+ allTools = append(allTools, agentTool)
+ }
+ return allTools, nil
+}
+
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
@@ -514,8 +525,12 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
}
+ allTools, toolsErr := a.getAllTools()
+ if toolsErr != nil {
+ return assistantMsg, nil, toolsErr
+ }
// Now collect tools (which may block on MCP initialization)
- eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
+ eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
// Add the session and message ID into the context if needed by tools.
ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
@@ -554,7 +569,8 @@ func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msg
default:
// Continue processing
var tool tools.BaseTool
- for availableTool := range a.tools.Seq() {
+ allTools, _ := a.getAllTools()
+ for _, availableTool := range allTools {
if availableTool.Info().Name == toolCall.Name {
tool = availableTool
break
@@ -275,7 +275,7 @@ var mcpInitRequest = mcp.InitializeRequest{
},
}
-func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *config.Config) []tools.BaseTool {
+func doGetMCPTools(permissions permission.Service, cfg *config.Config) []tools.BaseTool {
var wg sync.WaitGroup
result := csync.NewSlice[tools.BaseTool]()
@@ -309,7 +309,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
}
}()
- ctx, cancel := context.WithTimeout(ctx, mcpTimeout(m))
+ ctx, cancel := context.WithTimeout(context.Background(), mcpTimeout(m))
defer cancel()
c, err := createAndInitializeClient(ctx, name, m)
if err != nil {