refactor: extract common sub-agent execution logic

wanghuaiyu@qiniu.com created

- Add runSubAgent and runSubAgentWithOptions methods to coordinator
- Simplify agent_tool.go from 106 to 60 lines (43.4% reduction)
- Simplify agentic_fetch_tool.go from 235 to 200 lines (14.9% reduction)
- Eliminate 81 lines of duplicated session management code
- Support custom session setup via callback for special cases
- Improve error handling with proper error wrapping
- Add updateParentSessionCost helper for consistent cost propagation

Change summary

internal/agent/agent_tool.go         | 47 --------------
internal/agent/agentic_fetch_tool.go | 58 +++--------------
internal/agent/coordinator.go        | 92 +++++++++++++++++++++++++++++
3 files changed, 105 insertions(+), 92 deletions(-)

Detailed changes

internal/agent/agent_tool.go 🔗

@@ -4,7 +4,6 @@ import (
 	"context"
 	_ "embed"
 	"errors"
-	"fmt"
 
 	"charm.land/fantasy"
 
@@ -56,50 +55,6 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error)
 				return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
 			}
 
-			agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
-			session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
-			if err != nil {
-				return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
-			}
-			model := agent.Model()
-			maxTokens := model.CatwalkCfg.DefaultMaxTokens
-			if model.ModelCfg.MaxTokens != 0 {
-				maxTokens = model.ModelCfg.MaxTokens
-			}
-
-			providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
-			if !ok {
-				return fantasy.ToolResponse{}, errors.New("model provider not configured")
-			}
-			result, err := agent.Run(ctx, SessionAgentCall{
-				SessionID:        session.ID,
-				Prompt:           params.Prompt,
-				MaxOutputTokens:  maxTokens,
-				ProviderOptions:  getProviderOptions(model, providerCfg),
-				Temperature:      model.ModelCfg.Temperature,
-				TopP:             model.ModelCfg.TopP,
-				TopK:             model.ModelCfg.TopK,
-				FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
-				PresencePenalty:  model.ModelCfg.PresencePenalty,
-			})
-			if err != nil {
-				return fantasy.NewTextErrorResponse("error generating response"), nil
-			}
-			updatedSession, err := c.sessions.Get(ctx, session.ID)
-			if err != nil {
-				return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
-			}
-			parentSession, err := c.sessions.Get(ctx, sessionID)
-			if err != nil {
-				return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
-			}
-
-			parentSession.Cost += updatedSession.Cost
-
-			_, err = c.sessions.Save(ctx, parentSession)
-			if err != nil {
-				return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
-			}
-			return fantasy.NewTextResponse(result.Response.Content.Text()), nil
+			return c.runSubAgent(ctx, agent, sessionID, agentMessageID, call.ID, params.Prompt, "New Agent Session")
 		}), nil
 }

internal/agent/agentic_fetch_tool.go 🔗

@@ -184,51 +184,17 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
 				Tools:                fetchTools,
 			})
 
-			agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID)
-			session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Fetch Analysis")
-			if err != nil {
-				return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
-			}
-
-			c.permissions.AutoApproveSession(session.ID)
-
-			// Use small model for web content analysis (faster and cheaper)
-			maxTokens := small.CatwalkCfg.DefaultMaxTokens
-			if small.ModelCfg.MaxTokens != 0 {
-				maxTokens = small.ModelCfg.MaxTokens
-			}
-
-			result, err := agent.Run(ctx, SessionAgentCall{
-				SessionID:        session.ID,
-				Prompt:           fullPrompt,
-				MaxOutputTokens:  maxTokens,
-				ProviderOptions:  getProviderOptions(small, smallProviderCfg),
-				Temperature:      small.ModelCfg.Temperature,
-				TopP:             small.ModelCfg.TopP,
-				TopK:             small.ModelCfg.TopK,
-				FrequencyPenalty: small.ModelCfg.FrequencyPenalty,
-				PresencePenalty:  small.ModelCfg.PresencePenalty,
-			})
-			if err != nil {
-				return fantasy.NewTextErrorResponse("error generating response"), nil
-			}
-
-			updatedSession, err := c.sessions.Get(ctx, session.ID)
-			if err != nil {
-				return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
-			}
-			parentSession, err := c.sessions.Get(ctx, validationResult.SessionID)
-			if err != nil {
-				return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
-			}
-
-			parentSession.Cost += updatedSession.Cost
-
-			_, err = c.sessions.Save(ctx, parentSession)
-			if err != nil {
-				return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
-			}
-
-			return fantasy.NewTextResponse(result.Response.Content.Text()), nil
+			return c.runSubAgentWithOptions(
+				ctx,
+				agent,
+				validationResult.SessionID,
+				validationResult.AgentMessageID,
+				call.ID,
+				fullPrompt,
+				"Fetch Analysis",
+				func(sessionID string) {
+					c.permissions.AutoApproveSession(sessionID)
+				},
+			)
 		}), nil
 }

