1package agent
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "errors"
8 "fmt"
9
10 "github.com/charmbracelet/fantasy/ai"
11
12 "github.com/charmbracelet/crush/internal/agent/prompt"
13 "github.com/charmbracelet/crush/internal/agent/tools"
14 "github.com/charmbracelet/crush/internal/config"
15)
16
17//go:embed templates/agent_tool.md
18var agentToolDescription []byte
19
20type AgentParams struct {
21 Prompt string `json:"prompt" description:"The task for the agent to perform"`
22}
23
24const (
25 AgentToolName = "agent"
26)
27
28func (c *coordinator) agentTool() (ai.AgentTool, error) {
29 agentCfg, ok := c.cfg.Agents[config.AgentTask]
30 if !ok {
31 return nil, errors.New("task agent not configured")
32 }
33 prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
34 if err != nil {
35 return nil, err
36 }
37
38 agent, err := c.buildAgent(prompt, agentCfg)
39 if err != nil {
40 return nil, err
41 }
42 return ai.NewAgentTool(
43 AgentToolName,
44 string(agentToolDescription),
45 func(ctx context.Context, params AgentParams, call ai.ToolCall) (ai.ToolResponse, error) {
46 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
47 return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
48 }
49 if params.Prompt == "" {
50 return ai.NewTextErrorResponse("prompt is required"), nil
51 }
52
53 sessionID := tools.GetSessionFromContext(ctx)
54 if sessionID == "" {
55 return ai.ToolResponse{}, errors.New("session id missing from context")
56 }
57
58 agentMessageID := tools.GetMessageFromContext(ctx)
59 if agentMessageID == "" {
60 return ai.ToolResponse{}, errors.New("agent message id missing from context")
61 }
62
63 agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
64 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
65 if err != nil {
66 return ai.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
67 }
68 model := agent.Model()
69 maxTokens := model.CatwalkCfg.DefaultMaxTokens
70 if model.ModelCfg.MaxTokens != 0 {
71 maxTokens = model.ModelCfg.MaxTokens
72 }
73 result, err := agent.Run(ctx, SessionAgentCall{
74 SessionID: session.ID,
75 Prompt: params.Prompt,
76 MaxOutputTokens: maxTokens,
77 ProviderOptions: getProviderOptions(model),
78 Temperature: model.ModelCfg.Temperature,
79 TopP: model.ModelCfg.TopP,
80 TopK: model.ModelCfg.TopK,
81 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
82 PresencePenalty: model.ModelCfg.PresencePenalty,
83 })
84 if err != nil {
85 return ai.NewTextErrorResponse("error generating response"), nil
86 }
87 updatedSession, err := c.sessions.Get(ctx, session.ID)
88 if err != nil {
89 return ai.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
90 }
91 parentSession, err := c.sessions.Get(ctx, sessionID)
92 if err != nil {
93 return ai.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
94 }
95
96 parentSession.Cost += updatedSession.Cost
97
98 _, err = c.sessions.Save(ctx, parentSession)
99 if err != nil {
100 return ai.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
101 }
102 return ai.NewTextResponse(result.Response.Content.Text()), nil
103 }), nil
104}