refactor: use params struct for runSubAgent and add unit tests

wanghuaiyu@qiniu.com created

Change summary

internal/agent/agent_tool.go         |   9 
internal/agent/agentic_fetch_tool.go |  19 
internal/agent/coordinator.go        |  50 +--
internal/agent/coordinator_test.go   | 385 ++++++++++++++++++++++++++++++
4 files changed, 424 insertions(+), 39 deletions(-)

Detailed changes

internal/agent/agent_tool.go 🔗

@@ -55,6 +55,13 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error)
 				return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
 			}
 
-			return c.runSubAgent(ctx, agent, sessionID, agentMessageID, call.ID, params.Prompt, "New Agent Session")
+			return c.runSubAgent(ctx, subAgentParams{
+				Agent:          agent,
+				SessionID:      sessionID,
+				AgentMessageID: agentMessageID,
+				ToolCallID:     call.ID,
+				Prompt:         params.Prompt,
+				SessionTitle:   "New Agent Session",
+			})
 		}), nil
 }

internal/agent/agentic_fetch_tool.go 🔗

@@ -184,17 +184,16 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
 				Tools:                fetchTools,
 			})
 
-			return c.runSubAgentWithOptions(
-				ctx,
-				agent,
-				validationResult.SessionID,
-				validationResult.AgentMessageID,
-				call.ID,
-				fullPrompt,
-				"Fetch Analysis",
-				func(sessionID string) {
+			return c.runSubAgent(ctx, subAgentParams{
+				Agent:          agent,
+				SessionID:      validationResult.SessionID,
+				AgentMessageID: validationResult.AgentMessageID,
+				ToolCallID:     call.ID,
+				Prompt:         fullPrompt,
+				SessionTitle:   "Fetch Analysis",
+				SessionSetup: func(sessionID string) {
 					c.permissions.AutoApproveSession(sessionID)
 				},
-			)
+			})
 		}), nil
 }

internal/agent/coordinator.go 🔗

@@ -941,43 +941,37 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con
 	return nil
 }
 
+// subAgentParams holds the parameters for running a sub-agent.
+type subAgentParams struct {
+	Agent          SessionAgent
+	SessionID      string
+	AgentMessageID string
+	ToolCallID     string
+	Prompt         string
+	SessionTitle   string
+	// SessionSetup is an optional callback invoked after session creation
+	// but before agent execution, for custom session configuration.
+	SessionSetup func(sessionID string)
+}
+
 // runSubAgent runs a sub-agent and handles session management and cost accumulation.
 // It creates a sub-session, runs the agent with the given prompt, and propagates
 // the cost to the parent session.
-func (c *coordinator) runSubAgent(
-	ctx context.Context,
-	agent SessionAgent,
-	sessionID, agentMessageID, toolCallID string,
-	prompt string,
-	sessionTitle string,
-) (fantasy.ToolResponse, error) {
-	return c.runSubAgentWithOptions(ctx, agent, sessionID, agentMessageID, toolCallID, prompt, sessionTitle, nil)
-}
-
-// runSubAgentWithOptions runs a sub-agent with additional session configuration options.
-// The sessionSetup function is called after session creation but before agent execution.
-func (c *coordinator) runSubAgentWithOptions(
-	ctx context.Context,
-	agent SessionAgent,
-	sessionID, agentMessageID, toolCallID string,
-	prompt string,
-	sessionTitle string,
-	sessionSetup func(sessionID string),
-) (fantasy.ToolResponse, error) {
+func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
 	// Create sub-session
-	agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, toolCallID)
-	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, sessionTitle)
+	agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
+	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
 	if err != nil {
 		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
 	}
 
 	// Call session setup function if provided
