diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index 5c9a95fb7f210625c9a4a04a803dcfc634f471a3..1a7286e342d245c7e7ac1161111d8c205300018b 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -55,6 +55,13 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) return fantasy.ToolResponse{}, errors.New("agent message id missing from context") } - return c.runSubAgent(ctx, agent, sessionID, agentMessageID, call.ID, params.Prompt, "New Agent Session") + return c.runSubAgent(ctx, subAgentParams{ + Agent: agent, + SessionID: sessionID, + AgentMessageID: agentMessageID, + ToolCallID: call.ID, + Prompt: params.Prompt, + SessionTitle: "New Agent Session", + }) }), nil } diff --git a/internal/agent/agentic_fetch_tool.go b/internal/agent/agentic_fetch_tool.go index 26ed301eaff7955dac95bd2e0043490730e46f85..0bd942e013b706389fb90352c891a4f2ea014f30 100644 --- a/internal/agent/agentic_fetch_tool.go +++ b/internal/agent/agentic_fetch_tool.go @@ -184,17 +184,16 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( Tools: fetchTools, }) - return c.runSubAgentWithOptions( - ctx, - agent, - validationResult.SessionID, - validationResult.AgentMessageID, - call.ID, - fullPrompt, - "Fetch Analysis", - func(sessionID string) { + return c.runSubAgent(ctx, subAgentParams{ + Agent: agent, + SessionID: validationResult.SessionID, + AgentMessageID: validationResult.AgentMessageID, + ToolCallID: call.ID, + Prompt: fullPrompt, + SessionTitle: "Fetch Analysis", + SessionSetup: func(sessionID string) { c.permissions.AutoApproveSession(sessionID) }, - ) + }) }), nil } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 57b42d12de65e3d194ee9b1adaf96392082e3f55..40076c34ee429816e93d5c5082f598a5dd02ec6c 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -941,43 +941,37 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con return nil } +// subAgentParams holds the parameters for running a sub-agent. +type subAgentParams struct { + Agent SessionAgent + SessionID string + AgentMessageID string + ToolCallID string + Prompt string + SessionTitle string + // SessionSetup is an optional callback invoked after session creation + // but before agent execution, for custom session configuration. + SessionSetup func(sessionID string) +} + // runSubAgent runs a sub-agent and handles session management and cost accumulation. // It creates a sub-session, runs the agent with the given prompt, and propagates // the cost to the parent session. -func (c *coordinator) runSubAgent( - ctx context.Context, - agent SessionAgent, - sessionID, agentMessageID, toolCallID string, - prompt string, - sessionTitle string, -) (fantasy.ToolResponse, error) { - return c.runSubAgentWithOptions(ctx, agent, sessionID, agentMessageID, toolCallID, prompt, sessionTitle, nil) -} - -// runSubAgentWithOptions runs a sub-agent with additional session configuration options. -// The sessionSetup function is called after session creation but before agent execution. -func (c *coordinator) runSubAgentWithOptions( - ctx context.Context, - agent SessionAgent, - sessionID, agentMessageID, toolCallID string, - prompt string, - sessionTitle string, - sessionSetup func(sessionID string), -) (fantasy.ToolResponse, error) { +func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) { // Create sub-session - agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, toolCallID) - session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, sessionTitle) + agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID) + session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle) if err != nil { return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err) } // Call session setup function if provided - if sessionSetup != nil { - sessionSetup(session.ID) + if params.SessionSetup != nil { + params.SessionSetup(session.ID) } // Get model configuration - model := agent.Model() + model := params.Agent.Model() maxTokens := model.CatwalkCfg.DefaultMaxTokens if model.ModelCfg.MaxTokens != 0 { maxTokens = model.ModelCfg.MaxTokens @@ -989,9 +983,9 @@ func (c *coordinator) runSubAgentWithOptions( } // Run the agent - result, err := agent.Run(ctx, SessionAgentCall{ + result, err := params.Agent.Run(ctx, SessionAgentCall{ SessionID: session.ID, - Prompt: prompt, + Prompt: params.Prompt, MaxOutputTokens: maxTokens, ProviderOptions: getProviderOptions(model, providerCfg), Temperature: model.ModelCfg.Temperature, @@ -1005,7 +999,7 @@ func (c *coordinator) runSubAgentWithOptions( } // Update parent session cost - if err := c.updateParentSessionCost(ctx, session.ID, sessionID); err != nil { + if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil { return fantasy.ToolResponse{}, err } diff --git a/internal/agent/coordinator_test.go b/internal/agent/coordinator_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3c270394cba9c1758e4a9029a149027af6bf36c2 --- /dev/null +++ b/internal/agent/coordinator_test.go @@ -0,0 +1,385 @@ +package agent + +import ( + "context" + "errors" + "testing" + + "charm.land/catwalk/pkg/catwalk" + "charm.land/fantasy" + "github.com/charmbracelet/crush/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockSessionAgent is a minimal mock for the SessionAgent interface. +type mockSessionAgent struct { + model Model + runFunc func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) + cancelled []string +} + +func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) { + return m.runFunc(ctx, call) +} + +func (m *mockSessionAgent) Model() Model { return m.model } +func (m *mockSessionAgent) SetModels(large, small Model) {} +func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {} +func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {} +func (m *mockSessionAgent) Cancel(sessionID string) { + m.cancelled = append(m.cancelled, sessionID) +} +func (m *mockSessionAgent) CancelAll() {} +func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool { return false } +func (m *mockSessionAgent) IsBusy() bool { return false } +func (m *mockSessionAgent) QueuedPrompts(sessionID string) int { return 0 } +func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil } +func (m *mockSessionAgent) ClearQueue(sessionID string) {} +func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error { + return nil +} + +// newTestCoordinator creates a minimal coordinator for unit testing runSubAgent. +func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator { + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + cfg.Providers.Set(providerID, providerCfg) + return &coordinator{ + cfg: cfg, + sessions: env.sessions, + } +} + +// newMockAgent creates a mockSessionAgent with the given provider and run function. +func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent { + return &mockSessionAgent{ + model: Model{ + CatwalkCfg: catwalk.Model{ + DefaultMaxTokens: maxTokens, + }, + ModelCfg: config.SelectedModel{ + Provider: providerID, + }, + }, + runFunc: runFunc, + } +} + +// agentResultWithText creates a minimal AgentResult with the given text response. +func agentResultWithText(text string) *fantasy.AgentResult { + return &fantasy.AgentResult{ + Response: fantasy.Response{ + Content: fantasy.ResponseContent{ + fantasy.TextContent{Text: text}, + }, + }, + } +} + +func TestRunSubAgent(t *testing.T) { + const providerID = "test-provider" + providerCfg := config.ProviderConfig{ID: providerID} + + t.Run("happy path", func(t *testing.T) { + env := testEnv(t) + coord := newTestCoordinator(t, env, providerID, providerCfg) + + parentSession, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) { + assert.Equal(t, "do something", call.Prompt) + assert.Equal(t, int64(4096), call.MaxOutputTokens) + return agentResultWithText("done"), nil + }) + + resp, err := coord.runSubAgent(t.Context(), subAgentParams{ + Agent: agent, + SessionID: parentSession.ID, + AgentMessageID: "msg-1", + ToolCallID: "call-1", + Prompt: "do something", + SessionTitle: "Test Session", + }) + require.NoError(t, err) + assert.Equal(t, "done", resp.Content) + assert.False(t, resp.IsError) + }) + + t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) { + env := testEnv(t) + coord := newTestCoordinator(t, env, providerID, providerCfg) + + parentSession, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + agent := &mockSessionAgent{ + model: Model{ + CatwalkCfg: catwalk.Model{ + DefaultMaxTokens: 4096, + }, + ModelCfg: config.SelectedModel{ + Provider: providerID, + MaxTokens: 8192, + }, + }, + runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) { + assert.Equal(t, int64(8192), call.MaxOutputTokens) + return agentResultWithText("ok"), nil + }, + } + + resp, err := coord.runSubAgent(t.Context(), subAgentParams{ + Agent: agent, + SessionID: parentSession.ID, + AgentMessageID: "msg-1", + ToolCallID: "call-1", + Prompt: "test", + SessionTitle: "Test", + }) + require.NoError(t, err) + assert.Equal(t, "ok", resp.Content) + }) + + t.Run("session creation failure with canceled context", func(t *testing.T) { + env := testEnv(t) + coord := newTestCoordinator(t, env, providerID, providerCfg) + + parentSession, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + agent := newMockAgent(providerID, 4096, nil) + + // Use a canceled context to trigger CreateTaskSession failure. + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + _, err = coord.runSubAgent(ctx, subAgentParams{ + Agent: agent, + SessionID: parentSession.ID, + AgentMessageID: "msg-1", + ToolCallID: "call-1", + Prompt: "test", + SessionTitle: "Test", + }) + require.Error(t, err) + }) + + t.Run("provider not configured", func(t *testing.T) { + env := testEnv(t) + coord := newTestCoordinator(t, env, providerID, providerCfg) + + parentSession, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + // Agent references a provider that doesn't exist in config. + agent := newMockAgent("unknown-provider", 4096, nil) + + _, err = coord.runSubAgent(t.Context(), subAgentParams{ + Agent: agent, + SessionID: parentSession.ID, + AgentMessageID: "msg-1", + ToolCallID: "call-1", + Prompt: "test", + SessionTitle: "Test", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "model provider not configured") + }) + + t.Run("agent run error returns error response", func(t *testing.T) { + env := testEnv(t) + coord := newTestCoordinator(t, env, providerID, providerCfg) + + parentSession, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) { + return nil, errors.New("agent exploded") + }) + + resp, err := coord.runSubAgent(t.Context(), subAgentParams{ + Agent: agent, + SessionID: parentSession.ID, + AgentMessageID: "msg-1", + ToolCallID: "call-1", + Prompt: "test", + SessionTitle: "Test", + }) + // runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error. + require.NoError(t, err) + assert.True(t, resp.IsError) + assert.Equal(t, "error generating response", resp.Content) + }) + + t.Run("session setup callback is invoked", func(t *testing.T) { + env := testEnv(t) + coord := newTestCoordinator(t, env, providerID, providerCfg) + + parentSession, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + var setupCalledWith string + agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) { + return agentResultWithText("ok"), nil + }) + + _, err = coord.runSubAgent(t.Context(), subAgentParams{ + Agent: agent, + SessionID: parentSession.ID, + AgentMessageID: "msg-1", + ToolCallID: "call-1", + Prompt: "test", + SessionTitle: "Test", + SessionSetup: func(sessionID string) { + setupCalledWith = sessionID + }, + }) + require.NoError(t, err) + assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called") + }) + + t.Run("cost propagation to parent session", func(t *testing.T) { + env := testEnv(t) + coord := newTestCoordinator(t, env, providerID, providerCfg) + + parentSession, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) { + // Simulate the agent incurring cost by updating the child session. + childSession, err := env.sessions.Get(ctx, call.SessionID) + if err != nil { + return nil, err + } + childSession.Cost = 0.05 + _, err = env.sessions.Save(ctx, childSession) + if err != nil { + return nil, err + } + return agentResultWithText("ok"), nil + }) + + _, err = coord.runSubAgent(t.Context(), subAgentParams{ + Agent: agent, + SessionID: parentSession.ID, + AgentMessageID: "msg-1", + ToolCallID: "call-1", + Prompt: "test", + SessionTitle: "Test", + }) + require.NoError(t, err) + + updated, err := env.sessions.Get(t.Context(), parentSession.ID) + require.NoError(t, err) + assert.InDelta(t, 0.05, updated.Cost, 1e-9) + }) +} + +func TestUpdateParentSessionCost(t *testing.T) { + t.Run("accumulates cost correctly", func(t *testing.T) { + env := testEnv(t) + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + coord := &coordinator{cfg: cfg, sessions: env.sessions} + + parent, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child") + require.NoError(t, err) + + // Set child cost. + child.Cost = 0.10 + _, err = env.sessions.Save(t.Context(), child) + require.NoError(t, err) + + err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID) + require.NoError(t, err) + + updated, err := env.sessions.Get(t.Context(), parent.ID) + require.NoError(t, err) + assert.InDelta(t, 0.10, updated.Cost, 1e-9) + }) + + t.Run("accumulates multiple child costs", func(t *testing.T) { + env := testEnv(t) + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + coord := &coordinator{cfg: cfg, sessions: env.sessions} + + parent, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1") + require.NoError(t, err) + child1.Cost = 0.05 + _, err = env.sessions.Save(t.Context(), child1) + require.NoError(t, err) + + child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2") + require.NoError(t, err) + child2.Cost = 0.03 + _, err = env.sessions.Save(t.Context(), child2) + require.NoError(t, err) + + err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID) + require.NoError(t, err) + err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID) + require.NoError(t, err) + + updated, err := env.sessions.Get(t.Context(), parent.ID) + require.NoError(t, err) + assert.InDelta(t, 0.08, updated.Cost, 1e-9) + }) + + t.Run("child session not found", func(t *testing.T) { + env := testEnv(t) + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + coord := &coordinator{cfg: cfg, sessions: env.sessions} + + parent, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + + err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID) + require.Error(t, err) + assert.Contains(t, err.Error(), "get child session") + }) + + t.Run("parent session not found", func(t *testing.T) { + env := testEnv(t) + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + coord := &coordinator{cfg: cfg, sessions: env.sessions} + + parent, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child") + require.NoError(t, err) + + err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent") + require.Error(t, err) + assert.Contains(t, err.Error(), "get parent session") + }) + + t.Run("zero cost handled correctly", func(t *testing.T) { + env := testEnv(t) + cfg, err := config.Init(env.workingDir, "", false) + require.NoError(t, err) + coord := &coordinator{cfg: cfg, sessions: env.sessions} + + parent, err := env.sessions.Create(t.Context(), "Parent") + require.NoError(t, err) + child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child") + require.NoError(t, err) + + err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID) + require.NoError(t, err) + + updated, err := env.sessions.Get(t.Context(), parent.ID) + require.NoError(t, err) + assert.InDelta(t, 0.0, updated.Cost, 1e-9) + }) +}