1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8
9 "charm.land/fantasy"
10
11 "github.com/charmbracelet/crush/internal/agent/prompt"
12 "github.com/charmbracelet/crush/internal/agent/tools"
13 "github.com/charmbracelet/crush/internal/config"
14)
15
16//go:embed templates/agent_tool.md
17var agentToolDescription []byte
18
19type AgentParams struct {
20 Prompt string `json:"prompt" description:"The task for the agent to perform"`
21}
22
23const (
24 AgentToolName = "agent"
25)
26
27func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) {
28 agentCfg, ok := c.cfg.Agents[config.AgentTask]
29 if !ok {
30 return nil, errors.New("task agent not configured")
31 }
32 prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
33 if err != nil {
34 return nil, err
35 }
36
37 agent, err := c.buildAgent(ctx, prompt, agentCfg)
38 if err != nil {
39 return nil, err
40 }
41 return fantasy.NewAgentTool(
42 AgentToolName,
43 string(agentToolDescription),
44 func(ctx context.Context, params AgentParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
45 if params.Prompt == "" {
46 return fantasy.NewTextErrorResponse("prompt is required"), nil
47 }
48
49 sessionID := tools.GetSessionFromContext(ctx)
50 if sessionID == "" {
51 return fantasy.ToolResponse{}, errors.New("session id missing from context")
52 }
53
54 agentMessageID := tools.GetMessageFromContext(ctx)
55 if agentMessageID == "" {
56 return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
57 }
58
59 agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
60 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
61 if err != nil {
62 return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
63 }
64 model := agent.Model()
65 maxTokens := model.CatwalkCfg.DefaultMaxTokens
66 if model.ModelCfg.MaxTokens != 0 {
67 maxTokens = model.ModelCfg.MaxTokens
68 }
69
70 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
71 if !ok {
72 return fantasy.ToolResponse{}, errors.New("model provider not configured")
73 }
74 result, err := agent.Run(ctx, SessionAgentCall{
75 SessionID: session.ID,
76 Prompt: params.Prompt,
77 MaxOutputTokens: maxTokens,
78 ProviderOptions: getProviderOptions(model, providerCfg),
79 Temperature: model.ModelCfg.Temperature,
80 TopP: model.ModelCfg.TopP,
81 TopK: model.ModelCfg.TopK,
82 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
83 PresencePenalty: model.ModelCfg.PresencePenalty,
84 })
85 if err != nil {
86 return fantasy.NewTextErrorResponse("error generating response"), nil
87 }
88 updatedSession, err := c.sessions.Get(ctx, session.ID)
89 if err != nil {
90 return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
91 }
92 parentSession, err := c.sessions.Get(ctx, sessionID)
93 if err != nil {
94 return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
95 }
96
97 parentSession.Cost += updatedSession.Cost
98
99 _, err = c.sessions.Save(ctx, parentSession)
100 if err != nil {
101 return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
102 }
103 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
104 }), nil
105}