@@ -8,6 +8,7 @@ import (
"slices"
"strings"
"sync"
+ "sync/atomic"
"time"
"github.com/charmbracelet/crush/internal/config"
@@ -67,7 +68,9 @@ type agent struct {
sessions session.Service
messages message.Service
- tools []tools.BaseTool
+ toolsDone atomic.Bool
+ tools []tools.BaseTool
+
provider provider.Provider
providerID string
@@ -94,46 +97,7 @@ func NewAgent(
) (Service, error) {
ctx := context.Background()
cfg := config.Get()
- otherTools := GetMCPTools(ctx, permissions, cfg)
- if len(lspClients) > 0 {
- otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
- }
- cwd := cfg.WorkingDir()
- allTools := []tools.BaseTool{
- tools.NewBashTool(permissions, cwd),
- tools.NewDownloadTool(permissions, cwd),
- tools.NewEditTool(lspClients, permissions, history, cwd),
- tools.NewFetchTool(permissions, cwd),
- tools.NewGlobTool(cwd),
- tools.NewGrepTool(cwd),
- tools.NewLsTool(cwd),
- tools.NewSourcegraphTool(),
- tools.NewViewTool(lspClients, cwd),
- tools.NewWriteTool(lspClients, permissions, history, cwd),
- }
-
- if agentCfg.ID == "coder" {
- taskAgentCfg := config.Get().Agents["task"]
- if taskAgentCfg.ID == "" {
- return nil, fmt.Errorf("task agent not found in config")
- }
- taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
- if err != nil {
- return nil, fmt.Errorf("failed to create task agent: %w", err)
- }
-
- allTools = append(
- allTools,
- NewAgentTool(
- taskAgent,
- sessions,
- messages,
- ),
- )
- }
-
- allTools = append(allTools, otherTools...)
providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
if providerCfg == nil {
return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
@@ -190,15 +154,22 @@ func NewAgent(
return nil, err
}
- agentTools := []tools.BaseTool{}
- if agentCfg.AllowedTools == nil {
- agentTools = allTools
- } else {
- for _, tool := range allTools {
- if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
- agentTools = append(agentTools, tool)
- }
+ var agentTool tools.BaseTool
+ if agentCfg.ID == "coder" {
+ taskAgentCfg := config.Get().Agents["task"]
+ if taskAgentCfg.ID == "" {
+ return nil, fmt.Errorf("task agent not found in config")
}
+ taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create task agent: %w", err)
+ }
+
+ agentTool = NewAgentTool(
+ taskAgent,
+ sessions,
+ messages,
+ )
}
agent := &agent{
@@ -208,13 +179,55 @@ func NewAgent(
providerID: string(providerCfg.ID),
messages: messages,
sessions: sessions,
- tools: agentTools,
titleProvider: titleProvider,
summarizeProvider: summarizeProvider,
summarizeProviderID: string(smallModelProviderCfg.ID),
activeRequests: sync.Map{},
}
+ go func() {
+ slog.Info("Initializing agent tools", "agent", agentCfg.ID)
+
+ cwd := cfg.WorkingDir()
+ allTools := []tools.BaseTool{
+ tools.NewBashTool(permissions, cwd),
+ tools.NewDownloadTool(permissions, cwd),
+ tools.NewEditTool(lspClients, permissions, history, cwd),
+ tools.NewFetchTool(permissions, cwd),
+ tools.NewGlobTool(cwd),
+ tools.NewGrepTool(cwd),
+ tools.NewLsTool(cwd),
+ tools.NewSourcegraphTool(),
+ tools.NewViewTool(lspClients, cwd),
+ tools.NewWriteTool(lspClients, permissions, history, cwd),
+ }
+
+ mcpTools := GetMCPTools(ctx, permissions, cfg)
+ if len(lspClients) > 0 {
+ mcpTools = append(mcpTools, tools.NewDiagnosticsTool(lspClients))
+ }
+ allTools = append(allTools, mcpTools...)
+
+ if agentTool != nil {
+ allTools = append(allTools, agentTool)
+ }
+
+ agentTools := []tools.BaseTool{}
+ if agentCfg.AllowedTools == nil {
+ agentTools = allTools
+ } else {
+ for _, tool := range allTools {
+ if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
+ agentTools = append(agentTools, tool)
+ }
+ }
+ }
+
+ slog.Info("Initialized agent tools", "agent", agentCfg.ID)
+ agent.tools = agentTools
+ agent.toolsDone.Store(true)
+ }()
+
return agent, nil
}
@@ -437,6 +450,9 @@ func (a *agent) createUserMessage(ctx context.Context, sessionID, content string
func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
+ if !a.toolsDone.Load() {
+ return message.Message{}, nil, fmt.Errorf("tools not initialized yet")
+ }
eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{