1package agent
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10
11 "charm.land/fantasy"
12
13 "github.com/charmbracelet/crush/internal/agent/prompt"
14 "github.com/charmbracelet/crush/internal/agent/tools"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/hooks"
17)
18
19//go:embed templates/agent_tool.md
20var agentToolDescription []byte
21
22type AgentParams struct {
23 Prompt string `json:"prompt" description:"The task for the agent to perform"`
24}
25
26const (
27 AgentToolName = "agent"
28)
29
30func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) {
31 agentCfg, ok := c.cfg.Agents[config.AgentTask]
32 if !ok {
33 return nil, errors.New("task agent not configured")
34 }
35 prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
36 if err != nil {
37 return nil, err
38 }
39
40 agent, err := c.buildAgent(ctx, prompt, agentCfg)
41 if err != nil {
42 return nil, err
43 }
44 return fantasy.NewAgentTool(
45 AgentToolName,
46 string(agentToolDescription),
47 func(ctx context.Context, params AgentParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
48 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
49 return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
50 }
51 if params.Prompt == "" {
52 return fantasy.NewTextErrorResponse("prompt is required"), nil
53 }
54
55 sessionID := tools.GetSessionFromContext(ctx)
56 if sessionID == "" {
57 return fantasy.ToolResponse{}, errors.New("session id missing from context")
58 }
59
60 agentMessageID := tools.GetMessageFromContext(ctx)
61 if agentMessageID == "" {
62 return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
63 }
64
65 agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
66 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
67 if err != nil {
68 return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
69 }
70 model := agent.Model()
71 maxTokens := model.CatwalkCfg.DefaultMaxTokens
72 if model.ModelCfg.MaxTokens != 0 {
73 maxTokens = model.ModelCfg.MaxTokens
74 }
75
76 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
77 if !ok {
78 return fantasy.ToolResponse{}, errors.New("model provider not configured")
79 }
80 result, err := agent.Run(ctx, SessionAgentCall{
81 SessionID: session.ID,
82 Prompt: params.Prompt,
83 MaxOutputTokens: maxTokens,
84 ProviderOptions: getProviderOptions(model, providerCfg),
85 Temperature: model.ModelCfg.Temperature,
86 TopP: model.ModelCfg.TopP,
87 TopK: model.ModelCfg.TopK,
88 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
89 PresencePenalty: model.ModelCfg.PresencePenalty,
90 })
91 if err != nil {
92 return fantasy.NewTextErrorResponse("error generating response"), nil
93 }
94 updatedSession, err := c.sessions.Get(ctx, session.ID)
95 if err != nil {
96 return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
97 }
98 parentSession, err := c.sessions.Get(ctx, sessionID)
99 if err != nil {
100 return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
101 }
102
103 parentSession.Cost += updatedSession.Cost
104
105 _, err = c.sessions.Save(ctx, parentSession)
106 if err != nil {
107 return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
108 }
109
110 // Execute SubagentStop hook
111 if c.hooks != nil {
112 if err := c.hooks.Execute(ctx, hooks.HookContext{
113 EventType: config.SubagentStop,
114 SessionID: sessionID,
115 ToolName: AgentToolName,
116 MessageID: agentMessageID,
117 Provider: model.ModelCfg.Provider,
118 Model: model.ModelCfg.Model,
119 }); err != nil {
120 slog.Debug("subagent_stop hook execution failed", "error", err)
121 }
122 }
123
124 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
125 }), nil
126}