agent-tool.go

 1package agent
 2
 3import (
 4	"context"
 5	_ "embed"
 6	"encoding/json"
 7	"errors"
 8	"fmt"
 9
10	"github.com/charmbracelet/fantasy/ai"
11
12	"github.com/charmbracelet/crush/internal/agent/prompt"
13	"github.com/charmbracelet/crush/internal/agent/tools"
14	"github.com/charmbracelet/crush/internal/config"
15)
16
17//go:embed templates/agentTool.md
18var agentToolDescription []byte
19
20type AgentParams struct {
21	Prompt string `json:"prompt" description:"The task for the agent to perform"`
22}
23
24const (
25	AgentToolName = "agent"
26)
27
28func (c *coordinator) agentTool() (ai.AgentTool, error) {
29	agentCfg, ok := c.cfg.Agents[config.AgentTask]
30	if !ok {
31		return nil, errors.New("task agent not configured")
32	}
33	prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
34	if err != nil {
35		return nil, err
36	}
37
38	agent, err := c.buildAgent(prompt, agentCfg)
39	if err != nil {
40		return nil, err
41	}
42	return ai.NewAgentTool(
43		AgentToolName,
44		string(agentToolDescription),
45		func(ctx context.Context, params AgentParams, call ai.ToolCall) (ai.ToolResponse, error) {
46			if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
47				return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
48			}
49			if params.Prompt == "" {
50				return ai.NewTextErrorResponse("prompt is required"), nil
51			}
52
53			sessionID := tools.GetSessionFromContext(ctx)
54			if sessionID == "" {
55				return ai.ToolResponse{}, errors.New("session id missing from context")
56			}
57
58			session, err := c.sessions.CreateTaskSession(ctx, call.ID, sessionID, "New Agent Session")
59			if err != nil {
60				return ai.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
61			}
62			model := agent.Model()
63			maxTokens := model.CatwalkCfg.DefaultMaxTokens
64			if model.ModelCfg.MaxTokens != 0 {
65				maxTokens = model.ModelCfg.MaxTokens
66			}
67			result, err := agent.Run(ctx, SessionAgentCall{
68				SessionID:        sessionID,
69				Prompt:           params.Prompt,
70				MaxOutputTokens:  maxTokens,
71				ProviderOptions:  c.getProviderOptions(model),
72				Temperature:      model.ModelCfg.Temperature,
73				TopP:             model.ModelCfg.TopP,
74				TopK:             model.ModelCfg.TopK,
75				FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
76				PresencePenalty:  model.ModelCfg.PresencePenalty,
77			})
78			if err != nil {
79				return ai.ToolResponse{}, fmt.Errorf("error generating agent: %s", err)
80			}
81			updatedSession, err := c.sessions.Get(ctx, session.ID)
82			if err != nil {
83				return ai.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
84			}
85			parentSession, err := c.sessions.Get(ctx, sessionID)
86			if err != nil {
87				return ai.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
88			}
89
90			parentSession.Cost += updatedSession.Cost
91
92			_, err = c.sessions.Save(ctx, parentSession)
93			if err != nil {
94				return ai.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
95			}
96			return ai.NewTextResponse(result.Response.Content.Text()), nil
97		}), nil
98}