1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/diff"
 16	"github.com/charmbracelet/crush/internal/fsext"
 17	"github.com/charmbracelet/crush/internal/history"
 18	"github.com/charmbracelet/crush/internal/proto"
 19
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/permission"
 22)
 23
 24type (
 25	EditParams            = proto.EditParams
 26	EditPermissionsParams = proto.EditPermissionsParams
 27	EditResponseMetadata  = proto.EditResponseMetadata
 28)
 29
 30type editTool struct {
 31	lspClients  *csync.Map[string, *lsp.Client]
 32	permissions permission.Service
 33	files       history.Service
 34	workingDir  string
 35}
 36
 37const EditToolName = proto.EditToolName
 38
 39//go:embed edit.md
 40var editDescription []byte
 41
 42func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
 43	return &editTool{
 44		lspClients:  lspClients,
 45		permissions: permissions,
 46		files:       files,
 47		workingDir:  workingDir,
 48	}
 49}
 50
 51func (e *editTool) Name() string {
 52	return EditToolName
 53}
 54
 55func (e *editTool) Info() ToolInfo {
 56	return ToolInfo{
 57		Name:        EditToolName,
 58		Description: string(editDescription),
 59		Parameters: map[string]any{
 60			"file_path": map[string]any{
 61				"type":        "string",
 62				"description": "The absolute path to the file to modify",
 63			},
 64			"old_string": map[string]any{
 65				"type":        "string",
 66				"description": "The text to replace",
 67			},
 68			"new_string": map[string]any{
 69				"type":        "string",
 70				"description": "The text to replace it with",
 71			},
 72			"replace_all": map[string]any{
 73				"type":        "boolean",
 74				"description": "Replace all occurrences of old_string (default false)",
 75			},
 76		},
 77		Required: []string{"file_path", "old_string", "new_string"},
 78	}
 79}
 80
 81func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 82	var params EditParams
 83	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 84		return NewTextErrorResponse("invalid parameters"), nil
 85	}
 86
 87	if params.FilePath == "" {
 88		return NewTextErrorResponse("file_path is required"), nil
 89	}
 90
 91	if !filepath.IsAbs(params.FilePath) {
 92		params.FilePath = filepath.Join(e.workingDir, params.FilePath)
 93	}
 94
 95	var response ToolResponse
 96	var err error
 97
 98	if params.OldString == "" {
 99		response, err = e.createNewFile(ctx, params.FilePath, params.NewString, call)
100		if err != nil {
101			return response, err
102		}
103	}
104
105	if params.NewString == "" {
106		response, err = e.deleteContent(ctx, params.FilePath, params.OldString, params.ReplaceAll, call)
107		if err != nil {
108			return response, err
109		}
110	}
111
112	response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString, params.ReplaceAll, call)
113	if err != nil {
114		return response, err
115	}
116	if response.IsError {
117		// Return early if there was an error during content replacement
118		// This prevents unnecessary LSP diagnostics processing
119		return response, nil
120	}
121
122	notifyLSPs(ctx, e.lspClients, params.FilePath)
123
124	text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
125	text += getDiagnostics(params.FilePath, e.lspClients)
126	response.Content = text
127	return response, nil
128}
129
130func (e *editTool) createNewFile(ctx context.Context, filePath, content string, call ToolCall) (ToolResponse, error) {
131	fileInfo, err := os.Stat(filePath)
132	if err == nil {
133		if fileInfo.IsDir() {
134			return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
135		}
136		return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", filePath)), nil
137	} else if !os.IsNotExist(err) {
138		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
139	}
140
141	dir := filepath.Dir(filePath)
142	if err = os.MkdirAll(dir, 0o755); err != nil {
143		return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
144	}
145
146	sessionID, messageID := GetContextValues(ctx)
147	if sessionID == "" || messageID == "" {
148		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
149	}
150
151	_, additions, removals := diff.GenerateDiff(
152		"",
153		content,
154		strings.TrimPrefix(filePath, e.workingDir),
155	)
156	p := e.permissions.Request(
157		permission.CreatePermissionRequest{
158			SessionID:   sessionID,
159			Path:        fsext.PathOrPrefix(filePath, e.workingDir),
160			ToolCallID:  call.ID,
161			ToolName:    EditToolName,
162			Action:      "write",
163			Description: fmt.Sprintf("Create file %s", filePath),
164			Params: EditPermissionsParams{
165				FilePath:   filePath,
166				OldContent: "",
167				NewContent: content,
168			},
169		},
170	)
171	if !p {
172		return ToolResponse{}, permission.ErrorPermissionDenied
173	}
174
175	err = os.WriteFile(filePath, []byte(content), 0o644)
176	if err != nil {
177		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
178	}
179
180	// File can't be in the history so we create a new file history
181	_, err = e.files.Create(ctx, sessionID, filePath, "")
182	if err != nil {
183		// Log error but don't fail the operation
184		return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
185	}
186
187	// Add the new content to the file history
188	_, err = e.files.CreateVersion(ctx, sessionID, filePath, content)
189	if err != nil {
190		// Log error but don't fail the operation
191		slog.Debug("Error creating file history version", "error", err)
192	}
193
194	recordFileWrite(filePath)
195	recordFileRead(filePath)
196
197	return WithResponseMetadata(
198		NewTextResponse("File created: "+filePath),
199		EditResponseMetadata{
200			OldContent: "",
201			NewContent: content,
202			Additions:  additions,
203			Removals:   removals,
204		},
205	), nil
206}
207
208func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
209	fileInfo, err := os.Stat(filePath)
210	if err != nil {
211		if os.IsNotExist(err) {
212			return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
213		}
214		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
215	}
216
217	if fileInfo.IsDir() {
218		return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
219	}
220
221	if getLastReadTime(filePath).IsZero() {
222		return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
223	}
224
225	modTime := fileInfo.ModTime()
226	lastRead := getLastReadTime(filePath)
227	if modTime.After(lastRead) {
228		return NewTextErrorResponse(
229			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
230				filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
231			)), nil
232	}
233
234	content, err := os.ReadFile(filePath)
235	if err != nil {
236		return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
237	}
238
239	oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
240
241	var newContent string
242	var deletionCount int
243
244	if replaceAll {
245		newContent = strings.ReplaceAll(oldContent, oldString, "")
246		deletionCount = strings.Count(oldContent, oldString)
247		if deletionCount == 0 {
248			return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
249		}
250	} else {
251		index := strings.Index(oldContent, oldString)
252		if index == -1 {
253			return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
254		}
255
256		lastIndex := strings.LastIndex(oldContent, oldString)
257		if index != lastIndex {
258			return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
259		}
260
261		newContent = oldContent[:index] + oldContent[index+len(oldString):]
262		deletionCount = 1
263	}
264
265	sessionID, messageID := GetContextValues(ctx)
266
267	if sessionID == "" || messageID == "" {
268		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
269	}
270
271	_, additions, removals := diff.GenerateDiff(
272		oldContent,
273		newContent,
274		strings.TrimPrefix(filePath, e.workingDir),
275	)
276
277	p := e.permissions.Request(
278		permission.CreatePermissionRequest{
279			SessionID:   sessionID,
280			Path:        fsext.PathOrPrefix(filePath, e.workingDir),
281			ToolCallID:  call.ID,
282			ToolName:    EditToolName,
283			Action:      "write",
284			Description: fmt.Sprintf("Delete content from file %s", filePath),
285			Params: EditPermissionsParams{
286				FilePath:   filePath,
287				OldContent: oldContent,
288				NewContent: newContent,
289			},
290		},
291	)
292	if !p {
293		return ToolResponse{}, permission.ErrorPermissionDenied
294	}
295
296	if isCrlf {
297		newContent, _ = fsext.ToWindowsLineEndings(newContent)
298	}
299
300	err = os.WriteFile(filePath, []byte(newContent), 0o644)
301	if err != nil {
302		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
303	}
304
305	// Check if file exists in history
306	file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
307	if err != nil {
308		_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
309		if err != nil {
310			// Log error but don't fail the operation
311			return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
312		}
313	}
314	if file.Content != oldContent {
315		// User Manually changed the content store an intermediate version
316		_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
317		if err != nil {
318			slog.Debug("Error creating file history version", "error", err)
319		}
320	}
321	// Store the new version
322	_, err = e.files.CreateVersion(ctx, sessionID, filePath, "")
323	if err != nil {
324		slog.Debug("Error creating file history version", "error", err)
325	}
326
327	recordFileWrite(filePath)
328	recordFileRead(filePath)
329
330	return WithResponseMetadata(
331		NewTextResponse("Content deleted from file: "+filePath),
332		EditResponseMetadata{
333			OldContent: oldContent,
334			NewContent: newContent,
335			Additions:  additions,
336			Removals:   removals,
337		},
338	), nil
339}
340
341func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
342	fileInfo, err := os.Stat(filePath)
343	if err != nil {
344		if os.IsNotExist(err) {
345			return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
346		}
347		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
348	}
349
350	if fileInfo.IsDir() {
351		return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
352	}
353
354	if getLastReadTime(filePath).IsZero() {
355		return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
356	}
357
358	modTime := fileInfo.ModTime()
359	lastRead := getLastReadTime(filePath)
360	if modTime.After(lastRead) {
361		return NewTextErrorResponse(
362			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
363				filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
364			)), nil
365	}
366
367	content, err := os.ReadFile(filePath)
368	if err != nil {
369		return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
370	}
371
372	oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
373
374	var newContent string
375	var replacementCount int
376
377	if replaceAll {
378		newContent = strings.ReplaceAll(oldContent, oldString, newString)
379		replacementCount = strings.Count(oldContent, oldString)
380		if replacementCount == 0 {
381			return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
382		}
383	} else {
384		index := strings.Index(oldContent, oldString)
385		if index == -1 {
386			return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
387		}
388
389		lastIndex := strings.LastIndex(oldContent, oldString)
390		if index != lastIndex {
391			return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
392		}
393
394		newContent = oldContent[:index] + newString + oldContent[index+len(oldString):]
395		replacementCount = 1
396	}
397
398	if oldContent == newContent {
399		return NewTextErrorResponse("new content is the same as old content. No changes made."), nil
400	}
401	sessionID, messageID := GetContextValues(ctx)
402
403	if sessionID == "" || messageID == "" {
404		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
405	}
406	_, additions, removals := diff.GenerateDiff(
407		oldContent,
408		newContent,
409		strings.TrimPrefix(filePath, e.workingDir),
410	)
411
412	p := e.permissions.Request(
413		permission.CreatePermissionRequest{
414			SessionID:   sessionID,
415			Path:        fsext.PathOrPrefix(filePath, e.workingDir),
416			ToolCallID:  call.ID,
417			ToolName:    EditToolName,
418			Action:      "write",
419			Description: fmt.Sprintf("Replace content in file %s", filePath),
420			Params: EditPermissionsParams{
421				FilePath:   filePath,
422				OldContent: oldContent,
423				NewContent: newContent,
424			},
425		},
426	)
427	if !p {
428		return ToolResponse{}, permission.ErrorPermissionDenied
429	}
430
431	if isCrlf {
432		newContent, _ = fsext.ToWindowsLineEndings(newContent)
433	}
434
435	err = os.WriteFile(filePath, []byte(newContent), 0o644)
436	if err != nil {
437		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
438	}
439
440	// Check if file exists in history
441	file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
442	if err != nil {
443		_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
444		if err != nil {
445			// Log error but don't fail the operation
446			return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
447		}
448	}
449	if file.Content != oldContent {
450		// User Manually changed the content store an intermediate version
451		_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
452		if err != nil {
453			slog.Debug("Error creating file history version", "error", err)
454		}
455	}
456	// Store the new version
457	_, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent)
458	if err != nil {
459		slog.Debug("Error creating file history version", "error", err)
460	}
461
462	recordFileWrite(filePath)
463	recordFileRead(filePath)
464
465	return WithResponseMetadata(
466		NewTextResponse("Content replaced in file: "+filePath),
467		EditResponseMetadata{
468			OldContent: oldContent,
469			NewContent: newContent,
470			Additions:  additions,
471			Removals:   removals,
472		}), nil
473}