@@ -941,43 +941,37 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con
return nil
}
+// subAgentParams holds the parameters for running a sub-agent.
+type subAgentParams struct {
+ Agent SessionAgent
+ SessionID string
+ AgentMessageID string
+ ToolCallID string
+ Prompt string
+ SessionTitle string
+ // SessionSetup is an optional callback invoked after session creation
+ // but before agent execution, for custom session configuration.
+ SessionSetup func(sessionID string)
+}
+
// runSubAgent runs a sub-agent and handles session management and cost accumulation.
// It creates a sub-session, runs the agent with the given prompt, and propagates
// the cost to the parent session.
-func (c *coordinator) runSubAgent(
- ctx context.Context,
- agent SessionAgent,
- sessionID, agentMessageID, toolCallID string,
- prompt string,
- sessionTitle string,
-) (fantasy.ToolResponse, error) {
- return c.runSubAgentWithOptions(ctx, agent, sessionID, agentMessageID, toolCallID, prompt, sessionTitle, nil)
-}
-
-// runSubAgentWithOptions runs a sub-agent with additional session configuration options.
-// The sessionSetup function is called after session creation but before agent execution.
-func (c *coordinator) runSubAgentWithOptions(
- ctx context.Context,
- agent SessionAgent,
- sessionID, agentMessageID, toolCallID string,
- prompt string,
- sessionTitle string,
- sessionSetup func(sessionID string),
-) (fantasy.ToolResponse, error) {
+func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
// Create sub-session
- agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, toolCallID)
- session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, sessionTitle)
+ agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
+ session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
if err != nil {
return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
}
// Call session setup function if provided
- if sessionSetup != nil {
- sessionSetup(session.ID)
+ if params.SessionSetup != nil {
+ params.SessionSetup(session.ID)
}
// Get model configuration
- model := agent.Model()
+ model := params.Agent.Model()
maxTokens := model.CatwalkCfg.DefaultMaxTokens
if model.ModelCfg.MaxTokens != 0 {
maxTokens = model.ModelCfg.MaxTokens
@@ -989,9 +983,9 @@ func (c *coordinator) runSubAgentWithOptions(
}
// Run the agent
- result, err := agent.Run(ctx, SessionAgentCall{
+ result, err := params.Agent.Run(ctx, SessionAgentCall{
SessionID: session.ID,
- Prompt: prompt,
+ Prompt: params.Prompt,
MaxOutputTokens: maxTokens,
ProviderOptions: getProviderOptions(model, providerCfg),
Temperature: model.ModelCfg.Temperature,
@@ -1005,7 +999,7 @@ func (c *coordinator) runSubAgentWithOptions(
}
// Update parent session cost
- if err := c.updateParentSessionCost(ctx, session.ID, sessionID); err != nil {
+ if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
return fantasy.ToolResponse{}, err
}
@@ -0,0 +1,385 @@
+package agent
+
+import (
+ "context"
+ "errors"
+ "testing"
+
+ "charm.land/catwalk/pkg/catwalk"
+ "charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// mockSessionAgent is a minimal mock for the SessionAgent interface.
+type mockSessionAgent struct {
+ model Model
+ runFunc func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error)
+ cancelled []string
+}
+
+func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
+ return m.runFunc(ctx, call)
+}
+
+func (m *mockSessionAgent) Model() Model { return m.model }
+func (m *mockSessionAgent) SetModels(large, small Model) {}
+func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {}
+func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {}
+func (m *mockSessionAgent) Cancel(sessionID string) {
+ m.cancelled = append(m.cancelled, sessionID)
+}
+func (m *mockSessionAgent) CancelAll() {}
+func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool { return false }
+func (m *mockSessionAgent) IsBusy() bool { return false }
+func (m *mockSessionAgent) QueuedPrompts(sessionID string) int { return 0 }
+func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil }
+func (m *mockSessionAgent) ClearQueue(sessionID string) {}
+func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error {
+ return nil
+}
+
+// newTestCoordinator creates a minimal coordinator for unit testing runSubAgent.
+func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
+ cfg, err := config.Init(env.workingDir, "", false)
+ require.NoError(t, err)
+ cfg.Providers.Set(providerID, providerCfg)
+ return &coordinator{
+ cfg: cfg,
+ sessions: env.sessions,
+ }
+}
+
+// newMockAgent creates a mockSessionAgent with the given provider and run function.
+func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent {
+ return &mockSessionAgent{
+ model: Model{
+ CatwalkCfg: catwalk.Model{
+ DefaultMaxTokens: maxTokens,
+ },
+ ModelCfg: config.SelectedModel{
+ Provider: providerID,
+ },
+ },
+ runFunc: runFunc,
+ }
+}
+
+// agentResultWithText creates a minimal AgentResult with the given text response.
+func agentResultWithText(text string) *fantasy.AgentResult {
+ return &fantasy.AgentResult{
+ Response: fantasy.Response{
+ Content: fantasy.ResponseContent{
+ fantasy.TextContent{Text: text},
+ },
+ },
+ }
+}
+
+func TestRunSubAgent(t *testing.T) {
+ const providerID = "test-provider"
+ providerCfg := config.ProviderConfig{ID: providerID}
+
+ t.Run("happy path", func(t *testing.T) {
+ env := testEnv(t)
+ coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+ parentSession, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
+ assert.Equal(t, "do something", call.Prompt)
+ assert.Equal(t, int64(4096), call.MaxOutputTokens)
+ return agentResultWithText("done"), nil
+ })
+
+ resp, err := coord.runSubAgent(t.Context(), subAgentParams{
+ Agent: agent,
+ SessionID: parentSession.ID,
+ AgentMessageID: "msg-1",
+ ToolCallID: "call-1",
+ Prompt: "do something",
+ SessionTitle: "Test Session",
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "done", resp.Content)
+ assert.False(t, resp.IsError)
+ })
+
+ t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) {
+ env := testEnv(t)
+ coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+ parentSession, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ agent := &mockSessionAgent{
+ model: Model{
+ CatwalkCfg: catwalk.Model{
+ DefaultMaxTokens: 4096,
+ },
+ ModelCfg: config.SelectedModel{
+ Provider: providerID,
+ MaxTokens: 8192,
+ },
+ },
+ runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
+ assert.Equal(t, int64(8192), call.MaxOutputTokens)
+ return agentResultWithText("ok"), nil
+ },
+ }
+
+ resp, err := coord.runSubAgent(t.Context(), subAgentParams{
+ Agent: agent,
+ SessionID: parentSession.ID,
+ AgentMessageID: "msg-1",
+ ToolCallID: "call-1",
+ Prompt: "test",
+ SessionTitle: "Test",
+ })
+ require.NoError(t, err)
+ assert.Equal(t, "ok", resp.Content)
+ })
+
+ t.Run("session creation failure with canceled context", func(t *testing.T) {
+ env := testEnv(t)
+ coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+ parentSession, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ agent := newMockAgent(providerID, 4096, nil)
+
+ // Use a canceled context to trigger CreateTaskSession failure.
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel()
+
+ _, err = coord.runSubAgent(ctx, subAgentParams{
+ Agent: agent,
+ SessionID: parentSession.ID,
+ AgentMessageID: "msg-1",
+ ToolCallID: "call-1",
+ Prompt: "test",
+ SessionTitle: "Test",
+ })
+ require.Error(t, err)
+ })
+
+ t.Run("provider not configured", func(t *testing.T) {
+ env := testEnv(t)
+ coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+ parentSession, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ // Agent references a provider that doesn't exist in config.
+ agent := newMockAgent("unknown-provider", 4096, nil)
+
+ _, err = coord.runSubAgent(t.Context(), subAgentParams{
+ Agent: agent,
+ SessionID: parentSession.ID,
+ AgentMessageID: "msg-1",
+ ToolCallID: "call-1",
+ Prompt: "test",
+ SessionTitle: "Test",
+ })
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "model provider not configured")
+ })
+
+ t.Run("agent run error returns error response", func(t *testing.T) {
+ env := testEnv(t)
+ coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+ parentSession, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
+ return nil, errors.New("agent exploded")
+ })
+
+ resp, err := coord.runSubAgent(t.Context(), subAgentParams{
+ Agent: agent,
+ SessionID: parentSession.ID,
+ AgentMessageID: "msg-1",
+ ToolCallID: "call-1",
+ Prompt: "test",
+ SessionTitle: "Test",
+ })
+ // runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error.
+ require.NoError(t, err)
+ assert.True(t, resp.IsError)
+ assert.Equal(t, "error generating response", resp.Content)
+ })
+
+ t.Run("session setup callback is invoked", func(t *testing.T) {
+ env := testEnv(t)
+ coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+ parentSession, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ var setupCalledWith string
+ agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
+ return agentResultWithText("ok"), nil
+ })
+
+ _, err = coord.runSubAgent(t.Context(), subAgentParams{
+ Agent: agent,
+ SessionID: parentSession.ID,
+ AgentMessageID: "msg-1",
+ ToolCallID: "call-1",
+ Prompt: "test",
+ SessionTitle: "Test",
+ SessionSetup: func(sessionID string) {
+ setupCalledWith = sessionID
+ },
+ })
+ require.NoError(t, err)
+ assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called")
+ })
+
+ t.Run("cost propagation to parent session", func(t *testing.T) {
+ env := testEnv(t)
+ coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+ parentSession, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
+ // Simulate the agent incurring cost by updating the child session.
+ childSession, err := env.sessions.Get(ctx, call.SessionID)
+ if err != nil {
+ return nil, err
+ }
+ childSession.Cost = 0.05
+ _, err = env.sessions.Save(ctx, childSession)
+ if err != nil {
+ return nil, err
+ }
+ return agentResultWithText("ok"), nil
+ })
+
+ _, err = coord.runSubAgent(t.Context(), subAgentParams{
+ Agent: agent,
+ SessionID: parentSession.ID,
+ AgentMessageID: "msg-1",
+ ToolCallID: "call-1",
+ Prompt: "test",
+ SessionTitle: "Test",
+ })
+ require.NoError(t, err)
+
+ updated, err := env.sessions.Get(t.Context(), parentSession.ID)
+ require.NoError(t, err)
+ assert.InDelta(t, 0.05, updated.Cost, 1e-9)
+ })
+}
+
+func TestUpdateParentSessionCost(t *testing.T) {
+ t.Run("accumulates cost correctly", func(t *testing.T) {
+ env := testEnv(t)
+ cfg, err := config.Init(env.workingDir, "", false)
+ require.NoError(t, err)
+ coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+ parent, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
+ require.NoError(t, err)
+
+ // Set child cost.
+ child.Cost = 0.10
+ _, err = env.sessions.Save(t.Context(), child)
+ require.NoError(t, err)
+
+ err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
+ require.NoError(t, err)
+
+ updated, err := env.sessions.Get(t.Context(), parent.ID)
+ require.NoError(t, err)
+ assert.InDelta(t, 0.10, updated.Cost, 1e-9)
+ })
+
+ t.Run("accumulates multiple child costs", func(t *testing.T) {
+ env := testEnv(t)
+ cfg, err := config.Init(env.workingDir, "", false)
+ require.NoError(t, err)
+ coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+ parent, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1")
+ require.NoError(t, err)
+ child1.Cost = 0.05
+ _, err = env.sessions.Save(t.Context(), child1)
+ require.NoError(t, err)
+
+ child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2")
+ require.NoError(t, err)
+ child2.Cost = 0.03
+ _, err = env.sessions.Save(t.Context(), child2)
+ require.NoError(t, err)
+
+ err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID)
+ require.NoError(t, err)
+ err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID)
+ require.NoError(t, err)
+
+ updated, err := env.sessions.Get(t.Context(), parent.ID)
+ require.NoError(t, err)
+ assert.InDelta(t, 0.08, updated.Cost, 1e-9)
+ })
+
+ t.Run("child session not found", func(t *testing.T) {
+ env := testEnv(t)
+ cfg, err := config.Init(env.workingDir, "", false)
+ require.NoError(t, err)
+ coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+ parent, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+
+ err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID)
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "get child session")
+ })
+
+ t.Run("parent session not found", func(t *testing.T) {
+ env := testEnv(t)
+ cfg, err := config.Init(env.workingDir, "", false)
+ require.NoError(t, err)
+ coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+ parent, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+ child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
+ require.NoError(t, err)
+
+ err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent")
+ require.Error(t, err)
+ assert.Contains(t, err.Error(), "get parent session")
+ })
+
+ t.Run("zero cost handled correctly", func(t *testing.T) {
+ env := testEnv(t)
+ cfg, err := config.Init(env.workingDir, "", false)
+ require.NoError(t, err)
+ coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+ parent, err := env.sessions.Create(t.Context(), "Parent")
+ require.NoError(t, err)
+ child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
+ require.NoError(t, err)
+
+ err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
+ require.NoError(t, err)
+
+ updated, err := env.sessions.Get(t.Context(), parent.ID)
+ require.NoError(t, err)
+ assert.InDelta(t, 0.0, updated.Cost, 1e-9)
+ })
+}