Detailed changes
@@ -4,7 +4,6 @@ import (
"context"
_ "embed"
"errors"
- "fmt"
"charm.land/fantasy"
@@ -56,50 +55,6 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error)
return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
}
- agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
- session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
- }
- model := agent.Model()
- maxTokens := model.CatwalkCfg.DefaultMaxTokens
- if model.ModelCfg.MaxTokens != 0 {
- maxTokens = model.ModelCfg.MaxTokens
- }
-
- providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
- if !ok {
- return fantasy.ToolResponse{}, errors.New("model provider not configured")
- }
- result, err := agent.Run(ctx, SessionAgentCall{
- SessionID: session.ID,
- Prompt: params.Prompt,
- MaxOutputTokens: maxTokens,
- ProviderOptions: getProviderOptions(model, providerCfg),
- Temperature: model.ModelCfg.Temperature,
- TopP: model.ModelCfg.TopP,
- TopK: model.ModelCfg.TopK,
- FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
- PresencePenalty: model.ModelCfg.PresencePenalty,
- })
- if err != nil {
- return fantasy.NewTextErrorResponse("error generating response"), nil
- }
- updatedSession, err := c.sessions.Get(ctx, session.ID)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
- }
- parentSession, err := c.sessions.Get(ctx, sessionID)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
- }
-
- parentSession.Cost += updatedSession.Cost
-
- _, err = c.sessions.Save(ctx, parentSession)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
- }
- return fantasy.NewTextResponse(result.Response.Content.Text()), nil
+ return c.runSubAgent(ctx, agent, sessionID, agentMessageID, call.ID, params.Prompt, "New Agent Session")
}), nil
}
@@ -184,51 +184,17 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
Tools: fetchTools,
})
- agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID)
- session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Fetch Analysis")
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
- }
-
- c.permissions.AutoApproveSession(session.ID)
-
- // Use small model for web content analysis (faster and cheaper)
- maxTokens := small.CatwalkCfg.DefaultMaxTokens
- if small.ModelCfg.MaxTokens != 0 {
- maxTokens = small.ModelCfg.MaxTokens
- }
-
- result, err := agent.Run(ctx, SessionAgentCall{
- SessionID: session.ID,
- Prompt: fullPrompt,
- MaxOutputTokens: maxTokens,
- ProviderOptions: getProviderOptions(small, smallProviderCfg),
- Temperature: small.ModelCfg.Temperature,
- TopP: small.ModelCfg.TopP,
- TopK: small.ModelCfg.TopK,
- FrequencyPenalty: small.ModelCfg.FrequencyPenalty,
- PresencePenalty: small.ModelCfg.PresencePenalty,
- })
- if err != nil {
- return fantasy.NewTextErrorResponse("error generating response"), nil
- }
-
- updatedSession, err := c.sessions.Get(ctx, session.ID)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
- }
- parentSession, err := c.sessions.Get(ctx, validationResult.SessionID)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
- }
-
- parentSession.Cost += updatedSession.Cost
-
- _, err = c.sessions.Save(ctx, parentSession)
- if err != nil {
- return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
- }
-
- return fantasy.NewTextResponse(result.Response.Content.Text()), nil
+ return c.runSubAgentWithOptions(
+ ctx,
+ agent,
+ validationResult.SessionID,
+ validationResult.AgentMessageID,
+ call.ID,
+ fullPrompt,
+ "Fetch Analysis",
+ func(sessionID string) {
+ c.permissions.AutoApproveSession(sessionID)
+ },
+ )
}), nil
}
@@ -940,3 +940,95 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con
}
return nil
}
+
+// runSubAgent runs a sub-agent and handles session management and cost accumulation.
+// It creates a sub-session, runs the agent with the given prompt, and propagates
+// the cost to the parent session.
+func (c *coordinator) runSubAgent(
+ ctx context.Context,
+ agent SessionAgent,
+ sessionID, agentMessageID, toolCallID string,
+ prompt string,
+ sessionTitle string,
+) (fantasy.ToolResponse, error) {
+ return c.runSubAgentWithOptions(ctx, agent, sessionID, agentMessageID, toolCallID, prompt, sessionTitle, nil)
+}
+
+// runSubAgentWithOptions runs a sub-agent with additional session configuration options.
+// The sessionSetup function is called after session creation but before agent execution.
+func (c *coordinator) runSubAgentWithOptions(
+ ctx context.Context,
+ agent SessionAgent,
+ sessionID, agentMessageID, toolCallID string,
+ prompt string,
+ sessionTitle string,
+ sessionSetup func(sessionID string),
+) (fantasy.ToolResponse, error) {
+ // Create sub-session
+ agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, toolCallID)
+ session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, sessionTitle)
+ if err != nil {
+ return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
+ }
+
+ // Call session setup function if provided
+ if sessionSetup != nil {
+ sessionSetup(session.ID)
+ }
+
+ // Get model configuration
+ model := agent.Model()
+ maxTokens := model.CatwalkCfg.DefaultMaxTokens
+ if model.ModelCfg.MaxTokens != 0 {
+ maxTokens = model.ModelCfg.MaxTokens
+ }
+
+ providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
+ if !ok {
+ return fantasy.ToolResponse{}, errors.New("model provider not configured")
+ }
+
+ // Run the agent
+ result, err := agent.Run(ctx, SessionAgentCall{
+ SessionID: session.ID,
+ Prompt: prompt,
+ MaxOutputTokens: maxTokens,
+ ProviderOptions: getProviderOptions(model, providerCfg),
+ Temperature: model.ModelCfg.Temperature,
+ TopP: model.ModelCfg.TopP,
+ TopK: model.ModelCfg.TopK,
+ FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
+ PresencePenalty: model.ModelCfg.PresencePenalty,
+ })
+ if err != nil {
+ return fantasy.NewTextErrorResponse("error generating response"), nil
+ }
+
+ // Update parent session cost
+ if err := c.updateParentSessionCost(ctx, session.ID, sessionID); err != nil {
+ return fantasy.ToolResponse{}, err
+ }
+
+ return fantasy.NewTextResponse(result.Response.Content.Text()), nil
+}
+
+// updateParentSessionCost accumulates the cost from a child session to its parent session.
+func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
+ childSession, err := c.sessions.Get(ctx, childSessionID)
+ if err != nil {
+ return fmt.Errorf("get child session: %w", err)
+ }
+
+ parentSession, err := c.sessions.Get(ctx, parentSessionID)
+ if err != nil {
+ return fmt.Errorf("get parent session: %w", err)
+ }
+
+ parentSession.Cost += childSession.Cost
+
+ if _, err := c.sessions.Save(ctx, parentSession); err != nil {
+ return fmt.Errorf("save parent session: %w", err)
+ }
+
+ return nil
+}