internal/agent/coordinator.go 🔗

@@ -940,3 +940,95 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con
 	}
 	return nil
 }
+
+// runSubAgent runs a sub-agent and handles session management and cost accumulation.
+// It creates a sub-session, runs the agent with the given prompt, and propagates
+// the cost to the parent session.
+func (c *coordinator) runSubAgent(
+	ctx context.Context,
+	agent SessionAgent,
+	sessionID, agentMessageID, toolCallID string,
+	prompt string,
+	sessionTitle string,
+) (fantasy.ToolResponse, error) {
+	return c.runSubAgentWithOptions(ctx, agent, sessionID, agentMessageID, toolCallID, prompt, sessionTitle, nil)
+}
+
+// runSubAgentWithOptions runs a sub-agent with additional session configuration options.
+// The sessionSetup function is called after session creation but before agent execution.
+func (c *coordinator) runSubAgentWithOptions(
+	ctx context.Context,
+	agent SessionAgent,
+	sessionID, agentMessageID, toolCallID string,
+	prompt string,
+	sessionTitle string,
+	sessionSetup func(sessionID string),
+) (fantasy.ToolResponse, error) {
+	// Create sub-session
+	agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, toolCallID)
+	session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, sessionTitle)
+	if err != nil {
+		return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
+	}
+
+	// Call session setup function if provided
+	if sessionSetup != nil {
+		sessionSetup(session.ID)
+	}
+
+	// Get model configuration
+	model := agent.Model()
+	maxTokens := model.CatwalkCfg.DefaultMaxTokens
+	if model.ModelCfg.MaxTokens != 0 {
+		maxTokens = model.ModelCfg.MaxTokens
+	}
+
+	providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
+	if !ok {
+		return fantasy.ToolResponse{}, errors.New("model provider not configured")
+	}
+
+	// Run the agent
+	result, err := agent.Run(ctx, SessionAgentCall{
+		SessionID:        session.ID,
+		Prompt:           prompt,
+		MaxOutputTokens:  maxTokens,
+		ProviderOptions:  getProviderOptions(model, providerCfg),
+		Temperature:      model.ModelCfg.Temperature,
+		TopP:             model.ModelCfg.TopP,
+		TopK:             model.ModelCfg.TopK,
+		FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
+		PresencePenalty:  model.ModelCfg.PresencePenalty,
+	})
+	if err != nil {
+		return fantasy.NewTextErrorResponse("error generating response"), nil
+	}
+
+	// Update parent session cost
+	if err := c.updateParentSessionCost(ctx, session.ID, sessionID); err != nil {
+		return fantasy.ToolResponse{}, err
+	}
+
+	return fantasy.NewTextResponse(result.Response.Content.Text()), nil
+}
+
+// updateParentSessionCost accumulates the cost from a child session to its parent session.
+func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
+	childSession, err := c.sessions.Get(ctx, childSessionID)
+	if err != nil {
+		return fmt.Errorf("get child session: %w", err)
+	}
+
+	parentSession, err := c.sessions.Get(ctx, parentSessionID)
+	if err != nil {
+		return fmt.Errorf("get parent session: %w", err)
+	}
+
+	parentSession.Cost += childSession.Cost
+
+	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
+		return fmt.Errorf("save parent session: %w", err)
+	}
+
+	return nil
+}