From 7949ea5a5af79ef1efd6368d6e387235690ebf7f Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Mon, 19 Jan 2026 22:11:30 +0100 Subject: [PATCH] wip: summary memory tool --- internal/agent/agent.go | 160 +++++++++++++++- internal/agent/agent_test.go | 100 +++++++++- internal/agent/coordinator.go | 8 + internal/agent/memory_search_tool.go | 174 ++++++++++++++++++ internal/agent/templates/memory_search.md | 50 +++++ .../templates/memory_search_prompt.md.tpl | 54 ++++++ internal/agent/tools/memory_search_types.go | 9 + internal/config/config.go | 1 + internal/config/load_test.go | 4 +- internal/tui/components/chat/chat.go | 4 +- .../tui/components/chat/messages/renderer.go | 74 ++++++++ internal/tui/components/chat/messages/tool.go | 22 +++ 12 files changed, 653 insertions(+), 7 deletions(-) create mode 100644 internal/agent/memory_search_tool.go create mode 100644 internal/agent/templates/memory_search.md create mode 100644 internal/agent/templates/memory_search_prompt.md.tpl create mode 100644 internal/agent/tools/memory_search_types.go diff --git a/internal/agent/agent.go b/internal/agent/agent.go index c916cfd886372ab86f6d1fbb0e8b7bde2c87dabb..29a4ade558fbd487db31611d83c35b7d587322d3 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -16,6 +16,7 @@ import ( "fmt" "log/slog" "os" + "path/filepath" "regexp" "strconv" "strings" @@ -562,6 +563,11 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan return nil } + // Save transcript for later search via memory_search tool. + if err := a.saveTranscript(ctx, sessionID); err != nil { + slog.Warn("failed to save transcript", "error", err) + } + aiMsgs, _ := a.preparePrompt(msgs) genCtx, cancel := context.WithCancel(ctx) @@ -582,7 +588,7 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan return err } - summaryPromptText := buildSummaryPrompt(currentSession.Todos) + summaryPromptText := buildSummaryPrompt(sessionID, currentSession.Todos) resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{ Prompt: summaryPromptText, @@ -686,6 +692,13 @@ func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentC func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) { var history []fantasy.Message + hasSummary := false + for _, msg := range msgs { + if msg.IsSummaryMessage { + hasSummary = true + break + } + } if !a.isSubAgent { history = append(history, fantasy.NewUserMessage( fmt.Sprintf("%s", @@ -694,6 +707,13 @@ If you are working on tasks that would benefit from a todo list please use the " If not, please feel free to ignore. Again do not mention this message to the user.`, ), )) + if hasSummary { + history = append(history, fantasy.NewUserMessage( + fmt.Sprintf("%s", + `This session was summarized. If you need specific details from before the summary (commands, code, file paths, errors, decisions), use the "memory_search" tool to search the full transcript instead of guessing.`, + ), + )) + } } for _, m := range msgs { if len(m.Parts) == 0 { @@ -1122,9 +1142,16 @@ func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Mes } // buildSummaryPrompt constructs the prompt text for session summarization. -func buildSummaryPrompt(todos []session.Todo) string { +func buildSummaryPrompt(sessionID string, todos []session.Todo) string { var sb strings.Builder sb.WriteString("Provide a detailed summary of our conversation above.") + + // Include transcript path for memory search. + transcriptPath := TranscriptPath(sessionID) + sb.WriteString("\n\n## Session Transcript\n\n") + sb.WriteString(fmt.Sprintf("The full conversation transcript has been saved to: `%s`\n", transcriptPath)) + sb.WriteString("The resuming assistant can use the `memory_search` tool to search this transcript for specific details from the conversation.\n") + if len(todos) > 0 { sb.WriteString("\n\n## Current Todo List\n\n") for _, t := range todos { @@ -1135,3 +1162,132 @@ func buildSummaryPrompt(todos []session.Todo) string { } return sb.String() } + +// serializeTranscript converts a slice of messages to a searchable markdown +// transcript format. The transcript includes user messages, assistant +// responses, tool calls, tool results, and reasoning content. +func serializeTranscript(msgs []message.Message) string { + var sb strings.Builder + sb.WriteString("# Session Transcript\n\n") + + for _, msg := range msgs { + roleHeader := "Message" + switch msg.Role { + case message.User: + roleHeader = "User" + case message.Assistant: + roleHeader = "Assistant" + case message.Tool: + roleHeader = "Tool Results" + } + sb.WriteString(fmt.Sprintf("## %s\n\n", roleHeader)) + + switch msg.Role { + case message.User: + if text := msg.Content().Text; text != "" { + sb.WriteString("### Content\n\n") + sb.WriteString(text) + sb.WriteString("\n\n") + } + // Include binary content paths. + attachments := msg.BinaryContent() + if len(attachments) > 0 { + sb.WriteString("### Attachments\n\n") + for _, bc := range attachments { + sb.WriteString(fmt.Sprintf("- %s (%s)\n", bc.Path, bc.MIMEType)) + } + sb.WriteString("\n") + } + + case message.Assistant: + if msg.Model != "" { + sb.WriteString(fmt.Sprintf("**Model:** %s (%s)\n", msg.Model, msg.Provider)) + } + sb.WriteString("\n") + + // Reasoning content. + if reasoning := msg.ReasoningContent(); reasoning.Thinking != "" { + sb.WriteString("### Reasoning\n\n") + sb.WriteString("\n") + sb.WriteString(reasoning.Thinking) + sb.WriteString("\n\n\n") + } + + // Text content. + if text := msg.Content().Text; text != "" { + sb.WriteString("### Response\n\n") + sb.WriteString(text) + sb.WriteString("\n\n") + } + + // Tool calls. + toolCalls := msg.ToolCalls() + if len(toolCalls) > 0 { + sb.WriteString("### Tool Calls\n\n") + for _, tc := range toolCalls { + sb.WriteString("#### Tool Call\n\n") + sb.WriteString(fmt.Sprintf("**Tool:** `%s`\n\n", tc.Name)) + sb.WriteString("**Input:**\n\n") + sb.WriteString("```json\n") + sb.WriteString(tc.Input) + sb.WriteString("\n```\n\n") + } + } + + case message.Tool: + for _, tr := range msg.ToolResults() { + sb.WriteString("#### Tool Result\n\n") + sb.WriteString(fmt.Sprintf("**Tool:** `%s`\n", tr.Name)) + if tr.IsError { + sb.WriteString("**Status:** Error\n") + } else { + sb.WriteString("**Status:** Success\n") + } + // Truncate very long tool results. + content := tr.Content + const maxToolResultLen = 10000 + if len(content) > maxToolResultLen { + content = content[:maxToolResultLen] + "\n... (truncated)" + sb.WriteString("**Output:** (truncated)\n\n") + } else { + sb.WriteString("**Output:**\n\n") + } + sb.WriteString("```\n") + sb.WriteString(content) + sb.WriteString("\n```\n\n") + } + } + + sb.WriteString("---\n\n") + } + + return sb.String() +} + +// saveTranscript serializes messages to a markdown file for later search. +func (a *sessionAgent) saveTranscript(ctx context.Context, sessionID string) error { + msgs, err := a.messages.List(ctx, sessionID) + if err != nil { + return fmt.Errorf("failed to list messages: %w", err) + } + cfg := config.Get() + transcriptsDir := filepath.Join(cfg.Options.DataDirectory, "transcripts") + if err := os.MkdirAll(transcriptsDir, 0o755); err != nil { + return fmt.Errorf("failed to create transcripts directory: %w", err) + } + + transcriptPath := filepath.Join(transcriptsDir, sessionID+".md") + transcript := serializeTranscript(msgs) + if err := os.WriteFile(transcriptPath, []byte(transcript), 0o644); err != nil { + return fmt.Errorf("failed to write transcript: %w", err) + } + + slog.Debug("saved transcript", "path", transcriptPath, "messages", len(msgs)) + return nil +} + +// TranscriptPath returns the path where a session's transcript would be saved. +func TranscriptPath(sessionID string) string { + cfg := config.Get() + return filepath.Join(cfg.Options.DataDirectory, "transcripts", sessionID+".md") +} diff --git a/internal/agent/agent_test.go b/internal/agent/agent_test.go index 222f9575d867e6977ee5062e57008f3c154df89c..604b580d34c933c87a4ea8352c83329738c72e75 100644 --- a/internal/agent/agent_test.go +++ b/internal/agent/agent_test.go @@ -7,6 +7,7 @@ import ( "runtime" "strings" "testing" + "time" "charm.land/fantasy" "charm.land/x/vcr" @@ -646,8 +647,105 @@ func BenchmarkBuildSummaryPrompt(b *testing.B) { b.Run(tc.name, func(b *testing.B) { b.ReportAllocs() for range b.N { - _ = buildSummaryPrompt(todos) + _ = buildSummaryPrompt("test-session-id", todos) } }) } } + +func TestSerializeTranscript(t *testing.T) { + now := time.Now().Unix() + + msgs := []message.Message{ + { + ID: "msg1", + Role: message.User, + SessionID: "sess1", + CreatedAt: now, + Parts: []message.ContentPart{ + message.TextContent{Text: "Hello, can you help me?"}, + }, + }, + { + ID: "msg2", + Role: message.Assistant, + SessionID: "sess1", + Model: "claude-sonnet-4-20250514", + Provider: "anthropic", + CreatedAt: now + 1, + Parts: []message.ContentPart{ + message.TextContent{Text: "Of course! What do you need help with?"}, + message.ToolCall{ + ID: "tc1", + Name: "view", + Input: `{"file_path": "/test/file.go"}`, + }, + }, + }, + { + ID: "msg3", + Role: message.Tool, + SessionID: "sess1", + CreatedAt: now + 2, + Parts: []message.ContentPart{ + message.ToolResult{ + ToolCallID: "tc1", + Name: "view", + Content: "package main\n\nfunc main() {}", + IsError: false, + }, + }, + }, + } + + transcript := serializeTranscript(msgs) + + // Verify structure. + require.Contains(t, transcript, "# Session Transcript") + require.Contains(t, transcript, "## User") + require.Contains(t, transcript, "## Assistant") + require.Contains(t, transcript, "## Tool Results") + + // Verify user message. + require.Contains(t, transcript, "Hello, can you help me?") + + // Verify assistant message. + require.Contains(t, transcript, "claude-sonnet-4-20250514") + require.Contains(t, transcript, "Of course! What do you need help with?") + require.Contains(t, transcript, "**Tool:** `view`") + require.Contains(t, transcript, `"file_path": "/test/file.go"`) + + // Verify tool result. + require.Contains(t, transcript, "**Status:** Success") + require.Contains(t, transcript, "package main") +} + +func TestSerializeTranscript_TruncatesLongToolResults(t *testing.T) { + now := time.Now().Unix() + + // Create a tool result with content larger than the truncation threshold. + longContent := strings.Repeat("x", 15000) + + msgs := []message.Message{ + { + ID: "msg1", + Role: message.Tool, + SessionID: "sess1", + CreatedAt: now, + Parts: []message.ContentPart{ + message.ToolResult{ + ToolCallID: "tc1", + Name: "bash", + Content: longContent, + IsError: false, + }, + }, + }, + } + + transcript := serializeTranscript(msgs) + + // Verify truncation happened. + require.Contains(t, transcript, "... (truncated)") + require.Less(t, len(transcript), 15000, "Transcript should be smaller than original content") +} diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 943c3efc41b33ea9f261b4ffc7256b6f544beff9..397472dca3474c4dea8cecd63e707d7fde51d210 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -375,6 +375,14 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan allTools = append(allTools, agenticFetchTool) } + if slices.Contains(agent.AllowedTools, tools.MemorySearchToolName) { + memorySearchTool, err := c.memorySearchTool(ctx) + if err != nil { + return nil, err + } + allTools = append(allTools, memorySearchTool) + } + // Get the model name for the agent modelName := "" if modelCfg, ok := c.cfg.Models[agent.Model]; ok { diff --git a/internal/agent/memory_search_tool.go b/internal/agent/memory_search_tool.go new file mode 100644 index 0000000000000000000000000000000000000000..87223c6649fda47b0a2a0fba2f7df75d83ee6d9d --- /dev/null +++ b/internal/agent/memory_search_tool.go @@ -0,0 +1,174 @@ +package agent + +import ( + "context" + _ "embed" + "errors" + "fmt" + "os" + "path/filepath" + + "charm.land/fantasy" + + "github.com/charmbracelet/crush/internal/agent/prompt" + "github.com/charmbracelet/crush/internal/agent/tools" +) + +//go:embed templates/memory_search.md +var memorySearchToolDescription []byte + +//go:embed templates/memory_search_prompt.md.tpl +var memorySearchPromptTmpl []byte + +// memorySearchValidationResult holds the validated parameters from the tool call context. +type memorySearchValidationResult struct { + SessionID string + AgentMessageID string +} + +// validateMemorySearchParams validates the tool call parameters and extracts required context values. +func validateMemorySearchParams(ctx context.Context, params tools.MemorySearchParams) (memorySearchValidationResult, error) { + if params.Query == "" { + return memorySearchValidationResult{}, errors.New("query is required") + } + + sessionID := tools.GetSessionFromContext(ctx) + if sessionID == "" { + return memorySearchValidationResult{}, errors.New("session id missing from context") + } + + agentMessageID := tools.GetMessageFromContext(ctx) + if agentMessageID == "" { + return memorySearchValidationResult{}, errors.New("agent message id missing from context") + } + + return memorySearchValidationResult{ + SessionID: sessionID, + AgentMessageID: agentMessageID, + }, nil +} + +func (c *coordinator) memorySearchTool(_ context.Context) (fantasy.AgentTool, error) { + return fantasy.NewParallelAgentTool( + tools.MemorySearchToolName, + string(memorySearchToolDescription), + func(ctx context.Context, params tools.MemorySearchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) { + validationResult, err := validateMemorySearchParams(ctx, params) + if err != nil { + return fantasy.NewTextErrorResponse(err.Error()), nil + } + + // Get the parent session to find the transcript. + parentSession, err := c.sessions.Get(ctx, validationResult.SessionID) + if err != nil { + return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to get session: %s", err)), nil + } + + // Check if session has been summarized. + if parentSession.SummaryMessageID == "" { + return fantasy.NewTextErrorResponse("This session has not been summarized yet. The memory_search tool is only available after summarization."), nil + } + + // Find the transcript file. + transcriptPath := TranscriptPath(parentSession.ID) + if _, err := os.Stat(transcriptPath); os.IsNotExist(err) { + return fantasy.NewTextErrorResponse(fmt.Sprintf("Transcript file not found at %s. The session may have been summarized before this feature was available.", transcriptPath)), nil + } + + // Build the sub-agent prompt. + transcriptDir := filepath.Dir(transcriptPath) + promptOpts := []prompt.Option{ + prompt.WithWorkingDir(transcriptDir), + } + + promptTemplate, err := prompt.NewPrompt("memory_search", string(memorySearchPromptTmpl), promptOpts...) + if err != nil { + return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err) + } + + _, small, err := c.buildAgentModels(ctx, true) + if err != nil { + return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err) + } + + systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg) + if err != nil { + return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err) + } + + smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider) + if !ok { + return fantasy.ToolResponse{}, errors.New("small model provider not configured") + } + + // Create sub-agent with read-only tools scoped to the transcript directory. + searchTools := []fantasy.AgentTool{ + tools.NewGlobTool(transcriptDir), + tools.NewGrepTool(transcriptDir), + tools.NewViewTool(c.lspClients, c.permissions, transcriptDir), + } + + agent := NewSessionAgent(SessionAgentOptions{ + LargeModel: small, // Use small model for both (search doesn't need large) + SmallModel: small, + SystemPromptPrefix: smallProviderCfg.SystemPromptPrefix, + SystemPrompt: systemPrompt, + DisableAutoSummarize: true, // Never summarize the sub-agent session + IsYolo: c.permissions.SkipRequests(), + Sessions: c.sessions, + Messages: c.messages, + Tools: searchTools, + }) + + agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID) + session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Memory Search") + if err != nil { + return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err) + } + + c.permissions.AutoApproveSession(session.ID) + + // Build the full prompt including the transcript path. + fullPrompt := fmt.Sprintf("%s\n\nThe session transcript is located at: %s\n\nUse grep and view to search this file for the requested information.", params.Query, transcriptPath) + + // Use small model for transcript search. + maxTokens := small.CatwalkCfg.DefaultMaxTokens + if small.ModelCfg.MaxTokens != 0 { + maxTokens = small.ModelCfg.MaxTokens + } + + result, err := agent.Run(ctx, SessionAgentCall{ + SessionID: session.ID, + Prompt: fullPrompt, + MaxOutputTokens: maxTokens, + ProviderOptions: getProviderOptions(small, smallProviderCfg), + Temperature: small.ModelCfg.Temperature, + TopP: small.ModelCfg.TopP, + TopK: small.ModelCfg.TopK, + FrequencyPenalty: small.ModelCfg.FrequencyPenalty, + PresencePenalty: small.ModelCfg.PresencePenalty, + }) + if err != nil { + return fantasy.NewTextErrorResponse("error generating response"), nil + } + + // Update parent session cost. + updatedSession, err := c.sessions.Get(ctx, session.ID) + if err != nil { + return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err) + } + parentSession, err = c.sessions.Get(ctx, validationResult.SessionID) + if err != nil { + return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err) + } + + parentSession.Cost += updatedSession.Cost + + _, err = c.sessions.Save(ctx, parentSession) + if err != nil { + return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err) + } + + return fantasy.NewTextResponse(result.Response.Content.Text()), nil + }), nil +} diff --git a/internal/agent/templates/memory_search.md b/internal/agent/templates/memory_search.md new file mode 100644 index 0000000000000000000000000000000000000000..c5be6dc037f99b118a3cdedee64fb039eb147ecd --- /dev/null +++ b/internal/agent/templates/memory_search.md @@ -0,0 +1,50 @@ +Searches the full conversation transcript from a summarized session to recover specific details that were condensed in the summary. + + +Use this tool when you need to: +- Recover exact details that are missing from the summary +- Find specific code snippets, file paths, or commands discussed earlier +- Locate tool calls or their results from earlier in the session +- Retrieve precise error messages or decisions made earlier + +DO NOT use this tool when: +- The information you need is already in the current conversation context +- You're starting a fresh session with no summarization +- You need information from a different session + + + +- Provide a natural language query describing the information you want +- The query is interpreted by a sub-agent that searches the transcript with grep/view +- Returns a concise answer with quoted excerpts as evidence + + + +- query: A natural language description of what you want to find (required) + + + +- Only available after a session has been summarized +- The transcript contains the full conversation, including code blocks and tool calls +- The tool searches only the current session transcript +- Results include supporting excerpts to ground the answer + + + +- Cannot search across multiple sessions +- Very long tool results may be truncated in the transcript +- Binary file contents are not included (only paths are recorded) + + + +- Be specific: include names, URLs, filenames, error text, or distinctive phrases +- If the first query is too broad, refine it with additional constraints +- Ask for exact quotes when you need verbatim text + + + +- query: "What was the exact error message when running the tests?" +- query: "Find the implementation of the serializeTranscript function" +- query: "What file paths were modified in the refactoring?" +- query: "What approach did we try first that didn't work?" + diff --git a/internal/agent/templates/memory_search_prompt.md.tpl b/internal/agent/templates/memory_search_prompt.md.tpl new file mode 100644 index 0000000000000000000000000000000000000000..c3b1ca40b67f30f74521f29dbf8105e60ba64880 --- /dev/null +++ b/internal/agent/templates/memory_search_prompt.md.tpl @@ -0,0 +1,54 @@ +You are a memory search agent for Crush. The main agent has only a summary and needs you to recover specific details from the full transcript. + + +Treat the transcript as external memory: probe, filter, and retrieve only the minimal evidence needed to answer correctly. + + + +1. Be concise and direct in your responses +2. Focus only on the information requested in the user's query +3. Start with grep to filter; do not scan the entire transcript +4. Prefer targeted view reads around the best matches +5. Use alternative keywords or synonyms if the first search fails +6. Minimize tool calls by batching related searches +7. Avoid redundant verification once evidence is sufficient +8. Quote the exact lines that support the answer +9. If the requested information is not found, clearly state that +10. Any file paths you use MUST be absolute +11. Include enough surrounding context to interpret the match + + + +The transcript is a markdown file with this structure: +- Each message is marked with "## Message N [timestamp]" +- Messages have **Role:** (User, Assistant, or Tool Results) +- User messages have ### Content sections +- Assistant messages have ### Reasoning, ### Response, and ### Tool Calls sections +- Tool results show the tool name, status, and output +- Messages are separated by "---" +This is a full conversation like a live session, including code blocks and tool calls. + + + +1. Extract concrete keywords from the query (names, URLs, error text, identifiers, dates, file paths) +2. Grep for multiple keywords or regex patterns in a single pass when possible +3. If there are too many hits, add constraints (exact phrases, nearby terms) +4. If there are no hits, expand with synonyms or related terms +5. View surrounding context for the strongest hits to confirm relevance +6. If the answer is distributed, gather minimal supporting excerpts and aggregate +7. Stop when you have sufficient evidence to answer + + + +Your response should include: +1. A direct answer to the query +2. Relevant excerpts from the transcript (quoted) + +If nothing is found, explain what you searched for and suggest alternative search terms the user might try. + + + +Working directory: {{.WorkingDir}} +Platform: {{.Platform}} +Today's date: {{.Date}} + diff --git a/internal/agent/tools/memory_search_types.go b/internal/agent/tools/memory_search_types.go new file mode 100644 index 0000000000000000000000000000000000000000..bce404aaba0bc5f63f4f84da1efc0e35922ff7c2 --- /dev/null +++ b/internal/agent/tools/memory_search_types.go @@ -0,0 +1,9 @@ +package tools + +// MemorySearchToolName is the name of the memory_search tool. +const MemorySearchToolName = "memory_search" + +// MemorySearchParams defines the parameters for the memory_search tool. +type MemorySearchParams struct { + Query string `json:"query" description:"The query describing what information to search for in the session transcript"` +} diff --git a/internal/config/config.go b/internal/config/config.go index 2c414e3e9e35d6f232e00762f50aca1066aca321..3b6eec802ca3047055bcdab21c5ca1394d17043c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -692,6 +692,7 @@ func allToolNames() []string { "lsp_references", "fetch", "agentic_fetch", + "memory_search", "glob", "grep", "ls", diff --git a/internal/config/load_test.go b/internal/config/load_test.go index 8924475ef9c652ea1962e4f032a0e62e560bce7a..3a1bff45ea1147eadf52d80ebe2149c6ed1212ac 100644 --- a/internal/config/load_test.go +++ b/internal/config/load_test.go @@ -486,7 +486,7 @@ func TestConfig_setupAgentsWithDisabledTools(t *testing.T) { coderAgent, ok := cfg.Agents[AgentCoder] require.True(t, ok) - assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "glob", "ls", "sourcegraph", "todos", "view", "write"}, coderAgent.AllowedTools) + assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "memory_search", "glob", "ls", "sourcegraph", "todos", "view", "write"}, coderAgent.AllowedTools) taskAgent, ok := cfg.Agents[AgentTask] require.True(t, ok) @@ -509,7 +509,7 @@ func TestConfig_setupAgentsWithEveryReadOnlyToolDisabled(t *testing.T) { cfg.SetupAgents() coderAgent, ok := cfg.Agents[AgentCoder] require.True(t, ok) - assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "todos", "write"}, coderAgent.AllowedTools) + assert.Equal(t, []string{"agent", "bash", "job_output", "job_kill", "download", "edit", "multiedit", "lsp_diagnostics", "lsp_references", "fetch", "agentic_fetch", "memory_search", "todos", "write"}, coderAgent.AllowedTools) taskAgent, ok := cfg.Agents[AgentTask] require.True(t, ok) diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index 036c8262d2b0d8419bf89b64afd922767b6be12a..10ad8f10bf4315c6930b466c58380bbff5dabe3f 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -614,8 +614,8 @@ func (m *messageListCmp) convertAssistantMessage(msg message.Message, toolResult for _, tc := range msg.ToolCalls() { options := m.buildToolCallOptions(tc, msg, toolResultMap) uiMessages = append(uiMessages, messages.NewToolCallCmp(msg.ID, tc, m.app.Permissions, options...)) - // If this tool call is the agent tool or agentic fetch, fetch nested tool calls - if tc.Name == agent.AgentToolName || tc.Name == tools.AgenticFetchToolName { + // If this tool call is the agent tool, agentic fetch, or memory search, fetch nested tool calls + if tc.Name == agent.AgentToolName || tc.Name == tools.AgenticFetchToolName || tc.Name == tools.MemorySearchToolName { agentToolSessionID := m.app.Sessions.CreateAgentToolSessionID(msg.ID, tc.ID) nestedMessages, _ := m.app.Messages.List(context.Background(), agentToolSessionID) nestedToolResultMap := m.buildToolResultMap(nestedMessages) diff --git a/internal/tui/components/chat/messages/renderer.go b/internal/tui/components/chat/messages/renderer.go index 5fbd8a653c0b0374029bf13b31721d8ad5150948..da4e1f2785c5ebae3e3a7ae6316fc602f969242d 100644 --- a/internal/tui/components/chat/messages/renderer.go +++ b/internal/tui/components/chat/messages/renderer.go @@ -191,6 +191,7 @@ func init() { registry.register(tools.WriteToolName, func() renderer { return writeRenderer{} }) registry.register(tools.FetchToolName, func() renderer { return simpleFetchRenderer{} }) registry.register(tools.AgenticFetchToolName, func() renderer { return agenticFetchRenderer{} }) + registry.register(tools.MemorySearchToolName, func() renderer { return memorySearchRenderer{} }) registry.register(tools.WebFetchToolName, func() renderer { return webFetchRenderer{} }) registry.register(tools.WebSearchToolName, func() renderer { return webSearchRenderer{} }) registry.register(tools.GlobToolName, func() renderer { return globRenderer{} }) @@ -704,6 +705,77 @@ func (fr agenticFetchRenderer) Render(v *toolCallCmp) string { return joinHeaderBody(header, body) } +// ----------------------------------------------------------------------------- +// Memory Search renderer +// ----------------------------------------------------------------------------- + +// memorySearchRenderer handles session transcript searching with nested tool calls +type memorySearchRenderer struct { + baseRenderer +} + +// Render displays the memory search query and nested tool calls +func (mr memorySearchRenderer) Render(v *toolCallCmp) string { + t := styles.CurrentTheme() + var params tools.MemorySearchParams + if err := mr.unmarshalParams(v.call.Input, ¶ms); err != nil { + return mr.renderWithParams(v, "Memory Search", []string{v.call.Input}, func() string { + return renderPlainContent(v, v.result.Content) + }) + } + + query := params.Query + query = strings.ReplaceAll(query, "\n", " ") + + header := mr.makeHeader(v, "Memory Search", v.textWidth()) + if res, done := earlyState(header, v); v.cancelled && done { + return res + } + + taskTag := t.S().Base.Bold(true).Padding(0, 1).MarginLeft(2).Background(t.Citron).Foreground(t.Border).Render("Query") + remainingWidth := v.textWidth() - (lipgloss.Width(taskTag) + 1) + remainingWidth = min(remainingWidth, 120-(lipgloss.Width(taskTag)+1)) + query = t.S().Base.Width(remainingWidth).Render(query) + header = lipgloss.JoinVertical( + lipgloss.Left, + header, + "", + lipgloss.JoinHorizontal( + lipgloss.Left, + taskTag, + " ", + query, + ), + ) + childTools := tree.Root(header) + + for _, call := range v.nestedToolCalls { + call.SetSize(remainingWidth, 1) + childTools.Child(call.View()) + } + parts := []string{ + childTools.Enumerator(RoundedEnumeratorWithWidth(2, lipgloss.Width(taskTag)-5)).String(), + } + + if v.result.ToolCallID == "" { + v.spinning = true + parts = append(parts, "", v.anim.View()) + } else { + v.spinning = false + } + + header = lipgloss.JoinVertical( + lipgloss.Left, + parts..., + ) + + if v.result.ToolCallID == "" { + return header + } + body := renderMarkdownContent(v, v.result.Content) + return joinHeaderBody(header, body) +} + // formatTimeout converts timeout seconds to duration string func formatTimeout(timeout int) string { if timeout == 0 { @@ -1288,6 +1360,8 @@ func prettifyToolName(name string) string { return "Fetch" case tools.AgenticFetchToolName: return "Agentic Fetch" + case tools.MemorySearchToolName: + return "Memory Search" case tools.WebFetchToolName: return "Fetch" case tools.WebSearchToolName: diff --git a/internal/tui/components/chat/messages/tool.go b/internal/tui/components/chat/messages/tool.go index b8163f5a4c2a51f13ebd7ba2650bb7c3f33dac44..517413683a3be89ae30a2567b18b751c08166d4b 100644 --- a/internal/tui/components/chat/messages/tool.go +++ b/internal/tui/components/chat/messages/tool.go @@ -309,6 +309,11 @@ func (m *toolCallCmp) formatParametersForCopy() string { } return strings.Join(parts, "\n") } + case tools.MemorySearchToolName: + var params tools.MemorySearchParams + if json.Unmarshal([]byte(m.call.Input), ¶ms) == nil { + return fmt.Sprintf("**Query:** %s", params.Query) + } case tools.WebFetchToolName: var params tools.WebFetchParams if json.Unmarshal([]byte(m.call.Input), ¶ms) == nil { @@ -421,6 +426,8 @@ func (m *toolCallCmp) formatResultForCopy() string { return m.formatFetchResultForCopy() case tools.AgenticFetchToolName: return m.formatAgenticFetchResultForCopy() + case tools.MemorySearchToolName: + return m.formatMemorySearchResultForCopy() case tools.WebFetchToolName: return m.formatWebFetchResultForCopy() case agent.AgentToolName: @@ -670,6 +677,21 @@ func (m *toolCallCmp) formatAgenticFetchResultForCopy() string { return result.String() } +func (m *toolCallCmp) formatMemorySearchResultForCopy() string { + var params tools.MemorySearchParams + if json.Unmarshal([]byte(m.call.Input), ¶ms) != nil { + return m.result.Content + } + + var result strings.Builder + result.WriteString(fmt.Sprintf("Query: %s\n\n", params.Query)) + result.WriteString("```markdown\n") + result.WriteString(m.result.Content) + result.WriteString("\n```") + + return result.String() +} + func (m *toolCallCmp) formatWebFetchResultForCopy() string { var params tools.WebFetchParams if json.Unmarshal([]byte(m.call.Input), ¶ms) != nil {