memory_search_tool.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"os"
  9	"path/filepath"
 10
 11	"charm.land/fantasy"
 12
 13	"github.com/charmbracelet/crush/internal/agent/prompt"
 14	"github.com/charmbracelet/crush/internal/agent/tools"
 15)
 16
 17//go:embed templates/memory_search.md
 18var memorySearchToolDescription []byte
 19
 20//go:embed templates/memory_search_prompt.md.tpl
 21var memorySearchPromptTmpl []byte
 22
 23// memorySearchValidationResult holds the validated parameters from the tool call context.
 24type memorySearchValidationResult struct {
 25	SessionID      string
 26	AgentMessageID string
 27}
 28
 29// validateMemorySearchParams validates the tool call parameters and extracts required context values.
 30func validateMemorySearchParams(ctx context.Context, params tools.MemorySearchParams) (memorySearchValidationResult, error) {
 31	if params.Query == "" {
 32		return memorySearchValidationResult{}, errors.New("query is required")
 33	}
 34
 35	sessionID := tools.GetSessionFromContext(ctx)
 36	if sessionID == "" {
 37		return memorySearchValidationResult{}, errors.New("session id missing from context")
 38	}
 39
 40	agentMessageID := tools.GetMessageFromContext(ctx)
 41	if agentMessageID == "" {
 42		return memorySearchValidationResult{}, errors.New("agent message id missing from context")
 43	}
 44
 45	return memorySearchValidationResult{
 46		SessionID:      sessionID,
 47		AgentMessageID: agentMessageID,
 48	}, nil
 49}
 50
 51func (c *coordinator) memorySearchTool(_ context.Context) (fantasy.AgentTool, error) {
 52	return fantasy.NewParallelAgentTool(
 53		tools.MemorySearchToolName,
 54		string(memorySearchToolDescription),
 55		func(ctx context.Context, params tools.MemorySearchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 56			validationResult, err := validateMemorySearchParams(ctx, params)
 57			if err != nil {
 58				return fantasy.NewTextErrorResponse(err.Error()), nil
 59			}
 60
 61			// Get the parent session to find the transcript.
 62			parentSession, err := c.sessions.Get(ctx, validationResult.SessionID)
 63			if err != nil {
 64				return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to get session: %s", err)), nil
 65			}
 66
 67			// Check if session has been summarized.
 68			if parentSession.SummaryMessageID == "" {
 69				return fantasy.NewTextErrorResponse("This session has not been summarized yet. The memory_search tool is only available after summarization."), nil
 70			}
 71
 72			// Find the transcript file.
 73			transcriptPath := TranscriptPath(parentSession.ID)
 74			if _, err := os.Stat(transcriptPath); os.IsNotExist(err) {
 75				return fantasy.NewTextErrorResponse(fmt.Sprintf("Transcript file not found at %s. The session may have been summarized before this feature was available.", transcriptPath)), nil
 76			}
 77
 78			// Build the sub-agent prompt.
 79			transcriptDir := filepath.Dir(transcriptPath)
 80			promptOpts := []prompt.Option{
 81				prompt.WithWorkingDir(transcriptDir),
 82			}
 83
 84			promptTemplate, err := prompt.NewPrompt("memory_search", string(memorySearchPromptTmpl), promptOpts...)
 85			if err != nil {
 86				return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err)
 87			}
 88
 89			_, small, err := c.buildAgentModels(ctx, true)
 90			if err != nil {
 91				return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err)
 92			}
 93
 94			systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg)
 95			if err != nil {
 96				return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err)
 97			}
 98
 99			smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider)
100			if !ok {
101				return fantasy.ToolResponse{}, errors.New("small model provider not configured")
102			}
103
104			// Create sub-agent with read-only tools scoped to the transcript directory.
105			searchTools := []fantasy.AgentTool{
106				tools.NewGlobTool(transcriptDir),
107				tools.NewGrepTool(transcriptDir),
108				tools.NewViewTool(c.lspClients, c.permissions, transcriptDir),
109			}
110
111			agent := NewSessionAgent(SessionAgentOptions{
112				LargeModel:           small, // Use small model for both (search doesn't need large)
113				SmallModel:           small,
114				SystemPromptPrefix:   smallProviderCfg.SystemPromptPrefix,
115				SystemPrompt:         systemPrompt,
116				DisableAutoSummarize: true, // Never summarize the sub-agent session
117				IsYolo:               c.permissions.SkipRequests(),
118				Sessions:             c.sessions,
119				Messages:             c.messages,
120				Tools:                searchTools,
121			})
122
123			agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID)
124			session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Memory Search")
125			if err != nil {
126				return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
127			}
128
129			c.permissions.AutoApproveSession(session.ID)
130
131			// Build the full prompt including the transcript path.
132			fullPrompt := fmt.Sprintf("%s\n\nThe session transcript is located at: %s\n\nUse grep and view to search this file for the requested information.", params.Query, transcriptPath)
133
134			// Use small model for transcript search.
135			maxTokens := small.CatwalkCfg.DefaultMaxTokens
136			if small.ModelCfg.MaxTokens != 0 {
137				maxTokens = small.ModelCfg.MaxTokens
138			}
139
140			result, err := agent.Run(ctx, SessionAgentCall{
141				SessionID:        session.ID,
142				Prompt:           fullPrompt,
143				MaxOutputTokens:  maxTokens,
144				ProviderOptions:  getProviderOptions(small, smallProviderCfg),
145				Temperature:      small.ModelCfg.Temperature,
146				TopP:             small.ModelCfg.TopP,
147				TopK:             small.ModelCfg.TopK,
148				FrequencyPenalty: small.ModelCfg.FrequencyPenalty,
149				PresencePenalty:  small.ModelCfg.PresencePenalty,
150			})
151			if err != nil {
152				return fantasy.NewTextErrorResponse("error generating response"), nil
153			}
154
155			// Update parent session cost.
156			updatedSession, err := c.sessions.Get(ctx, session.ID)
157			if err != nil {
158				return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
159			}
160			parentSession, err = c.sessions.Get(ctx, validationResult.SessionID)
161			if err != nil {
162				return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
163			}
164
165			parentSession.Cost += updatedSession.Cost
166
167			_, err = c.sessions.Save(ctx, parentSession)
168			if err != nil {
169				return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
170			}
171
172			return fantasy.NewTextResponse(result.Response.Content.Text()), nil
173		}), nil
174}