-	if sessionSetup != nil {
-		sessionSetup(session.ID)
+	if params.SessionSetup != nil {
+		params.SessionSetup(session.ID)
 	}
 
 	// Get model configuration
-	model := agent.Model()
+	model := params.Agent.Model()
 	maxTokens := model.CatwalkCfg.DefaultMaxTokens
 	if model.ModelCfg.MaxTokens != 0 {
 		maxTokens = model.ModelCfg.MaxTokens
@@ -989,9 +983,9 @@ func (c *coordinator) runSubAgentWithOptions(
 	}
 
 	// Run the agent
-	result, err := agent.Run(ctx, SessionAgentCall{
+	result, err := params.Agent.Run(ctx, SessionAgentCall{
 		SessionID:        session.ID,
-		Prompt:           prompt,
+		Prompt:           params.Prompt,
 		MaxOutputTokens:  maxTokens,
 		ProviderOptions:  getProviderOptions(model, providerCfg),
 		Temperature:      model.ModelCfg.Temperature,
@@ -1005,7 +999,7 @@ func (c *coordinator) runSubAgentWithOptions(
 	}
 
 	// Update parent session cost
-	if err := c.updateParentSessionCost(ctx, session.ID, sessionID); err != nil {
+	if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
 		return fantasy.ToolResponse{}, err
 	}
 

internal/agent/coordinator_test.go 🔗

@@ -0,0 +1,385 @@
+package agent
+
+import (
+	"context"
+	"errors"
+	"testing"
+
+	"charm.land/catwalk/pkg/catwalk"
+	"charm.land/fantasy"
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+// mockSessionAgent is a minimal mock for the SessionAgent interface.
+type mockSessionAgent struct {
+	model     Model
+	runFunc   func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error)
+	cancelled []string
+}
+
+func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
+	return m.runFunc(ctx, call)
+}
+
+func (m *mockSessionAgent) Model() Model                        { return m.model }
+func (m *mockSessionAgent) SetModels(large, small Model)        {}
+func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool)  {}
+func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {}
+func (m *mockSessionAgent) Cancel(sessionID string) {
+	m.cancelled = append(m.cancelled, sessionID)
+}
+func (m *mockSessionAgent) CancelAll()                                  {}
+func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool         { return false }
+func (m *mockSessionAgent) IsBusy() bool                                { return false }
+func (m *mockSessionAgent) QueuedPrompts(sessionID string) int          { return 0 }
+func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil }
+func (m *mockSessionAgent) ClearQueue(sessionID string)                 {}
+func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error {
+	return nil
+}
+
+// newTestCoordinator creates a minimal coordinator for unit testing runSubAgent.
+func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
+	cfg, err := config.Init(env.workingDir, "", false)
+	require.NoError(t, err)
+	cfg.Providers.Set(providerID, providerCfg)
+	return &coordinator{
+		cfg:      cfg,
+		sessions: env.sessions,
+	}
+}
+
+// newMockAgent creates a mockSessionAgent with the given provider and run function.
+func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent {
+	return &mockSessionAgent{
+		model: Model{
+			CatwalkCfg: catwalk.Model{
+				DefaultMaxTokens: maxTokens,
+			},
+			ModelCfg: config.SelectedModel{
+				Provider: providerID,
+			},
+		},
+		runFunc: runFunc,
+	}
+}
+
+// agentResultWithText creates a minimal AgentResult with the given text response.
+func agentResultWithText(text string) *fantasy.AgentResult {
+	return &fantasy.AgentResult{
+		Response: fantasy.Response{
+			Content: fantasy.ResponseContent{
+				fantasy.TextContent{Text: text},
+			},
+		},
+	}
+}
+
+func TestRunSubAgent(t *testing.T) {
+	const providerID = "test-provider"
+	providerCfg := config.ProviderConfig{ID: providerID}
+
+	t.Run("happy path", func(t *testing.T) {
+		env := testEnv(t)
+		coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+		parentSession, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
+			assert.Equal(t, "do something", call.Prompt)
+			assert.Equal(t, int64(4096), call.MaxOutputTokens)
+			return agentResultWithText("done"), nil
+		})
+
+		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
+			Agent:          agent,
+			SessionID:      parentSession.ID,
+			AgentMessageID: "msg-1",
+			ToolCallID:     "call-1",
+			Prompt:         "do something",
+			SessionTitle:   "Test Session",
+		})
+		require.NoError(t, err)
+		assert.Equal(t, "done", resp.Content)
+		assert.False(t, resp.IsError)
+	})
+
+	t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) {
+		env := testEnv(t)
+		coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+		parentSession, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		agent := &mockSessionAgent{
+			model: Model{
+				CatwalkCfg: catwalk.Model{
+					DefaultMaxTokens: 4096,
+				},
+				ModelCfg: config.SelectedModel{
+					Provider:  providerID,
+					MaxTokens: 8192,
+				},
+			},
+			runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
+				assert.Equal(t, int64(8192), call.MaxOutputTokens)
+				return agentResultWithText("ok"), nil
+			},
+		}
+
+		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
+			Agent:          agent,
+			SessionID:      parentSession.ID,
+			AgentMessageID: "msg-1",
+			ToolCallID:     "call-1",
+			Prompt:         "test",
+			SessionTitle:   "Test",
+		})
+		require.NoError(t, err)
+		assert.Equal(t, "ok", resp.Content)
+	})
+
+	t.Run("session creation failure with canceled context", func(t *testing.T) {
+		env := testEnv(t)
+		coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+		parentSession, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		agent := newMockAgent(providerID, 4096, nil)
+
+		// Use a canceled context to trigger CreateTaskSession failure.
+		ctx, cancel := context.WithCancel(t.Context())
+		cancel()
+
+		_, err = coord.runSubAgent(ctx, subAgentParams{
+			Agent:          agent,
+			SessionID:      parentSession.ID,
+			AgentMessageID: "msg-1",
+			ToolCallID:     "call-1",
+			Prompt:         "test",
+			SessionTitle:   "Test",
+		})
+		require.Error(t, err)
+	})
+
+	t.Run("provider not configured", func(t *testing.T) {
+		env := testEnv(t)
+		coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+		parentSession, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		// Agent references a provider that doesn't exist in config.
+		agent := newMockAgent("unknown-provider", 4096, nil)
+
+		_, err = coord.runSubAgent(t.Context(), subAgentParams{
+			Agent:          agent,
+			SessionID:      parentSession.ID,
+			AgentMessageID: "msg-1",
+			ToolCallID:     "call-1",
+			Prompt:         "test",
+			SessionTitle:   "Test",
+		})
+		require.Error(t, err)
+		assert.Contains(t, err.Error(), "model provider not configured")
+	})
+
+	t.Run("agent run error returns error response", func(t *testing.T) {
+		env := testEnv(t)
+		coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+		parentSession, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
+			return nil, errors.New("agent exploded")
+		})
+
+		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
+			Agent:          agent,
+			SessionID:      parentSession.ID,
+			AgentMessageID: "msg-1",
+			ToolCallID:     "call-1",
+			Prompt:         "test",
+			SessionTitle:   "Test",
+		})
+		// runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error.
+		require.NoError(t, err)
+		assert.True(t, resp.IsError)
+		assert.Equal(t, "error generating response", resp.Content)
+	})
+
+	t.Run("session setup callback is invoked", func(t *testing.T) {
+		env := testEnv(t)
+		coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+		parentSession, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		var setupCalledWith string
+		agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
+			return agentResultWithText("ok"), nil
+		})
+
+		_, err = coord.runSubAgent(t.Context(), subAgentParams{
+			Agent:          agent,
+			SessionID:      parentSession.ID,
+			AgentMessageID: "msg-1",
+			ToolCallID:     "call-1",
+			Prompt:         "test",
+			SessionTitle:   "Test",
+			SessionSetup: func(sessionID string) {
+				setupCalledWith = sessionID
+			},
+		})
+		require.NoError(t, err)
+		assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called")
+	})
+
+	t.Run("cost propagation to parent session", func(t *testing.T) {
+		env := testEnv(t)
+		coord := newTestCoordinator(t, env, providerID, providerCfg)
+
+		parentSession, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
+			// Simulate the agent incurring cost by updating the child session.
+			childSession, err := env.sessions.Get(ctx, call.SessionID)
+			if err != nil {
+				return nil, err
+			}
+			childSession.Cost = 0.05
+			_, err = env.sessions.Save(ctx, childSession)
+			if err != nil {
+				return nil, err
+			}
+			return agentResultWithText("ok"), nil
+		})
+
+		_, err = coord.runSubAgent(t.Context(), subAgentParams{
+			Agent:          agent,
+			SessionID:      parentSession.ID,
+			AgentMessageID: "msg-1",
+			ToolCallID:     "call-1",
+			Prompt:         "test",
+			SessionTitle:   "Test",
+		})
+		require.NoError(t, err)
+
+		updated, err := env.sessions.Get(t.Context(), parentSession.ID)
+		require.NoError(t, err)
+		assert.InDelta(t, 0.05, updated.Cost, 1e-9)
+	})
+}
+
+func TestUpdateParentSessionCost(t *testing.T) {
+	t.Run("accumulates cost correctly", func(t *testing.T) {
+		env := testEnv(t)
+		cfg, err := config.Init(env.workingDir, "", false)
+		require.NoError(t, err)
+		coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+		parent, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
+		require.NoError(t, err)
+
+		// Set child cost.
+		child.Cost = 0.10
+		_, err = env.sessions.Save(t.Context(), child)
+		require.NoError(t, err)
+
+		err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
+		require.NoError(t, err)
+
+		updated, err := env.sessions.Get(t.Context(), parent.ID)
+		require.NoError(t, err)
+		assert.InDelta(t, 0.10, updated.Cost, 1e-9)
+	})
+
+	t.Run("accumulates multiple child costs", func(t *testing.T) {
+		env := testEnv(t)
+		cfg, err := config.Init(env.workingDir, "", false)
+		require.NoError(t, err)
+		coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+		parent, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1")
+		require.NoError(t, err)
+		child1.Cost = 0.05
+		_, err = env.sessions.Save(t.Context(), child1)
+		require.NoError(t, err)
+
+		child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2")
+		require.NoError(t, err)
+		child2.Cost = 0.03
+		_, err = env.sessions.Save(t.Context(), child2)
+		require.NoError(t, err)
+
+		err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID)
+		require.NoError(t, err)
+		err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID)
+		require.NoError(t, err)
+
+		updated, err := env.sessions.Get(t.Context(), parent.ID)
+		require.NoError(t, err)
+		assert.InDelta(t, 0.08, updated.Cost, 1e-9)
+	})
+
+	t.Run("child session not found", func(t *testing.T) {
+		env := testEnv(t)
+		cfg, err := config.Init(env.workingDir, "", false)
+		require.NoError(t, err)
+		coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+		parent, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+
+		err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID)
+		require.Error(t, err)
+		assert.Contains(t, err.Error(), "get child session")
+	})
+
+	t.Run("parent session not found", func(t *testing.T) {
+		env := testEnv(t)
+		cfg, err := config.Init(env.workingDir, "", false)
+		require.NoError(t, err)
+		coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+		parent, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
+		require.NoError(t, err)
+
+		err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent")
+		require.Error(t, err)
+		assert.Contains(t, err.Error(), "get parent session")
+	})
+
+	t.Run("zero cost handled correctly", func(t *testing.T) {
+		env := testEnv(t)
+		cfg, err := config.Init(env.workingDir, "", false)
+		require.NoError(t, err)
+		coord := &coordinator{cfg: cfg, sessions: env.sessions}
+
+		parent, err := env.sessions.Create(t.Context(), "Parent")
+		require.NoError(t, err)
+		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
+		require.NoError(t, err)
+
+		err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
+		require.NoError(t, err)
+
+		updated, err := env.sessions.Get(t.Context(), parent.ID)
+		require.NoError(t, err)
+		assert.InDelta(t, 0.0, updated.Cost, 1e-9)
+	})
+}