diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index 29566b1c5a00d00c1254a3f07cdcef71ba55d59e..5c9a95fb7f210625c9a4a04a803dcfc634f471a3 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -4,7 +4,6 @@ import ( "context" _ "embed" "errors" - "fmt" "charm.land/fantasy" @@ -56,50 +55,6 @@ func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) return fantasy.ToolResponse{}, errors.New("agent message id missing from context") } - agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID) - session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session") - if err != nil { - return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err) - } - model := agent.Model() - maxTokens := model.CatwalkCfg.DefaultMaxTokens - if model.ModelCfg.MaxTokens != 0 { - maxTokens = model.ModelCfg.MaxTokens - } - - providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider) - if !ok { - return fantasy.ToolResponse{}, errors.New("model provider not configured") - } - result, err := agent.Run(ctx, SessionAgentCall{ - SessionID: session.ID, - Prompt: params.Prompt, - MaxOutputTokens: maxTokens, - ProviderOptions: getProviderOptions(model, providerCfg), - Temperature: model.ModelCfg.Temperature, - TopP: model.ModelCfg.TopP, - TopK: model.ModelCfg.TopK, - FrequencyPenalty: model.ModelCfg.FrequencyPenalty, - PresencePenalty: model.ModelCfg.PresencePenalty, - }) - if err != nil { - return fantasy.NewTextErrorResponse("error generating response"), nil - } - updatedSession, err := c.sessions.Get(ctx, session.ID) - if err != nil { - return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err) - } - parentSession, err := c.sessions.Get(ctx, sessionID) - if err != nil { - return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err) - } - - parentSession.Cost += updatedSession.Cost - - _, err = c.sessions.Save(ctx, parentSession) - if err != nil { - return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err) - } - return fantasy.NewTextResponse(result.Response.Content.Text()), nil + return c.runSubAgent(ctx, agent, sessionID, agentMessageID, call.ID, params.Prompt, "New Agent Session") }), nil } diff --git a/internal/agent/agentic_fetch_tool.go b/internal/agent/agentic_fetch_tool.go index 2d52814d446581fca0e7a98368ffaae465aedf2c..26ed301eaff7955dac95bd2e0043490730e46f85 100644 --- a/internal/agent/agentic_fetch_tool.go +++ b/internal/agent/agentic_fetch_tool.go @@ -184,51 +184,17 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( Tools: fetchTools, }) - agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID) - session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Fetch Analysis") - if err != nil { - return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err) - } - - c.permissions.AutoApproveSession(session.ID) - - // Use small model for web content analysis (faster and cheaper) - maxTokens := small.CatwalkCfg.DefaultMaxTokens - if small.ModelCfg.MaxTokens != 0 { - maxTokens = small.ModelCfg.MaxTokens - } - - result, err := agent.Run(ctx, SessionAgentCall{ - SessionID: session.ID, - Prompt: fullPrompt, - MaxOutputTokens: maxTokens, - ProviderOptions: getProviderOptions(small, smallProviderCfg), - Temperature: small.ModelCfg.Temperature, - TopP: small.ModelCfg.TopP, - TopK: small.ModelCfg.TopK, - FrequencyPenalty: small.ModelCfg.FrequencyPenalty, - PresencePenalty: small.ModelCfg.PresencePenalty, - }) - if err != nil { - return fantasy.NewTextErrorResponse("error generating response"), nil - } - - updatedSession, err := c.sessions.Get(ctx, session.ID) - if err != nil { - return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err) - } - parentSession, err := c.sessions.Get(ctx, validationResult.SessionID) - if err != nil { - return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err) - } - - parentSession.Cost += updatedSession.Cost - - _, err = c.sessions.Save(ctx, parentSession) - if err != nil { - return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err) - } - - return fantasy.NewTextResponse(result.Response.Content.Text()), nil + return c.runSubAgentWithOptions( + ctx, + agent, + validationResult.SessionID, + validationResult.AgentMessageID, + call.ID, + fullPrompt, + "Fetch Analysis", + func(sessionID string) { + c.permissions.AutoApproveSession(sessionID) + }, + ) }), nil } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 0b070a24d346ecd649459e11dc71430873bf2788..57b42d12de65e3d194ee9b1adaf96392082e3f55 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -940,3 +940,95 @@ func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg con } return nil } + +// runSubAgent runs a sub-agent and handles session management and cost accumulation. +// It creates a sub-session, runs the agent with the given prompt, and propagates +// the cost to the parent session. +func (c *coordinator) runSubAgent( + ctx context.Context, + agent SessionAgent, + sessionID, agentMessageID, toolCallID string, + prompt string, + sessionTitle string, +) (fantasy.ToolResponse, error) { + return c.runSubAgentWithOptions(ctx, agent, sessionID, agentMessageID, toolCallID, prompt, sessionTitle, nil) +} + +// runSubAgentWithOptions runs a sub-agent with additional session configuration options. +// The sessionSetup function is called after session creation but before agent execution. +func (c *coordinator) runSubAgentWithOptions( + ctx context.Context, + agent SessionAgent, + sessionID, agentMessageID, toolCallID string, + prompt string, + sessionTitle string, + sessionSetup func(sessionID string), +) (fantasy.ToolResponse, error) { + // Create sub-session + agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, toolCallID) + session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, sessionTitle) + if err != nil { + return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err) + } + + // Call session setup function if provided + if sessionSetup != nil { + sessionSetup(session.ID) + } + + // Get model configuration + model := agent.Model() + maxTokens := model.CatwalkCfg.DefaultMaxTokens + if model.ModelCfg.MaxTokens != 0 { + maxTokens = model.ModelCfg.MaxTokens + } + + providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider) + if !ok { + return fantasy.ToolResponse{}, errors.New("model provider not configured") + } + + // Run the agent + result, err := agent.Run(ctx, SessionAgentCall{ + SessionID: session.ID, + Prompt: prompt, + MaxOutputTokens: maxTokens, + ProviderOptions: getProviderOptions(model, providerCfg), + Temperature: model.ModelCfg.Temperature, + TopP: model.ModelCfg.TopP, + TopK: model.ModelCfg.TopK, + FrequencyPenalty: model.ModelCfg.FrequencyPenalty, + PresencePenalty: model.ModelCfg.PresencePenalty, + }) + if err != nil { + return fantasy.NewTextErrorResponse("error generating response"), nil + } + + // Update parent session cost + if err := c.updateParentSessionCost(ctx, session.ID, sessionID); err != nil { + return fantasy.ToolResponse{}, err + } + + return fantasy.NewTextResponse(result.Response.Content.Text()), nil +} + +// updateParentSessionCost accumulates the cost from a child session to its parent session. +func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error { + childSession, err := c.sessions.Get(ctx, childSessionID) + if err != nil { + return fmt.Errorf("get child session: %w", err) + } + + parentSession, err := c.sessions.Get(ctx, parentSessionID) + if err != nil { + return fmt.Errorf("get parent session: %w", err) + } + + parentSession.Cost += childSession.Cost + + if _, err := c.sessions.Save(ctx, parentSession); err != nil { + return fmt.Errorf("save parent session: %w", err) + } + + return nil +}