agent_tool.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9
 10	"github.com/charmbracelet/fantasy/ai"
 11
 12	"github.com/charmbracelet/crush/internal/agent/prompt"
 13	"github.com/charmbracelet/crush/internal/agent/tools"
 14	"github.com/charmbracelet/crush/internal/config"
 15)
 16
 17//go:embed templates/agent_tool.md
 18var agentToolDescription []byte
 19
 20type AgentParams struct {
 21	Prompt string `json:"prompt" description:"The task for the agent to perform"`
 22}
 23
 24const (
 25	AgentToolName = "agent"
 26)
 27
 28func (c *coordinator) agentTool() (ai.AgentTool, error) {
 29	agentCfg, ok := c.cfg.Agents[config.AgentTask]
 30	if !ok {
 31		return nil, errors.New("task agent not configured")
 32	}
 33	prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 34	if err != nil {
 35		return nil, err
 36	}
 37
 38	agent, err := c.buildAgent(prompt, agentCfg)
 39	if err != nil {
 40		return nil, err
 41	}
 42	return ai.NewAgentTool(
 43		AgentToolName,
 44		string(agentToolDescription),
 45		func(ctx context.Context, params AgentParams, call ai.ToolCall) (ai.ToolResponse, error) {
 46			if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
 47				return ai.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 48			}
 49			if params.Prompt == "" {
 50				return ai.NewTextErrorResponse("prompt is required"), nil
 51			}
 52
 53			sessionID := tools.GetSessionFromContext(ctx)
 54			if sessionID == "" {
 55				return ai.ToolResponse{}, errors.New("session id missing from context")
 56			}
 57
 58			agentMessageID := tools.GetMessageFromContext(ctx)
 59			if agentMessageID == "" {
 60				return ai.ToolResponse{}, errors.New("agent message id missing from context")
 61			}
 62
 63			agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
 64			session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
 65			if err != nil {
 66				return ai.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
 67			}
 68			model := agent.Model()
 69			maxTokens := model.CatwalkCfg.DefaultMaxTokens
 70			if model.ModelCfg.MaxTokens != 0 {
 71				maxTokens = model.ModelCfg.MaxTokens
 72			}
 73			result, err := agent.Run(ctx, SessionAgentCall{
 74				SessionID:        session.ID,
 75				Prompt:           params.Prompt,
 76				MaxOutputTokens:  maxTokens,
 77				ProviderOptions:  c.getProviderOptions(model),
 78				Temperature:      model.ModelCfg.Temperature,
 79				TopP:             model.ModelCfg.TopP,
 80				TopK:             model.ModelCfg.TopK,
 81				FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
 82				PresencePenalty:  model.ModelCfg.PresencePenalty,
 83			})
 84			if err != nil {
 85				return ai.NewTextErrorResponse("error generating response"), nil
 86			}
 87			updatedSession, err := c.sessions.Get(ctx, session.ID)
 88			if err != nil {
 89				return ai.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
 90			}
 91			parentSession, err := c.sessions.Get(ctx, sessionID)
 92			if err != nil {
 93				return ai.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
 94			}
 95
 96			parentSession.Cost += updatedSession.Cost
 97
 98			_, err = c.sessions.Save(ctx, parentSession)
 99			if err != nil {
100				return ai.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
101			}
102			return ai.NewTextResponse(result.Response.Content.Text()), nil
103		}), nil
104}