agent_tool.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9
 10	"charm.land/fantasy"
 11
 12	"github.com/charmbracelet/crush/internal/agent/prompt"
 13	"github.com/charmbracelet/crush/internal/agent/tools"
 14	"github.com/charmbracelet/crush/internal/config"
 15)
 16
 17//go:embed templates/agent_tool.md
 18var agentToolDescription []byte
 19
 20type AgentParams struct {
 21	Prompt string `json:"prompt" description:"The task for the agent to perform"`
 22}
 23
 24const (
 25	AgentToolName = "agent"
 26)
 27
 28func (c *coordinator) agentTool() (fantasy.AgentTool, error) {
 29	agentCfg, ok := c.cfg.Agents[config.AgentTask]
 30	if !ok {
 31		return nil, errors.New("task agent not configured")
 32	}
 33	prompt, err := taskPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
 34	if err != nil {
 35		return nil, err
 36	}
 37
 38	agent, err := c.buildAgent(prompt, agentCfg)
 39	if err != nil {
 40		return nil, err
 41	}
 42	return fantasy.NewAgentTool(
 43		AgentToolName,
 44		string(agentToolDescription),
 45		func(ctx context.Context, params AgentParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 46			if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
 47				return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 48			}
 49			if params.Prompt == "" {
 50				return fantasy.NewTextErrorResponse("prompt is required"), nil
 51			}
 52
 53			sessionID := tools.GetSessionFromContext(ctx)
 54			if sessionID == "" {
 55				return fantasy.ToolResponse{}, errors.New("session id missing from context")
 56			}
 57
 58			agentMessageID := tools.GetMessageFromContext(ctx)
 59			if agentMessageID == "" {
 60				return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
 61			}
 62
 63			agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
 64			session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "New Agent Session")
 65			if err != nil {
 66				return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
 67			}
 68			model := agent.Model()
 69			maxTokens := model.CatwalkCfg.DefaultMaxTokens
 70			if model.ModelCfg.MaxTokens != 0 {
 71				maxTokens = model.ModelCfg.MaxTokens
 72			}
 73
 74			providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
 75			if !ok {
 76				return fantasy.ToolResponse{}, errors.New("model provider not configured")
 77			}
 78			result, err := agent.Run(ctx, SessionAgentCall{
 79				SessionID:        session.ID,
 80				Prompt:           params.Prompt,
 81				MaxOutputTokens:  maxTokens,
 82				ProviderOptions:  getProviderOptions(model, providerCfg.Type),
 83				Temperature:      model.ModelCfg.Temperature,
 84				TopP:             model.ModelCfg.TopP,
 85				TopK:             model.ModelCfg.TopK,
 86				FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
 87				PresencePenalty:  model.ModelCfg.PresencePenalty,
 88			})
 89			if err != nil {
 90				return fantasy.NewTextErrorResponse("error generating response"), nil
 91			}
 92			updatedSession, err := c.sessions.Get(ctx, session.ID)
 93			if err != nil {
 94				return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
 95			}
 96			parentSession, err := c.sessions.Get(ctx, sessionID)
 97			if err != nil {
 98				return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
 99			}
100
101			parentSession.Cost += updatedSession.Cost
102
103			_, err = c.sessions.Save(ctx, parentSession)
104			if err != nil {
105				return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
106			}
107			return fantasy.NewTextResponse(result.Response.Content.Text()), nil
108		}), nil
109}