1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "os"
9 "path/filepath"
10
11 "charm.land/fantasy"
12
13 "github.com/charmbracelet/crush/internal/agent/prompt"
14 "github.com/charmbracelet/crush/internal/agent/tools"
15)
16
17//go:embed templates/memory_search.md
18var memorySearchToolDescription []byte
19
20//go:embed templates/memory_search_prompt.md.tpl
21var memorySearchPromptTmpl []byte
22
23// memorySearchValidationResult holds the validated parameters from the tool call context.
24type memorySearchValidationResult struct {
25 SessionID string
26 AgentMessageID string
27}
28
29// validateMemorySearchParams validates the tool call parameters and extracts required context values.
30func validateMemorySearchParams(ctx context.Context, params tools.MemorySearchParams) (memorySearchValidationResult, error) {
31 if params.Query == "" {
32 return memorySearchValidationResult{}, errors.New("query is required")
33 }
34
35 sessionID := tools.GetSessionFromContext(ctx)
36 if sessionID == "" {
37 return memorySearchValidationResult{}, errors.New("session id missing from context")
38 }
39
40 agentMessageID := tools.GetMessageFromContext(ctx)
41 if agentMessageID == "" {
42 return memorySearchValidationResult{}, errors.New("agent message id missing from context")
43 }
44
45 return memorySearchValidationResult{
46 SessionID: sessionID,
47 AgentMessageID: agentMessageID,
48 }, nil
49}
50
51func (c *coordinator) memorySearchTool(_ context.Context) (fantasy.AgentTool, error) {
52 return fantasy.NewParallelAgentTool(
53 tools.MemorySearchToolName,
54 string(memorySearchToolDescription),
55 func(ctx context.Context, params tools.MemorySearchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
56 validationResult, err := validateMemorySearchParams(ctx, params)
57 if err != nil {
58 return fantasy.NewTextErrorResponse(err.Error()), nil
59 }
60
61 // Get the parent session to find the transcript.
62 parentSession, err := c.sessions.Get(ctx, validationResult.SessionID)
63 if err != nil {
64 return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to get session: %s", err)), nil
65 }
66
67 // Check if session has been summarized.
68 if parentSession.SummaryMessageID == "" {
69 return fantasy.NewTextErrorResponse("This session has not been summarized yet. The memory_search tool is only available after summarization."), nil
70 }
71
72 // Find the transcript file.
73 transcriptPath := TranscriptPath(parentSession.ID)
74 if _, err := os.Stat(transcriptPath); os.IsNotExist(err) {
75 return fantasy.NewTextErrorResponse(fmt.Sprintf("Transcript file not found at %s. The session may have been summarized before this feature was available.", transcriptPath)), nil
76 }
77
78 // Build the sub-agent prompt.
79 transcriptDir := filepath.Dir(transcriptPath)
80 promptOpts := []prompt.Option{
81 prompt.WithWorkingDir(transcriptDir),
82 }
83
84 promptTemplate, err := prompt.NewPrompt("memory_search", string(memorySearchPromptTmpl), promptOpts...)
85 if err != nil {
86 return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err)
87 }
88
89 _, small, err := c.buildAgentModels(ctx, true)
90 if err != nil {
91 return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err)
92 }
93
94 systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg)
95 if err != nil {
96 return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err)
97 }
98
99 smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider)
100 if !ok {
101 return fantasy.ToolResponse{}, errors.New("small model provider not configured")
102 }
103
104 // Create sub-agent with read-only tools scoped to the transcript directory.
105 searchTools := []fantasy.AgentTool{
106 tools.NewGlobTool(transcriptDir),
107 tools.NewGrepTool(transcriptDir),
108 tools.NewViewTool(c.lspClients, c.permissions, transcriptDir),
109 }
110
111 agent := NewSessionAgent(SessionAgentOptions{
112 LargeModel: small, // Use small model for both (search doesn't need large)
113 SmallModel: small,
114 SystemPromptPrefix: smallProviderCfg.SystemPromptPrefix,
115 SystemPrompt: systemPrompt,
116 DisableAutoSummarize: true, // Never summarize the sub-agent session
117 IsYolo: c.permissions.SkipRequests(),
118 Sessions: c.sessions,
119 Messages: c.messages,
120 Tools: searchTools,
121 })
122
123 agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID)
124 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Memory Search")
125 if err != nil {
126 return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
127 }
128
129 c.permissions.AutoApproveSession(session.ID)
130
131 // Build the full prompt including the transcript path.
132 fullPrompt := fmt.Sprintf("%s\n\nThe session transcript is located at: %s\n\nUse grep and view to search this file for the requested information.", params.Query, transcriptPath)
133
134 // Use small model for transcript search.
135 maxTokens := small.CatwalkCfg.DefaultMaxTokens
136 if small.ModelCfg.MaxTokens != 0 {
137 maxTokens = small.ModelCfg.MaxTokens
138 }
139
140 result, err := agent.Run(ctx, SessionAgentCall{
141 SessionID: session.ID,
142 Prompt: fullPrompt,
143 MaxOutputTokens: maxTokens,
144 ProviderOptions: getProviderOptions(small, smallProviderCfg),
145 Temperature: small.ModelCfg.Temperature,
146 TopP: small.ModelCfg.TopP,
147 TopK: small.ModelCfg.TopK,
148 FrequencyPenalty: small.ModelCfg.FrequencyPenalty,
149 PresencePenalty: small.ModelCfg.PresencePenalty,
150 })
151 if err != nil {
152 return fantasy.NewTextErrorResponse("error generating response"), nil
153 }
154
155 // Update parent session cost.
156 updatedSession, err := c.sessions.Get(ctx, session.ID)
157 if err != nil {
158 return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
159 }
160 parentSession, err = c.sessions.Get(ctx, validationResult.SessionID)
161 if err != nil {
162 return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
163 }
164
165 parentSession.Cost += updatedSession.Cost
166
167 _, err = c.sessions.Save(ctx, parentSession)
168 if err != nil {
169 return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
170 }
171
172 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
173 }), nil
174}