1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8
  9	"charm.land/fantasy"
 10
 11	"github.com/charmbracelet/crush/internal/agent/prompt"
 12	"github.com/charmbracelet/crush/internal/agent/tools"
 13	"github.com/charmbracelet/crush/internal/config"
 14)
 15
 16//go:embed templates/agent_tool.md
 17var agentToolDescription []byte
 18
 19type AgentParams struct {
 20	Prompt string `json:"prompt" description:"The task for the agent to perform"`
 21}
 22
 23const (
 24	AgentToolName = "agent"
 25)
 26
 27func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) {
 28	agentCfg, ok := c.cfg.Agents[config.AgentTask]
 29	if !ok {
 30		return nil, errors.New("task agent not configured")
 31	}
 32	prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 33	if err != nil {
 34		return nil, err
 35	}
 36
 37	agent, err := c.buildAgent(ctx, prompt, agentCfg)
 38	if err != nil {
 39		return nil, err
 40	}
 41	return fantasy.NewAgentTool(
 42		AgentToolName,
 43		string(agentToolDescription),
 44		func(ctx context.Context, params AgentParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 45			if params.Prompt == "" {
 46				return fantasy.NewTextErrorResponse("prompt is required"), nil
 47			}
 48
 49			sessionID := tools.GetSessionFromContext(ctx)
 50			if sessionID == "" {
 51				return fantasy.ToolResponse{}, errors.New("session id missing from context")
 52			}
 53
 54			agentMessageID := tools.GetMessageFromContext(ctx)
 55			if agentMessageID == "" {
 56				return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
 57			}
 58
 59			agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
 60			session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
 61			if err != nil {
 62				return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
 63			}
 64			model := agent.Model()
 65			maxTokens := model.CatwalkCfg.DefaultMaxTokens
 66			if model.ModelCfg.MaxTokens != 0 {
 67				maxTokens = model.ModelCfg.MaxTokens
 68			}
 69
 70			providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
 71			if !ok {
 72				return fantasy.ToolResponse{}, errors.New("model provider not configured")
 73			}
 74			result, err := agent.Run(ctx, SessionAgentCall{
 75				SessionID:        session.ID,
 76				Prompt:           params.Prompt,
 77				MaxOutputTokens:  maxTokens,
 78				ProviderOptions:  getProviderOptions(model, providerCfg),
 79				Temperature:      model.ModelCfg.Temperature,
 80				TopP:             model.ModelCfg.TopP,
 81				TopK:             model.ModelCfg.TopK,
 82				FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
 83				PresencePenalty:  model.ModelCfg.PresencePenalty,
 84			})
 85			if err != nil {
 86				return fantasy.NewTextErrorResponse("error generating response"), nil
 87			}
 88			updatedSession, err := c.sessions.Get(ctx, session.ID)
 89			if err != nil {
 90				return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
 91			}
 92			parentSession, err := c.sessions.Get(ctx, sessionID)
 93			if err != nil {
 94				return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
 95			}
 96
 97			parentSession.Cost += updatedSession.Cost
 98
 99			_, err = c.sessions.Save(ctx, parentSession)
100			if err != nil {
101				return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
102			}
103			return fantasy.NewTextResponse(result.Response.Content.Text()), nil
104		}), nil
105}