From 6808721f22c5b729b96eea9d17741c33bb09ce27 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 17 Jul 2025 15:34:17 +0200 Subject: [PATCH] chore: improve permissions & edit tool --- internal/app/app.go | 1 + internal/llm/agent/mcp-tools.go | 1 + internal/llm/prompt/coder.go | 6 +- internal/llm/tools/bash.go | 1 + internal/llm/tools/edit.go | 102 ++-- internal/llm/tools/fetch.go | 1 + internal/llm/tools/multiedit.go | 467 ++++++++++++++++++ internal/llm/tools/write.go | 1 + internal/permission/permission.go | 86 +++- internal/permission/permission_test.go | 159 ++++++ internal/tui/components/chat/chat.go | 32 +- .../tui/components/chat/messages/renderer.go | 60 ++- internal/tui/components/chat/messages/tool.go | 37 +- .../dialogs/permissions/permissions.go | 43 +- internal/tui/page/chat/chat.go | 6 + internal/tui/tui.go | 5 + 16 files changed, 950 insertions(+), 58 deletions(-) create mode 100644 internal/llm/tools/multiedit.go diff --git a/internal/app/app.go b/internal/app/app.go index 50e117ea1ae272156dbd11baa1a5f157a74333f1..f636395e58d65c100b5f68e31c704f2189bcf995 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -203,6 +203,7 @@ func (app *App) setupEvents() { setupSubscriber(ctx, app.serviceEventsWG, "sessions", app.Sessions.Subscribe, app.events) setupSubscriber(ctx, app.serviceEventsWG, "messages", app.Messages.Subscribe, app.events) setupSubscriber(ctx, app.serviceEventsWG, "permissions", app.Permissions.Subscribe, app.events) + setupSubscriber(ctx, app.serviceEventsWG, "permissions-notifications", app.Permissions.SubscribeNotifications, app.events) setupSubscriber(ctx, app.serviceEventsWG, "history", app.History.Subscribe, app.events) cleanupFunc := func() { cancel() diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index 0165b0f7194d029a6dee9113f82877820ce96c00..05b4fada88973608b94eb840a18d65efae70fccf 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -100,6 +100,7 @@ func (b *mcpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes p := b.permissions.Request( permission.CreatePermissionRequest{ SessionID: sessionID, + ToolCallID: params.ID, Path: b.workingDir, ToolName: b.Info().Name, Action: "execute", diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 2ffbf2111931ad111751af1bfcd492422da205ee..86394e6ce375ee7f4fb2d985e602075feb6180d0 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -107,7 +107,7 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN # Tool usage policy - When doing file search, prefer to use the Agent tool in order to reduce context usage. -- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in parallel. +- IMPORTANT: All tools are executed in parallel when multiple tool calls are sent in a single message. Only send multiple tool calls when they are safe to run in parallel (no dependencies between them). - IMPORTANT: The user does not see the full output of the tool responses, so if you need the output of the tool for the response make sure to summarize it for the user. # Proactiveness @@ -217,7 +217,7 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN # Tool usage policy - When doing file search, prefer to use the Agent tool in order to reduce context usage. -- If you intend to call multiple tools and there are no dependencies between the calls, make all of the independent calls in parallel. +- IMPORTANT: All tools are executed in parallel when multiple tool calls are sent in a single message. Only send multiple tool calls when they are safe to run in parallel (no dependencies between them). - IMPORTANT: The user does not see the full output of the tool responses, so if you need the output of the tool for the response make sure to summarize it for the user. VERY IMPORTANT NEVER use emojis in your responses. @@ -281,7 +281,7 @@ NEVER commit changes unless the user explicitly asks you to. It is VERY IMPORTAN ## Tool Usage - **File Paths:** Always use absolute paths when referring to files with tools like ` + "`view`" + ` or ` + "`write`" + `. Relative paths are not supported. You must provide an absolute path. -- **Parallelism:** Execute multiple independent tool calls in parallel when feasible (i.e. searching the codebase). +- **Parallelism:** IMPORTANT: All tools are executed in parallel when multiple tool calls are sent in a single message. Only send multiple tool calls when they are safe to run in parallel (no dependencies between them). - **Command Execution:** Use the ` + "`bash`" + ` tool for running shell commands, remembering the safety rule to explain modifying commands first. - **Background Processes:** Use background processes (via ` + "`&`" + `) for commands that are unlikely to stop on their own, e.g. ` + "`node server.js &`" + `. If unsure, ask the user. - **Interactive Commands:** Try to avoid shell commands that are likely to require user interaction (e.g. ` + "`git rebase -i`" + `). Use non-interactive versions of commands (e.g. ` + "`npm init -y`" + ` instead of ` + "`npm init`" + `) when available, and otherwise remind the user that interactive shell commands are not supported and may cause hangs until canceled by the user. diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 99ab86068a5effa1e631037f3340ba814055d709..1954c356cc634164a77bb51dec665bfb1405a4d9 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -373,6 +373,7 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) permission.CreatePermissionRequest{ SessionID: sessionID, Path: b.workingDir, + ToolCallID: call.ID, ToolName: BashToolName, Action: "execute", Description: fmt.Sprintf("Execute command: %s", params.Command), diff --git a/internal/llm/tools/edit.go b/internal/llm/tools/edit.go index e09151781cf7f3c53fd0d23de46f1b9ca7dd3607..77821b7119bcd3756bb531a031b0d99307361718 100644 --- a/internal/llm/tools/edit.go +++ b/internal/llm/tools/edit.go @@ -18,9 +18,10 @@ import ( ) type EditParams struct { - FilePath string `json:"file_path"` - OldString string `json:"old_string"` - NewString string `json:"new_string"` + FilePath string `json:"file_path"` + OldString string `json:"old_string"` + NewString string `json:"new_string"` + ReplaceAll bool `json:"replace_all,omitempty"` } type EditPermissionsParams struct { @@ -58,31 +59,33 @@ To make a file edit, provide the following: 1. file_path: The absolute path to the file to modify (must be absolute, not relative) 2. old_string: The text to replace (must be unique within the file, and must match the file contents exactly, including all whitespace and indentation) 3. new_string: The edited text to replace the old_string +4. replace_all: Replace all occurrences of old_string (default false) Special cases: - To create a new file: provide file_path and new_string, leave old_string empty - To delete content: provide file_path and old_string, leave new_string empty -The tool will replace ONE occurrence of old_string with new_string in the specified file. +The tool will replace ONE occurrence of old_string with new_string in the specified file by default. Set replace_all to true to replace all occurrences. CRITICAL REQUIREMENTS FOR USING THIS TOOL: -1. UNIQUENESS: The old_string MUST uniquely identify the specific instance you want to change. This means: +1. UNIQUENESS: When replace_all is false (default), the old_string MUST uniquely identify the specific instance you want to change. This means: - Include AT LEAST 3-5 lines of context BEFORE the change point - Include AT LEAST 3-5 lines of context AFTER the change point - Include all whitespace, indentation, and surrounding code exactly as it appears in the file -2. SINGLE INSTANCE: This tool can only change ONE instance at a time. If you need to change multiple instances: - - Make separate calls to this tool for each instance +2. SINGLE INSTANCE: When replace_all is false, this tool can only change ONE instance at a time. If you need to change multiple instances: + - Set replace_all to true to replace all occurrences at once + - Or make separate calls to this tool for each instance - Each call must uniquely identify its specific instance using extensive context 3. VERIFICATION: Before using this tool: - Check how many instances of the target text exist in the file - - If multiple instances exist, gather enough context to uniquely identify each one - - Plan separate tool calls for each instance + - If multiple instances exist and replace_all is false, gather enough context to uniquely identify each one + - Plan separate tool calls for each instance or use replace_all WARNING: If you do not follow these requirements: - - The tool will fail if old_string matches multiple locations + - The tool will fail if old_string matches multiple locations and replace_all is false - The tool will fail if old_string doesn't match exactly (including whitespace) - You may change the wrong instance if you don't include enough context @@ -129,6 +132,10 @@ func (e *editTool) Info() ToolInfo { "type": "string", "description": "The text to replace it with", }, + "replace_all": map[string]any{ + "type": "boolean", + "description": "Replace all occurrences of old_string (default false)", + }, }, Required: []string{"file_path", "old_string", "new_string"}, } @@ -152,20 +159,20 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) var err error if params.OldString == "" { - response, err = e.createNewFile(ctx, params.FilePath, params.NewString) + response, err = e.createNewFile(ctx, params.FilePath, params.NewString, call) if err != nil { return response, err } } if params.NewString == "" { - response, err = e.deleteContent(ctx, params.FilePath, params.OldString) + response, err = e.deleteContent(ctx, params.FilePath, params.OldString, params.ReplaceAll, call) if err != nil { return response, err } } - response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString) + response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString, params.ReplaceAll, call) if err != nil { return response, err } @@ -182,7 +189,7 @@ func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) return response, nil } -func (e *editTool) createNewFile(ctx context.Context, filePath, content string) (ToolResponse, error) { +func (e *editTool) createNewFile(ctx context.Context, filePath, content string, call ToolCall) (ToolResponse, error) { fileInfo, err := os.Stat(filePath) if err == nil { if fileInfo.IsDir() { @@ -217,6 +224,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) permission.CreatePermissionRequest{ SessionID: sessionID, Path: permissionPath, + ToolCallID: call.ID, ToolName: EditToolName, Action: "write", Description: fmt.Sprintf("Create file %s", filePath), @@ -264,7 +272,7 @@ func (e *editTool) createNewFile(ctx context.Context, filePath, content string) ), nil } -func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string) (ToolResponse, error) { +func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string, replaceAll bool, call ToolCall) (ToolResponse, error) { fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { @@ -297,17 +305,29 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string oldContent := string(content) - index := strings.Index(oldContent, oldString) - if index == -1 { - return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil - } + var newContent string + var deletionCount int - lastIndex := strings.LastIndex(oldContent, oldString) - if index != lastIndex { - return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil - } + if replaceAll { + newContent = strings.ReplaceAll(oldContent, oldString, "") + deletionCount = strings.Count(oldContent, oldString) + if deletionCount == 0 { + return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil + } + } else { + index := strings.Index(oldContent, oldString) + if index == -1 { + return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil + } + + lastIndex := strings.LastIndex(oldContent, oldString) + if index != lastIndex { + return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil + } - newContent := oldContent[:index] + oldContent[index+len(oldString):] + newContent = oldContent[:index] + oldContent[index+len(oldString):] + deletionCount = 1 + } sessionID, messageID := GetContextValues(ctx) @@ -330,6 +350,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string permission.CreatePermissionRequest{ SessionID: sessionID, Path: permissionPath, + ToolCallID: call.ID, ToolName: EditToolName, Action: "write", Description: fmt.Sprintf("Delete content from file %s", filePath), @@ -385,7 +406,7 @@ func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string ), nil } -func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string) (ToolResponse, error) { +func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string, replaceAll bool, call ToolCall) (ToolResponse, error) { fileInfo, err := os.Stat(filePath) if err != nil { if os.IsNotExist(err) { @@ -418,17 +439,29 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS oldContent := string(content) - index := strings.Index(oldContent, oldString) - if index == -1 { - return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil - } + var newContent string + var replacementCount int - lastIndex := strings.LastIndex(oldContent, oldString) - if index != lastIndex { - return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match"), nil - } + if replaceAll { + newContent = strings.ReplaceAll(oldContent, oldString, newString) + replacementCount = strings.Count(oldContent, oldString) + if replacementCount == 0 { + return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil + } + } else { + index := strings.Index(oldContent, oldString) + if index == -1 { + return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil + } + + lastIndex := strings.LastIndex(oldContent, oldString) + if index != lastIndex { + return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil + } - newContent := oldContent[:index] + newString + oldContent[index+len(oldString):] + newContent = oldContent[:index] + newString + oldContent[index+len(oldString):] + replacementCount = 1 + } if oldContent == newContent { return NewTextErrorResponse("new content is the same as old content. No changes made."), nil @@ -452,6 +485,7 @@ func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newS permission.CreatePermissionRequest{ SessionID: sessionID, Path: permissionPath, + ToolCallID: call.ID, ToolName: EditToolName, Action: "write", Description: fmt.Sprintf("Replace content in file %s", filePath), diff --git a/internal/llm/tools/fetch.go b/internal/llm/tools/fetch.go index 1e44151b1124c643d2ddd428144e66c5d365e609..156dbff7edd5747c4e758fc09cf94a5230c50deb 100644 --- a/internal/llm/tools/fetch.go +++ b/internal/llm/tools/fetch.go @@ -136,6 +136,7 @@ func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error permission.CreatePermissionRequest{ SessionID: sessionID, Path: t.workingDir, + ToolCallID: call.ID, ToolName: FetchToolName, Action: "fetch", Description: fmt.Sprintf("Fetch content from URL: %s", params.URL), diff --git a/internal/llm/tools/multiedit.go b/internal/llm/tools/multiedit.go new file mode 100644 index 0000000000000000000000000000000000000000..2038140e7a6bc33741772eb315a5cf69258b7c1e --- /dev/null +++ b/internal/llm/tools/multiedit.go @@ -0,0 +1,467 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "time" + + "github.com/charmbracelet/crush/internal/diff" + "github.com/charmbracelet/crush/internal/history" + "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/crush/internal/permission" +) + +type MultiEditOperation struct { + OldString string `json:"old_string"` + NewString string `json:"new_string"` + ReplaceAll bool `json:"replace_all,omitempty"` +} + +type MultiEditParams struct { + FilePath string `json:"file_path"` + Edits []MultiEditOperation `json:"edits"` +} + +type MultiEditPermissionsParams struct { + FilePath string `json:"file_path"` + OldContent string `json:"old_content,omitempty"` + NewContent string `json:"new_content,omitempty"` +} + +type MultiEditResponseMetadata struct { + Additions int `json:"additions"` + Removals int `json:"removals"` + OldContent string `json:"old_content,omitempty"` + NewContent string `json:"new_content,omitempty"` + EditsApplied int `json:"edits_applied"` +} + +type multiEditTool struct { + lspClients map[string]*lsp.Client + permissions permission.Service + files history.Service + workingDir string +} + +const ( + MultiEditToolName = "multiedit" + multiEditDescription = `This is a tool for making multiple edits to a single file in one operation. It is built on top of the Edit tool and allows you to perform multiple find-and-replace operations efficiently. Prefer this tool over the Edit tool when you need to make multiple edits to the same file. + +Before using this tool: + +1. Use the Read tool to understand the file's contents and context + +2. Verify the directory path is correct + +To make multiple file edits, provide the following: +1. file_path: The absolute path to the file to modify (must be absolute, not relative) +2. edits: An array of edit operations to perform, where each edit contains: + - old_string: The text to replace (must match the file contents exactly, including all whitespace and indentation) + - new_string: The edited text to replace the old_string + - replace_all: Replace all occurrences of old_string. This parameter is optional and defaults to false. + +IMPORTANT: +- All edits are applied in sequence, in the order they are provided +- Each edit operates on the result of the previous edit +- All edits must be valid for the operation to succeed - if any edit fails, none will be applied +- This tool is ideal when you need to make several changes to different parts of the same file + +CRITICAL REQUIREMENTS: +1. All edits follow the same requirements as the single Edit tool +2. The edits are atomic - either all succeed or none are applied +3. Plan your edits carefully to avoid conflicts between sequential operations + +WARNING: +- The tool will fail if edits.old_string doesn't match the file contents exactly (including whitespace) +- The tool will fail if edits.old_string and edits.new_string are the same +- Since edits are applied in sequence, ensure that earlier edits don't affect the text that later edits are trying to find + +When making edits: +- Ensure all edits result in idiomatic, correct code +- Do not leave the code in a broken state +- Always use absolute file paths (starting with /) +- Only use emojis if the user explicitly requests it. Avoid adding emojis to files unless asked. +- Use replace_all for replacing and renaming strings across the file. This parameter is useful if you want to rename a variable for instance. + +If you want to create a new file, use: +- A new file path, including dir name if needed +- First edit: empty old_string and the new file's contents as new_string +- Subsequent edits: normal edit operations on the created content` +) + +func NewMultiEditTool(lspClients map[string]*lsp.Client, permissions permission.Service, files history.Service, workingDir string) BaseTool { + return &multiEditTool{ + lspClients: lspClients, + permissions: permissions, + files: files, + workingDir: workingDir, + } +} + +func (m *multiEditTool) Name() string { + return MultiEditToolName +} + +func (m *multiEditTool) Info() ToolInfo { + return ToolInfo{ + Name: MultiEditToolName, + Description: multiEditDescription, + Parameters: map[string]any{ + "file_path": map[string]any{ + "type": "string", + "description": "The absolute path to the file to modify", + }, + "edits": map[string]any{ + "type": "array", + "items": map[string]any{ + "type": "object", + "properties": map[string]any{ + "old_string": map[string]any{ + "type": "string", + "description": "The text to replace", + }, + "new_string": map[string]any{ + "type": "string", + "description": "The text to replace it with", + }, + "replace_all": map[string]any{ + "type": "boolean", + "default": false, + "description": "Replace all occurrences of old_string (default false).", + }, + }, + "required": []string{"old_string", "new_string"}, + "additionalProperties": false, + }, + "minItems": 1, + "description": "Array of edit operations to perform sequentially on the file", + }, + }, + Required: []string{"file_path", "edits"}, + } +} + +func (m *multiEditTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) { + var params MultiEditParams + if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil { + return NewTextErrorResponse("invalid parameters"), nil + } + + if params.FilePath == "" { + return NewTextErrorResponse("file_path is required"), nil + } + + if len(params.Edits) == 0 { + return NewTextErrorResponse("at least one edit operation is required"), nil + } + + if !filepath.IsAbs(params.FilePath) { + params.FilePath = filepath.Join(m.workingDir, params.FilePath) + } + + // Validate all edits before applying any + if err := m.validateEdits(params.Edits); err != nil { + return NewTextErrorResponse(err.Error()), nil + } + + var response ToolResponse + var err error + + // Handle file creation case (first edit has empty old_string) + if len(params.Edits) > 0 && params.Edits[0].OldString == "" { + response, err = m.processMultiEditWithCreation(ctx, params, call) + } else { + response, err = m.processMultiEditExistingFile(ctx, params, call) + } + + if err != nil { + return response, err + } + + if response.IsError { + return response, nil + } + + // Wait for LSP diagnostics and add them to the response + waitForLspDiagnostics(ctx, params.FilePath, m.lspClients) + text := fmt.Sprintf("\n%s\n\n", response.Content) + text += getDiagnostics(params.FilePath, m.lspClients) + response.Content = text + return response, nil +} + +func (m *multiEditTool) validateEdits(edits []MultiEditOperation) error { + for i, edit := range edits { + if edit.OldString == edit.NewString { + return fmt.Errorf("edit %d: old_string and new_string are identical", i+1) + } + // Only the first edit can have empty old_string (for file creation) + if i > 0 && edit.OldString == "" { + return fmt.Errorf("edit %d: only the first edit can have empty old_string (for file creation)", i+1) + } + } + return nil +} + +func (m *multiEditTool) processMultiEditWithCreation(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) { + // First edit creates the file + firstEdit := params.Edits[0] + if firstEdit.OldString != "" { + return NewTextErrorResponse("first edit must have empty old_string for file creation"), nil + } + + // Check if file already exists + if _, err := os.Stat(params.FilePath); err == nil { + return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", params.FilePath)), nil + } else if !os.IsNotExist(err) { + return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) + } + + // Create parent directories + dir := filepath.Dir(params.FilePath) + if err := os.MkdirAll(dir, 0o755); err != nil { + return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err) + } + + // Start with the content from the first edit + currentContent := firstEdit.NewString + + // Apply remaining edits to the content + for i := 1; i < len(params.Edits); i++ { + edit := params.Edits[i] + newContent, err := m.applyEditToContent(currentContent, edit) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil + } + currentContent = newContent + } + + // Get session and message IDs + sessionID, messageID := GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file") + } + + // Check permissions + _, additions, removals := diff.GenerateDiff("", currentContent, strings.TrimPrefix(params.FilePath, m.workingDir)) + rootDir := m.workingDir + permissionPath := filepath.Dir(params.FilePath) + if strings.HasPrefix(params.FilePath, rootDir) { + permissionPath = rootDir + } + + p := m.permissions.Request(permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: permissionPath, + ToolCallID: call.ID, + ToolName: MultiEditToolName, + Action: "write", + Description: fmt.Sprintf("Create file %s with %d edits", params.FilePath, len(params.Edits)), + Params: MultiEditPermissionsParams{ + FilePath: params.FilePath, + OldContent: "", + NewContent: currentContent, + }, + }) + if !p { + return ToolResponse{}, permission.ErrorPermissionDenied + } + + // Write the file + err := os.WriteFile(params.FilePath, []byte(currentContent), 0o644) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) + } + + // Update file history + _, err = m.files.Create(ctx, sessionID, params.FilePath, "") + if err != nil { + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + + _, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent) + if err != nil { + slog.Debug("Error creating file history version", "error", err) + } + + recordFileWrite(params.FilePath) + recordFileRead(params.FilePath) + + return WithResponseMetadata( + NewTextResponse(fmt.Sprintf("File created with %d edits: %s", len(params.Edits), params.FilePath)), + MultiEditResponseMetadata{ + OldContent: "", + NewContent: currentContent, + Additions: additions, + Removals: removals, + EditsApplied: len(params.Edits), + }, + ), nil +} + +func (m *multiEditTool) processMultiEditExistingFile(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) { + // Validate file exists and is readable + fileInfo, err := os.Stat(params.FilePath) + if err != nil { + if os.IsNotExist(err) { + return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil + } + return ToolResponse{}, fmt.Errorf("failed to access file: %w", err) + } + + if fileInfo.IsDir() { + return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil + } + + // Check if file was read before editing + if getLastReadTime(params.FilePath).IsZero() { + return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil + } + + // Check if file was modified since last read + modTime := fileInfo.ModTime() + lastRead := getLastReadTime(params.FilePath) + if modTime.After(lastRead) { + return NewTextErrorResponse( + fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", + params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339), + )), nil + } + + // Read current file content + content, err := os.ReadFile(params.FilePath) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to read file: %w", err) + } + + oldContent := string(content) + currentContent := oldContent + + // Apply all edits sequentially + for i, edit := range params.Edits { + newContent, err := m.applyEditToContent(currentContent, edit) + if err != nil { + return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil + } + currentContent = newContent + } + + // Check if content actually changed + if oldContent == currentContent { + return NewTextErrorResponse("no changes made - all edits resulted in identical content"), nil + } + + // Get session and message IDs + sessionID, messageID := GetContextValues(ctx) + if sessionID == "" || messageID == "" { + return ToolResponse{}, fmt.Errorf("session ID and message ID are required for editing file") + } + + // Generate diff and check permissions + _, additions, removals := diff.GenerateDiff(oldContent, currentContent, strings.TrimPrefix(params.FilePath, m.workingDir)) + rootDir := m.workingDir + permissionPath := filepath.Dir(params.FilePath) + if strings.HasPrefix(params.FilePath, rootDir) { + permissionPath = rootDir + } + + p := m.permissions.Request(permission.CreatePermissionRequest{ + SessionID: sessionID, + Path: permissionPath, + ToolCallID: call.ID, + ToolName: MultiEditToolName, + Action: "write", + Description: fmt.Sprintf("Apply %d edits to file %s", len(params.Edits), params.FilePath), + Params: MultiEditPermissionsParams{ + FilePath: params.FilePath, + OldContent: oldContent, + NewContent: currentContent, + }, + }) + if !p { + return ToolResponse{}, permission.ErrorPermissionDenied + } + + // Write the updated content + err = os.WriteFile(params.FilePath, []byte(currentContent), 0o644) + if err != nil { + return ToolResponse{}, fmt.Errorf("failed to write file: %w", err) + } + + // Update file history + file, err := m.files.GetByPathAndSession(ctx, params.FilePath, sessionID) + if err != nil { + _, err = m.files.Create(ctx, sessionID, params.FilePath, oldContent) + if err != nil { + return ToolResponse{}, fmt.Errorf("error creating file history: %w", err) + } + } + if file.Content != oldContent { + // User manually changed the content, store an intermediate version + _, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent) + if err != nil { + slog.Debug("Error creating file history version", "error", err) + } + } + + // Store the new version + _, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent) + if err != nil { + slog.Debug("Error creating file history version", "error", err) + } + + recordFileWrite(params.FilePath) + recordFileRead(params.FilePath) + + return WithResponseMetadata( + NewTextResponse(fmt.Sprintf("Applied %d edits to file: %s", len(params.Edits), params.FilePath)), + MultiEditResponseMetadata{ + OldContent: oldContent, + NewContent: currentContent, + Additions: additions, + Removals: removals, + EditsApplied: len(params.Edits), + }, + ), nil +} + +func (m *multiEditTool) applyEditToContent(content string, edit MultiEditOperation) (string, error) { + if edit.OldString == "" && edit.NewString == "" { + return content, nil + } + + if edit.OldString == "" { + return "", fmt.Errorf("old_string cannot be empty for content replacement") + } + + var newContent string + var replacementCount int + + if edit.ReplaceAll { + newContent = strings.ReplaceAll(content, edit.OldString, edit.NewString) + replacementCount = strings.Count(content, edit.OldString) + if replacementCount == 0 { + return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks") + } + } else { + index := strings.Index(content, edit.OldString) + if index == -1 { + return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks") + } + + lastIndex := strings.LastIndex(content, edit.OldString) + if index != lastIndex { + return "", fmt.Errorf("old_string appears multiple times in the content. Please provide more context to ensure a unique match, or set replace_all to true") + } + + newContent = content[:index] + edit.NewString + content[index+len(edit.OldString):] + replacementCount = 1 + } + + return newContent, nil +} diff --git a/internal/llm/tools/write.go b/internal/llm/tools/write.go index 50f472bf2e65dba2b3c7e9efd9ecc88136764d2f..7d8d6f567955ae69f35bb4ac38d1d8331dd375a3 100644 --- a/internal/llm/tools/write.go +++ b/internal/llm/tools/write.go @@ -181,6 +181,7 @@ func (w *writeTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error permission.CreatePermissionRequest{ SessionID: sessionID, Path: permissionPath, + ToolCallID: call.ID, ToolName: WriteToolName, Action: "write", Description: fmt.Sprintf("Create file %s", filePath), diff --git a/internal/permission/permission.go b/internal/permission/permission.go index cd149a49890b54086bd52e562eed0d44f00c407e..77df812dd2379d1e6e9c6149d47018e7359d631f 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -1,7 +1,9 @@ package permission import ( + "context" "errors" + "log/slog" "path/filepath" "slices" "sync" @@ -14,6 +16,7 @@ var ErrorPermissionDenied = errors.New("permission denied") type CreatePermissionRequest struct { SessionID string `json:"session_id"` + ToolCallID string `json:"tool_call_id"` ToolName string `json:"tool_name"` Description string `json:"description"` Action string `json:"action"` @@ -21,9 +24,16 @@ type CreatePermissionRequest struct { Path string `json:"path"` } +type PermissionNotification struct { + ToolCallID string `json:"tool_call_id"` + Granted bool `json:"granted"` + Denied bool `json:"denied"` +} + type PermissionRequest struct { ID string `json:"id"` SessionID string `json:"session_id"` + ToolCallID string `json:"tool_call_id"` ToolName string `json:"tool_name"` Description string `json:"description"` Action string `json:"action"` @@ -38,22 +48,32 @@ type Service interface { Deny(permission PermissionRequest) Request(opts CreatePermissionRequest) bool AutoApproveSession(sessionID string) + SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[PermissionNotification] } type permissionService struct { *pubsub.Broker[PermissionRequest] + notificationBroker *pubsub.Broker[PermissionNotification] workingDir string sessionPermissions []PermissionRequest sessionPermissionsMu sync.RWMutex pendingRequests sync.Map - autoApproveSessions []string + autoApproveSessions map[string]bool autoApproveSessionsMu sync.RWMutex skip bool allowedTools []string + + // used to make sure we only process one request at a time + requestMu sync.Mutex + activeRequest *PermissionRequest } func (s *permissionService) GrantPersistent(permission PermissionRequest) { + s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ + ToolCallID: permission.ToolCallID, + Granted: true, + }) respCh, ok := s.pendingRequests.Load(permission.ID) if ok { respCh.(chan bool) <- true @@ -62,20 +82,41 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) { s.sessionPermissionsMu.Lock() s.sessionPermissions = append(s.sessionPermissions, permission) s.sessionPermissionsMu.Unlock() + + if s.activeRequest != nil && s.activeRequest.ID == permission.ID { + s.activeRequest = nil + } } func (s *permissionService) Grant(permission PermissionRequest) { + s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ + ToolCallID: permission.ToolCallID, + Granted: true, + }) respCh, ok := s.pendingRequests.Load(permission.ID) if ok { respCh.(chan bool) <- true } + + if s.activeRequest != nil && s.activeRequest.ID == permission.ID { + s.activeRequest = nil + } } func (s *permissionService) Deny(permission PermissionRequest) { + s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ + ToolCallID: permission.ToolCallID, + Granted: false, + Denied: true, + }) respCh, ok := s.pendingRequests.Load(permission.ID) if ok { respCh.(chan bool) <- false } + + if s.activeRequest != nil && s.activeRequest.ID == permission.ID { + s.activeRequest = nil + } } func (s *permissionService) Request(opts CreatePermissionRequest) bool { @@ -83,6 +124,13 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { return true } + // tell the UI that a permission was requested + s.notificationBroker.Publish(pubsub.CreatedEvent, PermissionNotification{ + ToolCallID: opts.ToolCallID, + }) + s.requestMu.Lock() + defer s.requestMu.Unlock() + // Check if the tool/action combination is in the allowlist commandKey := opts.ToolName + ":" + opts.Action if slices.Contains(s.allowedTools, commandKey) || slices.Contains(s.allowedTools, opts.ToolName) { @@ -90,7 +138,7 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { } s.autoApproveSessionsMu.RLock() - autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID) + autoApprove := s.autoApproveSessions[opts.SessionID] s.autoApproveSessionsMu.RUnlock() if autoApprove { @@ -101,10 +149,12 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { if dir == "." { dir = s.workingDir } + slog.Info("Requesting permission", "session_id", opts.SessionID, "tool_name", opts.ToolName, "action", opts.Action, "path", dir) permission := PermissionRequest{ ID: uuid.New().String(), Path: dir, SessionID: opts.SessionID, + ToolCallID: opts.ToolCallID, ToolName: opts.ToolName, Description: opts.Description, Action: opts.Action, @@ -120,29 +170,45 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { } s.sessionPermissionsMu.RUnlock() - respCh := make(chan bool, 1) + s.sessionPermissionsMu.RLock() + for _, p := range s.sessionPermissions { + if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path { + s.sessionPermissionsMu.RUnlock() + return true + } + } + s.sessionPermissionsMu.RUnlock() + + s.activeRequest = &permission + respCh := make(chan bool, 1) s.pendingRequests.Store(permission.ID, respCh) defer s.pendingRequests.Delete(permission.ID) + // Publish the request s.Publish(pubsub.CreatedEvent, permission) - // Wait for the response indefinitely return <-respCh } func (s *permissionService) AutoApproveSession(sessionID string) { s.autoApproveSessionsMu.Lock() - s.autoApproveSessions = append(s.autoApproveSessions, sessionID) + s.autoApproveSessions[sessionID] = true s.autoApproveSessionsMu.Unlock() } +func (s *permissionService) SubscribeNotifications(ctx context.Context) <-chan pubsub.Event[PermissionNotification] { + return s.notificationBroker.Subscribe(ctx) +} + func NewPermissionService(workingDir string, skip bool, allowedTools []string) Service { return &permissionService{ - Broker: pubsub.NewBroker[PermissionRequest](), - workingDir: workingDir, - sessionPermissions: make([]PermissionRequest, 0), - skip: skip, - allowedTools: allowedTools, + Broker: pubsub.NewBroker[PermissionRequest](), + notificationBroker: pubsub.NewBroker[PermissionNotification](), + workingDir: workingDir, + sessionPermissions: make([]PermissionRequest, 0), + autoApproveSessions: make(map[string]bool), + skip: skip, + allowedTools: allowedTools, } } diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go index 5d10fbd240da6a171e345938cb3382a7f7fcf19b..c3c646ecd97f51a0f91d8209e2a34c6855d6547b 100644 --- a/internal/permission/permission_test.go +++ b/internal/permission/permission_test.go @@ -1,7 +1,10 @@ package permission import ( + "sync" "testing" + + "github.com/stretchr/testify/assert" ) func TestPermissionService_AllowedCommands(t *testing.T) { @@ -90,3 +93,159 @@ func TestPermissionService_SkipMode(t *testing.T) { t.Error("expected permission to be granted in skip mode") } } + +func TestPermissionService_SequentialProperties(t *testing.T) { + t.Run("Sequential permission requests with persistent grants", func(t *testing.T) { + service := NewPermissionService("/tmp", false, []string{}) + + req1 := CreatePermissionRequest{ + SessionID: "session1", + ToolName: "file_tool", + Description: "Read file", + Action: "read", + Params: map[string]string{"file": "test.txt"}, + Path: "/tmp/test.txt", + } + + var result1 bool + var wg sync.WaitGroup + wg.Add(1) + + events := service.Subscribe(t.Context()) + + go func() { + defer wg.Done() + result1 = service.Request(req1) + }() + + var permissionReq PermissionRequest + event := <-events + + permissionReq = event.Payload + service.GrantPersistent(permissionReq) + + wg.Wait() + assert.True(t, result1, "First request should be granted") + + // Second identical request should be automatically approved due to persistent permission + req2 := CreatePermissionRequest{ + SessionID: "session1", + ToolName: "file_tool", + Description: "Read file again", + Action: "read", + Params: map[string]string{"file": "test.txt"}, + Path: "/tmp/test.txt", + } + result2 := service.Request(req2) + assert.True(t, result2, "Second request should be auto-approved") + }) + t.Run("Sequential requests with temporary grants", func(t *testing.T) { + service := NewPermissionService("/tmp", false, []string{}) + + req := CreatePermissionRequest{ + SessionID: "session2", + ToolName: "file_tool", + Description: "Write file", + Action: "write", + Params: map[string]string{"file": "test.txt"}, + Path: "/tmp/test.txt", + } + + events := service.Subscribe(t.Context()) + var result1 bool + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + result1 = service.Request(req) + }() + + var permissionReq PermissionRequest + event := <-events + permissionReq = event.Payload + + service.Grant(permissionReq) + wg.Wait() + assert.True(t, result1, "First request should be granted") + + var result2 bool + wg.Add(1) + + go func() { + defer wg.Done() + result2 = service.Request(req) + }() + + event = <-events + permissionReq = event.Payload + service.Deny(permissionReq) + wg.Wait() + assert.False(t, result2, "Second request should be denied") + }) + t.Run("Concurrent requests with different outcomes", func(t *testing.T) { + service := NewPermissionService("/tmp", false, []string{}) + + events := service.Subscribe(t.Context()) + + var wg sync.WaitGroup + results := make([]bool, 0) + + requests := []CreatePermissionRequest{ + { + SessionID: "concurrent1", + ToolName: "tool1", + Action: "action1", + Path: "/tmp/file1.txt", + Description: "First concurrent request", + }, + { + SessionID: "concurrent2", + ToolName: "tool2", + Action: "action2", + Path: "/tmp/file2.txt", + Description: "Second concurrent request", + }, + { + SessionID: "concurrent3", + ToolName: "tool3", + Action: "action3", + Path: "/tmp/file3.txt", + Description: "Third concurrent request", + }, + } + + for i, req := range requests { + wg.Add(1) + go func(index int, request CreatePermissionRequest) { + defer wg.Done() + results = append(results, service.Request(request)) + }(i, req) + } + + for range 3 { + event := <-events + switch event.Payload.ToolName { + case "tool1": + service.Grant(event.Payload) + case "tool2": + service.GrantPersistent(event.Payload) + case "tool3": + service.Deny(event.Payload) + } + } + wg.Wait() + grantedCount := 0 + for _, result := range results { + if result { + grantedCount++ + } + } + + assert.Equal(t, 2, grantedCount, "Should have 2 granted and 1 denied") + secondReq := requests[1] + secondReq.Description = "Repeat of second request" + result := service.Request(secondReq) + assert.True(t, result, "Repeated request should be auto-approved due to persistent permission") + }) +} diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index 90d117e64dec449d09f8ef301de661a1feefd22c..211808b88b1291ed2359dc137e14d1eeea8f2c14 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -2,6 +2,7 @@ package chat import ( "context" + "log/slog" "time" "github.com/charmbracelet/bubbles/v2/key" @@ -9,6 +10,7 @@ import ( "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/llm/agent" "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/tui/components/chat/messages" @@ -85,6 +87,8 @@ func (m *messageListCmp) Init() tea.Cmd { // Update handles incoming messages and updates the component state. func (m *messageListCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg := msg.(type) { + case pubsub.Event[permission.PermissionNotification]: + return m, m.handlePermissionRequest(msg.Payload) case SessionSelectedMsg: if msg.ID != m.session.ID { cmd := m.SetSession(msg) @@ -124,6 +128,20 @@ func (m *messageListCmp) View() string { ) } +func (m *messageListCmp) handlePermissionRequest(permission permission.PermissionNotification) tea.Cmd { + items := m.listCmp.Items() + slog.Info("Handling permission request", "tool_call_id", permission.ToolCallID, "granted", permission.Granted) + if toolCallIndex := m.findToolCallByID(items, permission.ToolCallID); toolCallIndex != NotFound { + toolCall := items[toolCallIndex].(messages.ToolCallCmp) + toolCall.SetPermissionRequested() + if permission.Granted { + toolCall.SetPermissionGranted() + } + m.listCmp.UpdateItem(toolCall.ID(), toolCall) + } + return nil +} + // handleChildSession handles messages from child sessions (agent tools). func (m *messageListCmp) handleChildSession(event pubsub.Event[message.Message]) tea.Cmd { var cmds []tea.Cmd @@ -158,6 +176,7 @@ func (m *messageListCmp) handleChildSession(event pubsub.Event[message.Message]) nestedCall := messages.NewToolCallCmp( event.Payload.ID, tc, + m.app.Permissions, messages.WithToolCallNested(true), ) cmds = append(cmds, nestedCall.Init()) @@ -199,7 +218,12 @@ func (m *messageListCmp) handleMessageEvent(event pubsub.Event[message.Message]) if event.Payload.SessionID != m.session.ID { return m.handleChildSession(event) } - return m.handleUpdateAssistantMessage(event.Payload) + switch event.Payload.Role { + case message.Assistant: + return m.handleUpdateAssistantMessage(event.Payload) + case message.Tool: + return m.handleToolMessage(event.Payload) + } } return nil } @@ -371,7 +395,7 @@ func (m *messageListCmp) updateOrAddToolCall(msg message.Message, tc message.Too } // Add new tool call if not found - return m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc)) + return m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc, m.app.Permissions)) } // handleNewAssistantMessage processes new assistant messages and their tool calls. @@ -390,7 +414,7 @@ func (m *messageListCmp) handleNewAssistantMessage(msg message.Message) tea.Cmd // Add tool calls for _, tc := range msg.ToolCalls() { - cmd := m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc)) + cmd := m.listCmp.AppendItem(messages.NewToolCallCmp(msg.ID, tc, m.app.Permissions)) cmds = append(cmds, cmd) } @@ -473,7 +497,7 @@ func (m *messageListCmp) convertAssistantMessage(msg message.Message, toolResult // Add tool calls with their results and status for _, tc := range msg.ToolCalls() { options := m.buildToolCallOptions(tc, msg, toolResultMap) - uiMessages = append(uiMessages, messages.NewToolCallCmp(msg.ID, tc, options...)) + uiMessages = append(uiMessages, messages.NewToolCallCmp(msg.ID, tc, m.app.Permissions, options...)) // If this tool call is the agent tool, fetch nested tool calls if tc.Name == agent.AgentToolName { nestedMessages, _ := m.app.Messages.List(context.Background(), tc.ID) diff --git a/internal/tui/components/chat/messages/renderer.go b/internal/tui/components/chat/messages/renderer.go index ace42420a26a47854313029e48ca4b3f495525c4..a4abfc909dfa4db867e0232a97478eb1ebc04eda 100644 --- a/internal/tui/components/chat/messages/renderer.go +++ b/internal/tui/components/chat/messages/renderer.go @@ -166,6 +166,7 @@ func init() { registry.register(tools.DownloadToolName, func() renderer { return downloadRenderer{} }) registry.register(tools.ViewToolName, func() renderer { return viewRenderer{} }) registry.register(tools.EditToolName, func() renderer { return editRenderer{} }) + registry.register(tools.MultiEditToolName, func() renderer { return multiEditRenderer{} }) registry.register(tools.WriteToolName, func() renderer { return writeRenderer{} }) registry.register(tools.FetchToolName, func() renderer { return fetchRenderer{} }) registry.register(tools.GlobToolName, func() renderer { return globRenderer{} }) @@ -316,6 +317,57 @@ func (er editRenderer) Render(v *toolCallCmp) string { }) } +// ----------------------------------------------------------------------------- +// Multi-Edit renderer +// ----------------------------------------------------------------------------- + +// multiEditRenderer handles multiple file edits with diff visualization +type multiEditRenderer struct { + baseRenderer +} + +// Render displays the multi-edited file with a formatted diff of changes +func (mer multiEditRenderer) Render(v *toolCallCmp) string { + t := styles.CurrentTheme() + var params tools.MultiEditParams + var args []string + if err := mer.unmarshalParams(v.call.Input, ¶ms); err == nil { + file := fsext.PrettyPath(params.FilePath) + editsCount := len(params.Edits) + args = newParamBuilder(). + addMain(file). + addKeyValue("edits", fmt.Sprintf("%d", editsCount)). + build() + } + + return mer.renderWithParams(v, "Multi-Edit", args, func() string { + var meta tools.MultiEditResponseMetadata + if err := mer.unmarshalParams(v.result.Metadata, &meta); err != nil { + return renderPlainContent(v, v.result.Content) + } + + formatter := core.DiffFormatter(). + Before(fsext.PrettyPath(params.FilePath), meta.OldContent). + After(fsext.PrettyPath(params.FilePath), meta.NewContent). + Width(v.textWidth() - 2) // -2 for padding + if v.textWidth() > 120 { + formatter = formatter.Split() + } + // add a message to the bottom if the content was truncated + formatted := formatter.String() + if lipgloss.Height(formatted) > responseContextHeight { + contentLines := strings.Split(formatted, "\n") + truncateMessage := t.S().Muted. + Background(t.BgBaseLighter). + PaddingLeft(2). + Width(v.textWidth() - 4). + Render(fmt.Sprintf("… (%d lines)", len(contentLines)-responseContextHeight)) + formatted = strings.Join(contentLines[:responseContextHeight], "\n") + "\n" + truncateMessage + } + return formatted + }) +} + // ----------------------------------------------------------------------------- // Write renderer // ----------------------------------------------------------------------------- @@ -672,7 +724,11 @@ func earlyState(header string, v *toolCallCmp) (string, bool) { case v.cancelled: message = t.S().Base.Foreground(t.FgSubtle).Render("Canceled.") case v.result.ToolCallID == "": - message = t.S().Base.Foreground(t.FgSubtle).Render("Waiting for tool to start...") + if v.permissionRequested && !v.permissionGranted { + message = t.S().Base.Foreground(t.FgSubtle).Render("Requesting for permission...") + } else { + message = t.S().Base.Foreground(t.FgSubtle).Render("Waiting for tool response...") + } default: return "", false } @@ -799,6 +855,8 @@ func prettifyToolName(name string) string { return "Download" case tools.EditToolName: return "Edit" + case tools.MultiEditToolName: + return "Multi-Edit" case tools.FetchToolName: return "Fetch" case tools.GlobToolName: diff --git a/internal/tui/components/chat/messages/tool.go b/internal/tui/components/chat/messages/tool.go index 2f639c5c5d192ba9c59402976e552462d8ebcd0b..51375c1b1bb11956069a733ffa509aae112eb073 100644 --- a/internal/tui/components/chat/messages/tool.go +++ b/internal/tui/components/chat/messages/tool.go @@ -5,6 +5,7 @@ import ( tea "github.com/charmbracelet/bubbletea/v2" "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/tui/components/anim" "github.com/charmbracelet/crush/internal/tui/components/core/layout" "github.com/charmbracelet/crush/internal/tui/styles" @@ -30,6 +31,8 @@ type ToolCallCmp interface { SetNestedToolCalls([]ToolCallCmp) // Set nested tool calls SetIsNested(bool) // Set whether this tool call is nested ID() string + SetPermissionRequested() // Mark permission request + SetPermissionGranted() // Mark permission granted } // toolCallCmp implements the ToolCallCmp interface for displaying tool calls. @@ -40,10 +43,12 @@ type toolCallCmp struct { isNested bool // Whether this tool call is nested within another // Tool call data and state - parentMessageID string // ID of the message that initiated this tool call - call message.ToolCall // The tool call being executed - result message.ToolResult // The result of the tool execution - cancelled bool // Whether the tool call was cancelled + parentMessageID string // ID of the message that initiated this tool call + call message.ToolCall // The tool call being executed + result message.ToolResult // The result of the tool execution + cancelled bool // Whether the tool call was cancelled + permissionRequested bool + permissionGranted bool // Animation state for pending tool calls spinning bool // Whether to show loading animation @@ -81,9 +86,21 @@ func WithToolCallNestedCalls(calls []ToolCallCmp) ToolCallOption { } } +func WithToolPermissionRequested() ToolCallOption { + return func(m *toolCallCmp) { + m.permissionRequested = true + } +} + +func WithToolPermissionGranted() ToolCallOption { + return func(m *toolCallCmp) { + m.permissionGranted = true + } +} + // NewToolCallCmp creates a new tool call component with the given parent message ID, // tool call, and optional configuration -func NewToolCallCmp(parentMessageID string, tc message.ToolCall, opts ...ToolCallOption) ToolCallCmp { +func NewToolCallCmp(parentMessageID string, tc message.ToolCall, permissions permission.Service, opts ...ToolCallOption) ToolCallCmp { m := &toolCallCmp{ call: tc, parentMessageID: parentMessageID, @@ -316,3 +333,13 @@ func (m *toolCallCmp) Spinning() bool { func (m *toolCallCmp) ID() string { return m.call.ID } + +// SetPermissionRequested marks that a permission request was made for this tool call +func (m *toolCallCmp) SetPermissionRequested() { + m.permissionRequested = true +} + +// SetPermissionGranted marks that permission was granted for this tool call +func (m *toolCallCmp) SetPermissionGranted() { + m.permissionGranted = true +} diff --git a/internal/tui/components/dialogs/permissions/permissions.go b/internal/tui/components/dialogs/permissions/permissions.go index 1b41094c9c69ba91bbbefdf86e7040cd77d3ce8e..2e7a04dc7416baf6fdfec90ab56d61f60dad81f1 100644 --- a/internal/tui/components/dialogs/permissions/permissions.go +++ b/internal/tui/components/dialogs/permissions/permissions.go @@ -84,7 +84,7 @@ func (p *permissionDialogCmp) Init() tea.Cmd { } func (p *permissionDialogCmp) supportsDiffView() bool { - return p.permission.ToolName == tools.EditToolName || p.permission.ToolName == tools.WriteToolName + return p.permission.ToolName == tools.EditToolName || p.permission.ToolName == tools.WriteToolName || p.permission.ToolName == tools.MultiEditToolName } func (p *permissionDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -305,6 +305,20 @@ func (p *permissionDialogCmp) renderHeader() string { ), baseStyle.Render(strings.Repeat(" ", p.width)), ) + case tools.MultiEditToolName: + params := p.permission.Params.(tools.MultiEditPermissionsParams) + fileKey := t.S().Muted.Render("File") + filePath := t.S().Text. + Width(p.width - lipgloss.Width(fileKey)). + Render(fmt.Sprintf(" %s", fsext.PrettyPath(params.FilePath))) + headerParts = append(headerParts, + lipgloss.JoinHorizontal( + lipgloss.Left, + fileKey, + filePath, + ), + baseStyle.Render(strings.Repeat(" ", p.width)), + ) case tools.FetchToolName: headerParts = append(headerParts, t.S().Muted.Width(p.width).Bold(true).Render("URL")) } @@ -329,6 +343,8 @@ func (p *permissionDialogCmp) getOrGenerateContent() string { content = p.generateEditContent() case tools.WriteToolName: content = p.generateWriteContent() + case tools.MultiEditToolName: + content = p.generateMultiEditContent() case tools.FetchToolName: content = p.generateFetchContent() default: @@ -435,6 +451,28 @@ func (p *permissionDialogCmp) generateDownloadContent() string { return "" } +func (p *permissionDialogCmp) generateMultiEditContent() string { + if pr, ok := p.permission.Params.(tools.MultiEditPermissionsParams); ok { + // Use the cache for diff rendering + formatter := core.DiffFormatter(). + Before(fsext.PrettyPath(pr.FilePath), pr.OldContent). + After(fsext.PrettyPath(pr.FilePath), pr.NewContent). + Height(p.contentViewPort.Height()). + Width(p.contentViewPort.Width()). + XOffset(p.diffXOffset). + YOffset(p.diffYOffset) + if p.useDiffSplitMode() { + formatter = formatter.Split() + } else { + formatter = formatter.Unified() + } + + diff := formatter.String() + return diff + } + return "" +} + func (p *permissionDialogCmp) generateFetchContent() string { t := styles.CurrentTheme() baseStyle := t.S().Base.Background(t.BgSubtle) @@ -579,6 +617,9 @@ func (p *permissionDialogCmp) SetSize() tea.Cmd { case tools.WriteToolName: p.width = int(float64(p.wWidth) * 0.8) p.height = int(float64(p.wHeight) * 0.8) + case tools.MultiEditToolName: + p.width = int(float64(p.wWidth) * 0.8) + p.height = int(float64(p.wHeight) * 0.8) case tools.FetchToolName: p.width = int(float64(p.wWidth) * 0.8) p.height = int(float64(p.wHeight) * 0.3) diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 073ac869bb5f3916e5eccbb37da135c0b012f251..e8e9e97bce5f98cef91886c05df6988a4561825c 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -12,6 +12,7 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/tui/components/anim" @@ -251,6 +252,11 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.sidebar = u.(sidebar.Sidebar) cmds = append(cmds, cmd) return p, tea.Batch(cmds...) + case pubsub.Event[permission.PermissionNotification]: + u, cmd := p.chat.Update(msg) + p.chat = u.(chat.MessageListCmp) + cmds = append(cmds, cmd) + return p, tea.Batch(cmds...) case commands.CommandRunCustomMsg: if p.app.CoderAgent.IsBusy() { diff --git a/internal/tui/tui.go b/internal/tui/tui.go index c4c88199de49fd9145dcf21fc78d452b8de14e9a..af89ce9bfa2165caa25c2671f4ba8096e11bd4f9 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -205,6 +205,11 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { Model: filepicker.NewFilePickerCmp(a.app.Config().WorkingDir()), }) // Permissions + case pubsub.Event[permission.PermissionNotification]: + // forward to page + updated, cmd := a.pages[a.currentPage].Update(msg) + a.pages[a.currentPage] = updated.(util.Model) + return a, cmd case pubsub.Event[permission.PermissionRequest]: return a, util.CmdHandler(dialogs.OpenDialogMsg{ Model: permissions.NewPermissionDialogCmp(msg.Payload),