agent_tool.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10
 11	"charm.land/fantasy"
 12
 13	"github.com/charmbracelet/crush/internal/agent/prompt"
 14	"github.com/charmbracelet/crush/internal/agent/tools"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/hooks"
 17)
 18
 19//go:embed templates/agent_tool.md
 20var agentToolDescription []byte
 21
 22type AgentParams struct {
 23	Prompt string `json:"prompt" description:"The task for the agent to perform"`
 24}
 25
 26const (
 27	AgentToolName = "agent"
 28)
 29
 30func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) {
 31	agentCfg, ok := c.cfg.Agents[config.AgentTask]
 32	if !ok {
 33		return nil, errors.New("task agent not configured")
 34	}
 35	prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 36	if err != nil {
 37		return nil, err
 38	}
 39
 40	agent, err := c.buildAgent(ctx, prompt, agentCfg)
 41	if err != nil {
 42		return nil, err
 43	}
 44	return fantasy.NewAgentTool(
 45		AgentToolName,
 46		string(agentToolDescription),
 47		func(ctx context.Context, params AgentParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 48			if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
 49				return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 50			}
 51			if params.Prompt == "" {
 52				return fantasy.NewTextErrorResponse("prompt is required"), nil
 53			}
 54
 55			sessionID := tools.GetSessionFromContext(ctx)
 56			if sessionID == "" {
 57				return fantasy.ToolResponse{}, errors.New("session id missing from context")
 58			}
 59
 60			agentMessageID := tools.GetMessageFromContext(ctx)
 61			if agentMessageID == "" {
 62				return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
 63			}
 64
 65			agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
 66			session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
 67			if err != nil {
 68				return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
 69			}
 70			model := agent.Model()
 71			maxTokens := model.CatwalkCfg.DefaultMaxTokens
 72			if model.ModelCfg.MaxTokens != 0 {
 73				maxTokens = model.ModelCfg.MaxTokens
 74			}
 75
 76			providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
 77			if !ok {
 78				return fantasy.ToolResponse{}, errors.New("model provider not configured")
 79			}
 80			result, err := agent.Run(ctx, SessionAgentCall{
 81				SessionID:        session.ID,
 82				Prompt:           params.Prompt,
 83				MaxOutputTokens:  maxTokens,
 84				ProviderOptions:  getProviderOptions(model, providerCfg),
 85				Temperature:      model.ModelCfg.Temperature,
 86				TopP:             model.ModelCfg.TopP,
 87				TopK:             model.ModelCfg.TopK,
 88				FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
 89				PresencePenalty:  model.ModelCfg.PresencePenalty,
 90			})
 91			if err != nil {
 92				return fantasy.NewTextErrorResponse("error generating response"), nil
 93			}
 94			updatedSession, err := c.sessions.Get(ctx, session.ID)
 95			if err != nil {
 96				return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
 97			}
 98			parentSession, err := c.sessions.Get(ctx, sessionID)
 99			if err != nil {
100				return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
101			}
102
103			parentSession.Cost += updatedSession.Cost
104
105			_, err = c.sessions.Save(ctx, parentSession)
106			if err != nil {
107				return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
108			}
109
110			// Execute SubagentStop hook
111			if c.hooks != nil {
112				if err := c.hooks.Execute(ctx, hooks.HookContext{
113					EventType: config.SubagentStop,
114					SessionID: sessionID,
115					ToolName:  AgentToolName,
116					MessageID: agentMessageID,
117					Provider:  model.ModelCfg.Provider,
118					Model:     model.ModelCfg.Model,
119				}); err != nil {
120					slog.Debug("subagent_stop hook execution failed", "error", err)
121				}
122			}
123
124			return fantasy.NewTextResponse(result.Response.Content.Text()), nil
125		}), nil
126}